import os
import sys
import time
import re
import csv
import gradio as gr
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import plotly.colors as pc
from qatch.connectors.sqlite_connector import SqliteConnector
from qatch.generate_dataset.orchestrator_generator import OrchestratorGenerator
from qatch.evaluate_dataset.orchestrator_evaluator import OrchestratorEvaluator
import qatch.evaluate_dataset.orchestrator_evaluator as eva
from prediction import ModelPrediction
import utils_get_db_tables_info
import utilities as us
# @spaces.GPU
# def model_prediction():
#   pass
# # https://discuss.huggingface.co/t/issues-with-sadtalker-zerogpu-spaces-inquiry-about-community-grant/110625/10
# if os.environ.get("SPACES_ZERO_GPU") is not None:
#     import spaces
# else:
#     class spaces:
#         @staticmethod
#         def GPU(func):
#             def wrapper(*args, **kwargs):
#                 return func(*args, **kwargs)
#             return wrapper
#pnp_path = os.path.join("data", "evaluation_p_np_metrics.csv")
pnp_path = "concatenated_output.csv"
PATH_PKL_TABLES = 'tables_dict_beaver.pkl'
js_func = """
function refresh() {
    const url = new URL(window.location);
    if (url.searchParams.get('__theme') !== 'light') {
        url.searchParams.set('__theme', 'light');
        window.location.href = url.href;
    }
}
"""
reset_flag = False
flag_TQA = False
with open('style.css', 'r') as file:
    css = file.read()
# DataFrame di default
df_default = pd.DataFrame({
    'Name': ['Alice', 'Bob', 'Charlie'],
    'Age': [25, 30, 35],
    'City': ['New York', 'Los Angeles', 'Chicago']
})
models_path ="models.csv"
# Variabile globale per tenere traccia dei dati correnti
df_current = df_default.copy()
description = """## π Comparison of Proprietary and Non-Proprietary Databases  
                    ### β€ **Proprietary** :
                    ###                β Economic π°, Medical π₯, Financial π³, Miscellaneous π
                    ###                β BEAVER (FAC BUILDING ADDRESS π’ , TIME QUARTER β±οΈ)
                    ### β€ **Non-Proprietary** 
                    ###                β Spider 1.0 π·οΈ"""
prompt_default = "Translate the following question in SQL code to be executed over the database to fetch the answer.\nReturn the sql code in ```sql ```\nQuestion\n{question}\nDatabase Schema\n{db_schema}\n"
prompt_default_tqa = "Return the answer of the following question based on the provided database. Return your answer as the result of a query executed over the database. Namely, as a list of list where the first list represent the tuples and the second list the values in that tuple.\n Return the answer in answer tag as  .\n Question \n {question}\n Database Schema\n {db_schema}\n"
input_data = {
    'input_method': "", 
    'data_path': "",
    'db_name': "", 
    'data': {
        'data_frames': {},    # dictionary of dataframes
        'db': None,             # SQLITE3 database object
        'selected_tables' :[]
    },
    'models': [],
    'prompt': prompt_default
}
def load_data(file, path, use_default):
    """Carica i dati da un file, un percorso o usa il DataFrame di default."""
    global df_current
    if file is not None:
        try:
            input_data["input_method"] = 'uploaded_file'
            input_data["db_name"] = os.path.splitext(os.path.basename(file))[0]
            if file.endswith('.sqlite'):
                #return 'Error: The uploaded file is not a valid SQLite database.'
                input_data["data_path"] = file #os.path.join(".", "data", "data_interface",f"{input_data['db_name']}.sqlite")
            else:
                #change path
                input_data["data_path"] = os.path.join(".", f"{input_data['db_name']}.sqlite")
            input_data["data"] = us.load_data(file, input_data["db_name"])
            df_current = input_data["data"]['data_frames'].get('MyTable', df_default)  # Carica il DataFrame
            if(input_data["data"]['data_frames'] and input_data["data"]["db"] is None): #for csv and xlsx files
                table2primary_key = {}
                for table_name, df in input_data["data"]['data_frames'].items():
                    # Assign primary keys for each table
                    table2primary_key[table_name] = 'id'
                input_data["data"]["db"] = SqliteConnector(
                    relative_db_path=input_data["data_path"],
                    db_name=input_data["db_name"],
                    tables= input_data["data"]['data_frames'],
                    table2primary_key=table2primary_key
                )
            return input_data["data"]['data_frames']
        except Exception as e:
            return f'Errore nel caricamento del file: {e}'
    if use_default:
        if(use_default == 'Custom'):
            input_data["input_method"] = 'custom'
            #input_data["data_path"] = os.path.join(".", "data", "data_interface", "mytable_0.sqlite")
            input_data["data_path"] = os.path.join(".","mytable_0.sqlite")
            #if file already exist
            while os.path.exists(input_data["data_path"]):
                input_data["data_path"] = us.increment_filename(input_data["data_path"])
            input_data["db_name"] = os.path.splitext(os.path.basename(input_data["data_path"]))[0]
            input_data["data"]['data_frames'] = {'MyTable': df_current}
            if(input_data["data"]['data_frames']):
                table2primary_key = {}
                for table_name, df in input_data["data"]['data_frames'].items():
                    # Assign primary keys for each table
                    table2primary_key[table_name] = 'id'
                input_data["data"]["db"] = SqliteConnector(
                    relative_db_path=input_data["data_path"],
                    db_name=input_data["db_name"],
                    tables= input_data["data"]['data_frames'],
                    table2primary_key=table2primary_key
                )
            df_current = df_default.copy()  # Ripristina i dati di default
            return input_data["data"]['data_frames']
        
        if(use_default == 'Proprietary vs Non-proprietary'):
            input_data["input_method"] = 'default'
            #input_data["data_path"] = os.path.join(".", "data", "data_interface", "default.sqlite")
            #input_data["data_path"] = os.path.join(".", "data", "spider_databases", "defeault.sqlite")
            #input_data["db_name"] = "default"
            #input_data["data"]['db'] =  SqliteConnector(relative_db_path=input_data["data_path"], db_name=input_data["db_name"])
            input_data["data"]['data_frames'] = us.load_tables_dict_from_pkl(PATH_PKL_TABLES)
            return input_data["data"]['data_frames']
    
    selected_inputs = sum([file is not None, bool(path), use_default])
    if selected_inputs > 1:
        return 'Error: Select only one input method at a time.'
        
    return input_data["data"]['data_frames']
def preview_default(use_default, file):
    if file:
        return gr.DataFrame(interactive=True, visible = False, value = df_default), gr.update(value="## β
 File successfully uploaded!", visible=True)
    else :
        if use_default == 'Custom':
                return gr.DataFrame(interactive=True, visible = True, value = df_default), gr.update(value="## π Toy Table", visible=True)
        else:
            return gr.DataFrame(interactive=False, visible = False, value = df_default), gr.update(value = description, visible=True)
    #return gr.DataFrame(interactive=True, value = df_current)  # Mostra il DataFrame corrente, che potrebbe essere stato modificato
def update_df(new_df):
    """Aggiorna il DataFrame corrente."""
    global df_current  # Usa la variabile globale per aggiornarla
    df_current = new_df
    return df_current
def open_accordion(target):
    # Apre uno e chiude l'altro
    if target == "reset":
        df_current = df_default.copy()
        input_data['input_method'] = ""
        input_data['data_path'] = ""
        input_data['db_name'] = ""
        input_data['data']['data_frames'] = {}
        input_data['data']['selected_tables'] = []
        input_data['data']['db'] = None
        input_data['models'] = []
        return gr.update(open=True), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(open=False, visible=False), gr.update(value='Proprietary vs Non-proprietary'), gr.update(value=None)
    elif target == "model_selection":
        return gr.update(open=False), gr.update(open=False), gr.update(open=True, visible=True), gr.update(open=False), gr.update(open=False)
# Interfaccia Gradio
#with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as interface:
    with gr.Row():
        with gr.Column(scale=1):
            gr.Image(
                value=os.path.join(".", "qatch_logo.png"),
                show_label=False,
                container=False,
                interactive=False,
                show_fullscreen_button=False,
                show_download_button=False,
                show_share_button=False,
                height=150,  # in pixel
                width=300
            )
        with gr.Column(scale=1):
            pass
    data_state = gr.State(None)  # Memorizza i dati caricati
    upload_acc = gr.Accordion("Upload data section", open=True, visible=True)
    select_table_acc = gr.Accordion("Select tables section", open=False, visible=False)
    select_model_acc = gr.Accordion("Select models section", open=False, visible=False)
    qatch_acc = gr.Accordion("QATCH execution section", open=False, visible=False)
    metrics_acc = gr.Accordion("Metrics section", open=False, visible=False)
    #################################
    #       DATABASE INSERTION      #
    #################################
    with upload_acc:
        gr.Markdown("## π₯Choose data input method")
        with gr.Row():
            default_checkbox = gr.Radio(label = "Explore the comparison between proprietary and non-proprietary databases or edit a toy table with the values you prefer", choices=['Proprietary vs Non-proprietary', 'Custom'], value='Proprietary vs Non-proprietary')
            #default_checkbox = gr.Checkbox(label="Use default DataFrame"
        
        table_default = gr.Markdown(description, visible=True)
        preview_output = gr.DataFrame(interactive=False, visible=False, value=df_default)
        gr.Markdown("## π Or upload your data")
        file_input = gr.File(label="Drag and drop a file", file_types=[".csv", ".xlsx", ".sqlite"])
        submit_button = gr.Button("Load Data")  # Disabled by default
        output = gr.JSON(visible=False)  # Dictionary output
        # Function to enable the button if there is data to load
        def enable_submit(file, use_default):
            return gr.update(interactive=bool(file or use_default))
        # Function to uncheck the checkbox if a file is uploaded
        def deselect_default(file):
            if file:
                return gr.update(value='Proprietary vs Non-proprietary')
            return gr.update()
        
        def enable_disable_first(enable):
            return (
                gr.update(interactive=enable),
                gr.update(interactive=enable), 
                gr.update(interactive=enable), 
                gr.update(interactive=enable)
            )
        # Enable the button when inputs are provided
        #file_input.change(fn=enable_submit, inputs=[file_input, default_checkbox], outputs=[submit_button])
        #default_checkbox.change(fn=enable_submit, inputs=[file_input, default_checkbox], outputs=[submit_button])
        
        # Show preview of the default DataFrame when checkbox is selected
        default_checkbox.change(fn=preview_default, inputs=[default_checkbox, file_input], outputs=[preview_output, table_default])
        file_input.change(fn=preview_default, inputs=[default_checkbox, file_input], outputs=[preview_output, table_default])
        preview_output.change(fn=update_df, inputs=[preview_output], outputs=[preview_output])
        # Uncheck the checkbox when a file is uploaded
        file_input.change(fn=deselect_default, inputs=[file_input], outputs=[default_checkbox])
        def handle_output(file, use_default):
            """Handles the output when the 'Load Data' button is pressed."""
            result = load_data(file, None, use_default)
            
            if isinstance(result, dict):  # If result is a dictionary of DataFrames
                if len(result) == 1:  # If there's only one table
                    input_data['data']['selected_tables'] = list(input_data['data']['data_frames'].keys())
                    return (
                        gr.update(visible=False),  # Hide JSON output
                        result,  # Save the data state
                        gr.update(visible=False),  # Hide table selection
                        result,  # Maintain the data state
                        gr.update(interactive=False),  # Disable the submit button
                        gr.update(visible=True, open=True),  # Proceed to select_model_acc
                        gr.update(visible=True, open=False)
                    )
                else:
                    return (
                        gr.update(visible=False),
                        result,
                        gr.update(open=True, visible=True),
                        result,
                        gr.update(interactive=False),
                        gr.update(visible=False),  # Keep current behavior
                        gr.update(visible=True, open=False)
                    )
            else:
                return (
                    gr.update(visible=False),
                    None,
                    gr.update(open=False, visible=True),
                    None,
                    gr.update(interactive=True),
                    gr.update(visible=False),
                    gr.update(visible=True, open=False)
                )
        submit_button.click(
            fn=handle_output,
            inputs=[file_input, default_checkbox], 
            outputs=[output, output, select_table_acc, data_state, submit_button, select_model_acc, upload_acc]
        )
        
        submit_button.click(
            fn=enable_disable_first,
            inputs=[gr.State(False)], 
            outputs=[
                preview_output,
                submit_button,
                file_input,
                default_checkbox
            ]
        )
    ######################################
    #        TABLE SELECTION PART        #
    ######################################
    with select_table_acc:
        previous_selection = gr.State([])
        table_selector = gr.CheckboxGroup(choices=[], label="Select tables from the choosen database", value=[])
        excluded_tables_info = gr.HTML(label="Non-selectable tables (too many columns)", visible=False)
        table_outputs = [gr.DataFrame(label=f"Table {i+1}", interactive=True, visible=False) for i in range(50)]
        selected_table_names = gr.Textbox(label="Selected tables", visible=False, interactive=False)
        # Model selection button (initially disabled)
        open_model_selection = gr.Button("Choose your models", interactive=False)
        def update_table_list(data):
            """Dynamically updates the list of available tables and excluded ones."""
            if isinstance(data, dict) and data:
                table_names = []
                excluded_tables = []
                data_frames = input_data['data'].get('data_frames', {})
                available_tables = []
                for name, df in data.items():
                    df_real = data_frames.get(name, None)
                    if input_data['input_method'] != "default":
                        if df_real is not None and df_real.shape[1] > 15:
                            excluded_tables.append(name)
                        else:
                            available_tables.append(name)
                    else:
                        available_tables.append(name)
                if input_data['input_method'] == "default":
                    table_names.append("All")
                    excluded_tables = []
                elif  len(available_tables) < 6:
                    table_names.append("All")
                table_names.extend(available_tables)
                if excluded_tables and input_data['input_method'] != "default" :
                    excluded_text = "β οΈ The following tables have more than 15 columns and cannot be selected:
" + "
".join(f"- {t}" for t in excluded_tables)
                    excluded_visible = True
                else:
                    excluded_text = ""
                    excluded_visible = False
                return [
                    gr.update(choices=table_names, value=[]),  # CheckboxGroup update
                    gr.update(value=excluded_text, visible=excluded_visible)  # HTML display update
                ]
            return [
                gr.update(choices=[], value=[]),
                gr.update(value="", visible=False)
            ]
        
        def show_selected_tables(data, selected_tables):
            updates = []
            data_frames = input_data['data'].get('data_frames', {})
            available_tables = []
            for name, df in data.items():
                df_real = data_frames.get(name)
                if input_data['input_method'] != "default" :
                    if df_real is not None and df_real.shape[1] <= 15:
                        available_tables.append(name)
                else:
                    available_tables.append(name)
            input_method = input_data['input_method']
            allow_all = input_method == "default" or len(available_tables) < 6
            selected_set = set(selected_tables)
            tables_set = set(available_tables)
            if allow_all:
                if "All" in selected_set:
                    selected_tables = ["All"] + available_tables
                elif selected_set == tables_set:
                    selected_tables = []
                else:
                    selected_tables = [t for t in selected_tables if t in available_tables]
            else:
                selected_tables = [t for t in selected_tables if t in available_tables and t != "All"][:5]
            tables = {name: data[name] for name in selected_tables if name in data}
            for i, (name, df) in enumerate(tables.items()):
                updates.append(gr.update(value=df, label=f"Table: {name}", visible=True, interactive=False))
            for _ in range(len(tables), 50):
                updates.append(gr.update(visible=False))
            updates.append(gr.update(interactive=bool(tables)))
            if allow_all:
                updates.insert(0, gr.update(
                    choices=["All"] + available_tables,
                    value=selected_tables
                ))
            else:
                if len(selected_tables) >= 5:
                    updates.insert(0, gr.update(
                        choices=selected_tables,
                        value=selected_tables
                    ))
                else:
                    updates.insert(0, gr.update(
                        choices=available_tables,
                        value=selected_tables
                    ))
            return updates
        def show_selected_table_names(data, selected_tables):
            """Displays the names of the selected tables when the button is pressed."""
            if selected_tables:
                available_tables = list(data.keys())  # Actually available names
                if "All" in selected_tables:
                    selected_tables = available_tables
                    if (input_data['input_method'] != "default") : selected_tables = [t for t in selected_tables if len(data[t].columns) <= 15]
                
                input_data['data']['selected_tables'] = selected_tables
                return gr.update(value=", ".join(selected_tables), visible=False)
            return gr.update(value="", visible=False)
        # Automatically updates the checkbox list when `data_state` changes
        data_state.change(fn=update_table_list, inputs=[data_state], outputs=[table_selector, excluded_tables_info])
        # Updates the visible tables and the button state based on user selections
        #table_selector.change(fn=show_selected_tables, inputs=[data_state, table_selector], outputs=table_outputs + [open_model_selection])
        table_selector.change(
            fn=show_selected_tables,
            inputs=[data_state, table_selector],
            outputs=[table_selector] + table_outputs + [open_model_selection]
        )
        # Shows the list of selected tables when "Choose your models" is clicked
        open_model_selection.click(fn=show_selected_table_names, inputs=[data_state, table_selector], outputs=[selected_table_names])
        open_model_selection.click(open_accordion, inputs=gr.State("model_selection"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc])
        
        reset_data = gr.Button("Back to upload data section")
        
        reset_data.click(
            fn=enable_disable_first,
            inputs=[gr.State(True)], 
            outputs=[
                preview_output,
                submit_button,
                file_input,
                default_checkbox
            ]
        )
        reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
    ####################################
    #       MODEL SELECTION PART       #
    ####################################
    with select_model_acc:
        gr.Markdown("# Model Selection")
        # Assume that `us.read_models_csv` also returns the image path
        model_list_dict = us.read_models_csv(models_path)
        model_list = [model["code"] for model in model_list_dict]
        model_images = [model["image_path"] for model in model_list_dict]
        model_names = [model["name"] for model in model_list_dict]
        # Create a mapping between model_list and model_images_names
        model_mapping = dict(zip(model_list, model_names))
        model_mapping_reverse = dict(zip(model_names, model_list))
        model_checkboxes = []
        rows = []
        
        # Dynamically create checkboxes with images (3 per row)
        for i in range(0, len(model_list), 3):
            with gr.Row():
                cols = []
                for j in range(3):
                    if i + j < len(model_list):
                        model = model_list[i + j]
                        image_path = model_images[i + j]
                        with gr.Column():
                            gr.Image(image_path,
                                     show_label=False,
                                     container=False,
                                     interactive=False,
                                     show_fullscreen_button=False,
                                     show_download_button=False,
                                     show_share_button=False)
                            checkbox = gr.Checkbox(label=model_mapping[model], value=False)
                            model_checkboxes.append(checkbox)
                            cols.append(checkbox)
                rows.append(cols)
        selected_models_output = gr.JSON(visible=False)
        # Function to get selected models
        def get_selected_models(*model_selections):
            selected_models = [model for model, selected in zip(model_list, model_selections) if selected]
            input_data['models'] = selected_models
            button_state = bool(selected_models and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
            return selected_models, gr.update(open=True, visible=True), gr.update(interactive=button_state), gr.update(interactive=button_state)
        # Add the Textbox to the interface
        prompt = gr.TextArea(
            label="Customise the prompt for selected models here or leave the default one.",
            placeholder=prompt_default,
            elem_id="custom-textarea"
        )
        warning_prompt = gr.Markdown(value="## Error in the prompt format", visible=False)
        # Submit button (initially disabled)
        with gr.Row():
            submit_models_button = gr.Button("Submit Models for NL2SQL task", interactive=False)
            submit_models_button_tqa = gr.Button("Submit Models for TQA task", interactive=False)
        def check_prompt(prompt):
            #TODO
            missing_elements = []
            if(prompt==""):
                input_data["prompt"] = prompt_default
                button_state = bool(len(input_data['models']) > 0 and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
            else:
                input_data["prompt"]=prompt
                if "{db_schema}" not in prompt:
                    missing_elements.append("{db_schema}")
                if "{question}" not in prompt:
                    missing_elements.append("{question}")
                button_state = bool(len(input_data['models']) > 0 and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
                if missing_elements:
                    return gr.update(
                        value=f"
"
                            f"β Missing {', '.join(missing_elements)} in the prompt β
",
                        visible=True
                    ), gr.update(interactive=button_state)
            return gr.update(visible=False),  gr.update(interactive=button_state)
        prompt.change(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button, submit_models_button_tqa])
        # Link checkboxes to selection events
        for checkbox in model_checkboxes:
            checkbox.change(
                fn=get_selected_models, 
                inputs=model_checkboxes, 
                outputs=[selected_models_output, select_model_acc, submit_models_button, submit_models_button_tqa]
            )
        prompt.change(
            fn=get_selected_models, 
            inputs=model_checkboxes, 
            outputs=[selected_models_output, select_model_acc, submit_models_button, submit_models_button_tqa]
        )
        submit_models_button.click(
            fn=lambda *args: (get_selected_models(*args), gr.update(open=False, visible=True), gr.update(open=True, visible=True)),
            inputs=model_checkboxes,
            outputs=[selected_models_output, select_model_acc, qatch_acc]
        )
        submit_models_button_tqa.click(
            fn=lambda *args: (get_selected_models(*args), gr.update(open=False, visible=True), gr.update(open=True, visible=True)),
            inputs=model_checkboxes,
            outputs=[selected_models_output, select_model_acc, qatch_acc]
        )
        def change_flag():
            global flag_TQA
            flag_TQA = True
        def dis_flag():
            global flag_TQA
            flag_TQA = False
        submit_models_button.click(fn = dis_flag, inputs=[], outputs=[])
        submit_models_button_tqa.click(fn = change_flag, inputs=[], outputs=[])
        def enable_disable(enable):
            return (
                *[gr.update(interactive=enable) for _ in model_checkboxes],
                gr.update(interactive=enable),
                gr.update(interactive=enable),
                gr.update(interactive=enable), 
                gr.update(interactive=enable), 
                gr.update(interactive=enable), 
                gr.update(interactive=enable),
                *[gr.update(interactive=enable) for _ in table_outputs],
                gr.update(interactive=enable),
                gr.update(interactive=enable)
            )
        
        reset_data = gr.Button("Back to upload data section")
        
        submit_models_button.click(
            fn=enable_disable,
            inputs=[gr.State(False)],
            outputs=[
                *model_checkboxes,
                submit_models_button,
                preview_output,
                submit_button,
                file_input,
                default_checkbox,
                table_selector,
                *table_outputs,
                open_model_selection,
                submit_models_button_tqa
            ]
        )
        submit_models_button_tqa.click(
            fn=enable_disable,
            inputs=[gr.State(False)],
            outputs=[
                *model_checkboxes,
                submit_models_button,
                preview_output,
                submit_button,
                file_input,
                default_checkbox,
                table_selector,
                *table_outputs,
                open_model_selection,
                submit_models_button_tqa
            ]
        )
        
        reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
        
        reset_data.click(
            fn=enable_disable,
            inputs=[gr.State(True)],
            outputs=[
                *model_checkboxes,
                submit_models_button,
                preview_output,
                submit_button,
                file_input,
                default_checkbox,
                table_selector,
                *table_outputs,
                open_model_selection,
                submit_models_button_tqa
            ]
        )
    #############################
    #      QATCH EXECUTION      #
    #############################
    with qatch_acc:
        def change_text(text):
            return text
        loading_symbols= {1:"π", 
                          2: "π π", 
                          3: "π π π", 
                          4: "π π π π", 
                          5: "π π π π π",
                          6: "π π π π π π",
                          7: "π π π π π π π",                          
                          8: "π π π π π π π π",
                          9: "π π π π π π π π π",
                          10:"π π π π π π π π π π",                                                 
                        }
        def generate_loading_text(percent):
            num_symbols = (round(percent) % 11) + 1
            symbols = loading_symbols.get(num_symbols, "π")
            mirrored_symbols = f'{symbols.strip()}'
            css_symbols = f'{symbols.strip()}'
            return f"""
                
                    {css_symbols}
                    
                        Generation {percent}%
                    
                    {mirrored_symbols}
                
                """
        
        def generate_eval_text(text):
            symbols = "π‘ "
            mirrored_symbols = f'{symbols.strip()}'
            css_symbols = f'{symbols.strip()}'
            return f"""
                
                    {css_symbols}
                    
                        {text}
                    
                    {mirrored_symbols}
                
                """
        
        def qatch_flow_nl_sql():
            global reset_flag
            global flag_TQA
            predictions_dict = {model: pd.DataFrame(columns=['id', 'question', 'predicted_sql', 'time', 'query', 'db_path']) for model in model_list}
            metrics_conc = pd.DataFrame()
            columns_to_visulize = ["db_path", "tbl_name", "test_category", "sql_tag", "query", "question", "predicted_sql", "time", "price", "answer"]
            if (input_data['input_method']=="default"):
                target_df = us.load_csv(pnp_path)                 #target_df = us.load_csv("priority_non_priority_metrics.csv")
                #predictions_dict = {model: pd.DataFrame(columns=target_df.columns) for model in model_list}
                target_df = target_df[target_df["tbl_name"].isin(input_data['data']['selected_tables'])]
                target_df = target_df[target_df["model"].isin(input_data['models'])]
                predictions_dict = {model: target_df[target_df["model"] == model] if model in target_df["model"].unique() else pd.DataFrame(columns=target_df.columns) for model in model_list}
                reset_flag = False
                for model in input_data['models']:
                    model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None)                
                    yield gr.Markdown(visible=False), gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
                    count=1
                    for _, row in predictions_dict[model].iterrows():
                    #for index, row in target_df.iterrows():
                        if (reset_flag == False):
                            percent_complete = round(count / len(predictions_dict[model]) * 100, 2)
                            count=count+1
                            load_text = f"{generate_loading_text(percent_complete)}"
                            question = row['question']
                            
                            display_question = f"""Natural Language:
 
                                                    
                                                """
                            yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
                            prediction = row['predicted_sql']
                            
                            display_prediction = f"""Predicted SQL:
 
                                                    
                                                """                                
                                
                            yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
                    yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
                metrics_conc = target_df    
                if 'valid_efficency_score' not in metrics_conc.columns:
                    metrics_conc['valid_efficency_score'] = metrics_conc['VES']
                eval_text = generate_eval_text("End evaluation")
                yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
            
            else:
                orchestrator_generator = OrchestratorGenerator()
                target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_include=input_data['data']['selected_tables'])
                #create target_df[target_answer]
                if flag_TQA : 
                    if (input_data["prompt"] == prompt_default):
                        input_data["prompt"] = prompt_default_tqa
                    target_df = us.extract_answer(target_df)
                predictor = ModelPrediction()
                reset_flag = False
                for model in input_data["models"]:
                    model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None)
                    yield gr.Markdown(visible=False), gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
                    count=0
                    for index, row in target_df.iterrows():
                        if (reset_flag == False):
                            percent_complete = round(((index+1) / len(target_df)) * 100, 2)
                            load_text = f"{generate_loading_text(percent_complete)}"
                            
                            question = row['question']
                            display_question = f"""Natural Language:
 
                                                    
                                                """
                            yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model]for model in model_list]
                            #samples = us.generate_some_samples(input_data["data_path"], row["tbl_name"])
                            model_to_send = None if not flag_TQA else model
                            db_schema_text = utils_get_db_tables_info.utils_extract_db_schema_as_string(
                                db_id = input_data["db_name"], 
                                base_path = input_data["data_path"], 
                                normalize=False, 
                                sql=row["query"],
                                get_insert_into=True,
                                model = model_to_send,
                                prompt = input_data["prompt"].format(question=question, db_schema=""),
                            )
                            
                            #prompt_to_send = us.prepare_prompt(input_data["prompt"], question, schema_text, samples)
                            prompt_to_send = input_data["prompt"]
                            #PREDICTION SQL
                            # TODO add button for QA or SP and pass to .make_prediction parameter TASK
                            if flag_TQA: task="QA" 
                            else: task="SP"
                            start_time = time.time()
                            response = predictor.make_prediction(
                                question=question,
                                db_schema=db_schema_text,
                                model_name=model,
                                prompt=f"{prompt_to_send}",
                                task=task 
                            )
                            prediction = response['response_parsed']
                            price = response['cost']
                            answer = response['response']
                            end_time = time.time()
                            if flag_TQA:
                                task_string = "Answer"
                            else:
                                task_string = "SQL"
                            display_prediction = f"""Predicted {task_string}:
 
                                                    
                                                """
                            # Create a new row as dataframe
                            new_row = pd.DataFrame([{
                                'id': index,
                                'question': question,
                                'predicted_sql': prediction,
                                'time': end_time - start_time,
                                'query': row["query"],
                                'db_path': input_data["data_path"],
                                'price':price,
                                'answer': answer, 
                                'number_question':count,
                                'target_answer' : row["target_answer"] if flag_TQA else None,
                            }]).dropna(how="all")  # Remove only completely empty rows
                            count=count+1
                            # TODO: use a for loop
                            if (flag_TQA) : 
                                new_row['predicted_answer'] = prediction 
                            for col in target_df.columns:
                                if col not in new_row.columns:
                                    new_row[col] = row[col]
                            # Update model's prediction dataframe incrementally
                            if not new_row.empty:
                                predictions_dict[model] = pd.concat([predictions_dict[model], new_row], ignore_index=True)
                            # yield gr.Textbox(), gr.Textbox(prediction), *[predictions_dict[model] for model in input_data["models"]], None
                            yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model]for model in model_list]
                    yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list]
                    # END 
                eval_text = generate_eval_text("Evaluation")
                yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
                
                evaluator = OrchestratorEvaluator()
                for model in input_data["models"]:
                    if not flag_TQA:
                        metrics_df_model = evaluator.evaluate_df(
                            df=predictions_dict[model],
                            target_col_name="query",
                            prediction_col_name="predicted_sql",
                            db_path_name="db_path"
                        )
                    else: 
                        metrics_df_model = us.evaluate_answer(predictions_dict[model])
                    metrics_df_model['model'] = model
                    metrics_conc = pd.concat([metrics_conc, metrics_df_model], ignore_index=True)
                if 'valid_efficency_score' not in metrics_conc.columns and 'VES' in metrics_conc.columns:
                    metrics_conc['valid_efficency_score'] = metrics_conc['VES']
                
                eval_text = generate_eval_text("End evaluation")
                yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
        # Loading Bar
        with gr.Row():
            # progress = gr.Progress()
            variable = gr.Markdown()
        # NL -> MODEL -> Generated Query
        with gr.Row():
            with gr.Column():
                with gr.Column():
                    question_display = gr.Markdown()
            with gr.Column():
                model_logo = gr.Image(visible=True, 
                                      show_label=False,
                                      container=False,
                                      interactive=False,
                                      show_fullscreen_button=False,
                                      show_download_button=False,
                                      show_share_button=False)
            with gr.Column():
                with gr.Column():
                    prediction_display = gr.Markdown()
        dataframe_per_model = {}
        with gr.Tabs() as model_tabs:
            tab_dict = {}
            for model, model_name in zip(model_list, model_names):
                with gr.TabItem(model_name, visible=(model in input_data["models"])) as tab:
                    gr.Markdown(f"**Results for {model}**")
                    tab_dict[model] = tab
                    dataframe_per_model[model] = gr.DataFrame()
                    #TODO download metrics per model
                    # download_pred_model = gr.DownloadButton(label="Download Prediction per Model", visible=False)
        evaluation_loading = gr.Markdown()
        def change_tab():
            return [gr.update(visible=(model in input_data["models"])) for model in model_list]
        submit_models_button.click(
            change_tab,
            inputs=[],
            outputs=[tab_dict[model] for model in model_list]  # Update TabItem visibility
        )
        submit_models_button_tqa.click(
            change_tab,
            inputs=[],
            outputs=[tab_dict[model] for model in model_list]  # Update TabItem visibility
        )
        selected_models_display = gr.JSON(label="Final input data", visible=False)
        metrics_df = gr.DataFrame(visible=False)
        metrics_df_out = gr.DataFrame(visible=False)
        submit_models_button.click(
            fn=qatch_flow_nl_sql,
            inputs=[],
            outputs=[evaluation_loading, model_logo, variable, question_display, prediction_display, metrics_df] + list(dataframe_per_model.values())
        )
        submit_models_button_tqa.click(
            fn=qatch_flow_nl_sql,
            inputs=[],
            outputs=[evaluation_loading, model_logo, variable, question_display, prediction_display, metrics_df] + list(dataframe_per_model.values())
        )
        submit_models_button.click(
            fn=lambda: gr.update(value=input_data), 
            outputs=[selected_models_display]
        )
        submit_models_button_tqa.click(
            fn=lambda: gr.update(value=input_data), 
            outputs=[selected_models_display]
        )
        # Works for METRICS
        metrics_df.change(fn=change_text, inputs=[metrics_df], outputs=[metrics_df_out])
        proceed_to_metrics_button = gr.Button("Proceed to Metrics", visible=False)
        proceed_to_metrics_button.click(
            fn=lambda: (gr.update(open=False, visible=True), gr.update(open=True, visible=True)), 
            outputs=[qatch_acc, metrics_acc]
        )
        def allow_download(metrics_df_out):
            #path = os.path.join(".", "data", "data_results", "results.csv")
            path = os.path.join(".", "results.csv")
            metrics_df_out.to_csv(path, index=False)
            return gr.update(value=path, visible=True), gr.update(visible=True), gr.update(interactive=True)
        
        download_metrics = gr.DownloadButton(label="Download Metrics Evaluation", visible=False)
        submit_models_button.click(
            fn=lambda: gr.update(visible=False),
            outputs=[download_metrics]
        )
        submit_models_button_tqa.click(
            fn=lambda: gr.update(visible=False),
            outputs=[download_metrics]
        )
        def refresh():
            global reset_flag
            global flag_TQA
            reset_flag = True
            flag_TQA = False
        reset_data = gr.Button("Back to upload data section", interactive=True)
        
        metrics_df_out.change(fn=allow_download, inputs=[metrics_df_out], outputs=[download_metrics, proceed_to_metrics_button, reset_data])
        reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
        #WHY NOT WORKING?
        reset_data.click(
            fn=lambda: gr.update(visible=False),
            outputs=[download_metrics]
        )
        reset_data.click(refresh)
        reset_data.click(
            fn=enable_disable,
            inputs=[gr.State(True)],
            outputs=[
                *model_checkboxes,
                submit_models_button,
                preview_output,
                submit_button,
                file_input,
                default_checkbox,
                table_selector,
                *table_outputs,
                open_model_selection,
                submit_models_button_tqa
            ]
        )
    ##########################################
    #     METRICS VISUALIZATION SECTION      #
    ##########################################
    with metrics_acc:
        #data_path = 'test_results_metrics1.csv'
        @gr.render(inputs=metrics_df_out)
        def function_metrics(metrics_df_out):
            
            ####################################
            #     UTILS FUNCTIONS SECTION      #
            ####################################
            
            def load_data_csv_es():
                
                if input_data["input_method"]=="default":  
                    global flag_TQA
                    df = pd.read_csv(pnp_path)
                    df = df[df['model'].isin(input_data["models"])]
                    df = df[df['tbl_name'].isin(input_data["data"]["selected_tables"])]
                    
                    df['model'] = df['model'].replace('DeepSeek-R1-Distill-Llama-70B', 'DS-Llama3 70B')
                    df['model'] = df['model'].replace('gpt-3.5', 'GPT-3.5')
                    df['model'] = df['model'].replace('gpt-4o-mini', 'GPT-4o-mini')
                    df['model'] = df['model'].replace('llama-70', 'Llama-70B')
                    df['model'] = df['model'].replace('llama-8', 'Llama-8B')
                    df['test_category'] = df['test_category'].replace('many-to-many-generator', 'MANY-TO-MANY')
                    if (flag_TQA) : flag_TQA = False #TODO delete after make pred
                    return df
                return metrics_df_out
            
            def calculate_average_metrics(df, selected_metrics):
                # Exclude the 'tuple_order' column from the selected metrics
                #TODO tuple_order has NULL VALUE 
                selected_metrics = [metric for metric in selected_metrics if metric != 'tuple_order']
                #print(df[selected_metrics])
                df['avg_metric'] = df[selected_metrics].mean(axis=1)
                return df
            def generate_model_colors():
                """Generates a unique color map for models in the dataset."""
                df = load_data_csv_es()
                unique_models = df['model'].unique()  # Extract unique models
                num_models = len(unique_models)
                
                # Use the Plotly color scale (you can change it if needed)
                color_palette = ['#00B4D8', '#BCE784', '#C84630', '#F79256', '#D269FC']
                #color_palette = pc.qualitative.Plotly  # ['#636EFA', '#EF553B', '#00CC96', ...]
                
                # If there are more models than colors, cycle through them
                colors = {model: color_palette[i % len(color_palette)] for i, model in enumerate(unique_models)}
                
                return colors
            MODEL_COLORS = generate_model_colors()
            
            def generate_db_category_colors():
                """Assigns 3 distinct colors to db_category groups."""
                return {
                    "Spider": "#1f77b4",        # blu
                    "Beaver": "#ff7f0e",        # arancione
                    "Economic": "#2ca02c",      # tutti gli altri verdi
                    "Financial": "#2ca02c",
                    "Medical": "#2ca02c",
                    "Miscellaneous": "#2ca02c"
                }
            DB_CATEGORY_COLORS = generate_db_category_colors()
            def normalize_valid_efficency_score(df):
                df['valid_efficency_score'] = df['valid_efficency_score'].replace([np.nan, ''], 0)
                df['valid_efficency_score'] = df['valid_efficency_score'].astype(int)
                min_val = df['valid_efficency_score'].min()
                max_val = df['valid_efficency_score'].max()
                
                if min_val == max_val :
                        # All values are equal, so for avoid division by zero, we set the score to 1/0
                        if min_val == None:
                            df['valid_efficency_score'] = 0
                        else:
                            df['valid_efficency_score'] = 1.0
                else:
                    df['valid_efficency_score'] = (
                        df['valid_efficency_score'] - min_val
                    ) / (max_val - min_val)
                
                return df
            ####################################
            #     GRAPH FUNCTIONS SECTION      #
            ####################################
            # BAR CHART FOR AVERAGE METRICS WITH UPDATE FUNCTION
            def plot_metric(df, radio_metric, qatch_selected_metrics, external_selected_metric, group_by, selected_models):
                df = df[df['model'].isin(selected_models)]
                df = normalize_valid_efficency_score(df)
                # Mappatura nomi leggibili -> tecnici
                qatch_selected_internal = [qatch_metrics_dict[label] for label in qatch_selected_metrics]
                external_selected_internal = [external_metrics_dict[label] for label in external_selected_metric]
                
                selected_metrics = qatch_selected_internal if radio_metric == "Qatch" else external_selected_internal
                df = calculate_average_metrics(df, selected_metrics)
                if group_by == ["model"]:
                    # Bar plot per "model"
                    avg_metrics = df.groupby("model")['avg_metric'].mean().reset_index()
                    avg_metrics['text_label'] = avg_metrics['avg_metric'].apply(lambda x: f'{x:.2f}')
                    fig = px.bar(
                        avg_metrics, 
                        x="model", 
                        y="avg_metric", 
                        color="model", 
                        color_discrete_map=MODEL_COLORS,
                        title='Average metrics per Model π§ ',
                        labels={"model": "Model", "avg_metric": "Average Metrics"},
                        template='simple_white',
                        #template='plotly_dark',
                        text='text_label'
                    )
                else:
                    if group_by != ["tbl_name", "model"]:
                        group_by = ["tbl_name", "model"]
                    avg_metrics = df.groupby(group_by)['avg_metric'].mean().reset_index()
                    avg_metrics['text_label'] = avg_metrics['avg_metric'].apply(lambda x: f'{x:.2f}')
                    fig = px.bar(
                        avg_metrics, 
                        x=group_by[0], 
                        y='avg_metric', 
                        color='model',  
                        color_discrete_map=MODEL_COLORS, 
                        barmode='group',
                        title=f'Average metrics per {group_by[0]} π',
                        labels={group_by[0]: group_by[0].capitalize(), 'avg_metric': 'Average Metrics'},
                        template='simple_white',
                        #template='plotly_dark',
                        text='text_label'
                    )
                fig.update_traces(textposition='outside', textfont_size=10)
                # Applica font Inter a tutto il layout
                fig.update_layout(
                    margin=dict(t=80),
                    title=dict(
                        font=dict(
                            family="Inter, sans-serif",
                            size=22,
                            #color="white"
                        ),
                        x=0.5
                    ),
                    xaxis=dict(
                        title=dict(
                            font=dict(
                                family="Inter, sans-serif",
                                size=18,
                                #color="white"
                            )
                        ),
                        tickfont=dict(
                            family="Inter, sans-serif",
                            #color="white"
                            size=16
                        )
                    ),
                    yaxis=dict(
                        title=dict(
                            font=dict(
                                family="Inter, sans-serif",
                                size=18,
                                #color="white"
                            )
                        ),
                        tickfont=dict(
                            family="Inter, sans-serif",
                            #color="white"
                        )
                    ),
                    legend=dict(
                        title=dict(
                            font=dict(
                                family="Inter, sans-serif",
                                size=16,
                                #color="white"
                            )
                        ),
                        font=dict(
                            family="Inter, sans-serif",
                            #color="white"
                        )
                    )
                )
                return gr.Plot(fig, visible=True)
            
            def update_plot(radio_metric, qatch_selected_metrics, external_selected_metric,group_by, selected_models):
                df = load_data_csv_es()
                return plot_metric(df, radio_metric, qatch_selected_metrics, external_selected_metric, group_by, selected_models)
   
            # BAR CHART FOR PROPIETARY DATASET WITH AVERAGE METRICS WITH UPDATE FUNCTION
            def plot_metric_propietary(df, radio_metric, qatch_selected_metrics, external_selected_metric, selected_models):
                if selected_models == "All":
                    selected_models = models
                else:
                    selected_models = [selected_models]
                
                df = df[df['model'].isin(selected_models)]
                df = normalize_valid_efficency_score(df)
                
                # Converti nomi leggibili -> tecnici
                qatch_selected_internal = [qatch_metrics_dict[label] for label in qatch_selected_metrics]
                external_selected_internal = [external_metrics_dict[label] for label in external_selected_metric]
                
                selected_metrics = qatch_selected_internal if radio_metric == "Qatch" else external_selected_internal
                
                df = calculate_average_metrics(df, selected_metrics)
                
                avg_metrics = df.groupby(["db_category", "model"])['avg_metric'].mean().reset_index()
                avg_metrics['text_label'] = avg_metrics['avg_metric'].apply(lambda x: f'{x:.2f}')
                fig = px.bar(
                    avg_metrics, 
                    x='db_category', 
                    y='avg_metric', 
                    color='model',  
                    color_discrete_map=MODEL_COLORS, 
                    barmode='group',
                    title='Average metrics per database types π',
                    labels={'db_path': 'DB Path', 'avg_metric': 'Average Metrics'},
                    template='simple_white',
                    text='text_label'
                )
                fig.update_traces(textposition='outside', textfont_size=14)
                # Aggiorna layout con font Inter
                fig.update_layout(
                    margin=dict(t=80),
                    title=dict(
                        font=dict(
                            family="Inter, sans-serif",
                            size=24,
                            color="black"
                        ),
                        x=0.5
                    ),
                    xaxis=dict(
                        title=dict(
                            text='Database Category',
                            font=dict(
                                family='Inter, sans-serif',
                                size=22,
                                color='black'
                            )
                        ),
                        tickfont=dict(
                            family='Inter, sans-serif',
                            color='black',
                            size=20
                        )
                    ),
                    yaxis=dict(
                        title=dict(
                            text='Average Metrics',
                            font=dict(
                                family='Inter, sans-serif',
                                size=22,
                                color='black'
                            )
                        ),
                        tickfont=dict(
                            family='Inter, sans-serif',
                            color='black'
                        )
                    ),
                    legend=dict(
                        title=dict(
                            text='Models',
                            font=dict(
                                family='Inter, sans-serif',
                                size=20,
                                color='black'
                            )
                        ),
                        font=dict(
                            family='Inter, sans-serif',
                            color='black',
                            size=18
                        )
                    )
                )
                return gr.Plot(fig, visible=True)
            
            def update_plot_propietary(radio_metric, qatch_selected_metrics, external_selected_metric, selected_models):
                df = load_data_csv_es()
                return plot_metric_propietary(df, radio_metric, qatch_selected_metrics, external_selected_metric, selected_models)
                  
                        # BAR CHART FOR PROPIETARY DATASET WITH AVERAGE METRICS WITH UPDATE FUNCTION
            
            def lollipop_propietary(selected_models):
                df = load_data_csv_es()
                # Filtra solo le categorie rilevanti
                target_cats = ["Spider", "Economic", "Financial", "Medical", "Miscellaneous", "Beaver"]
                df = df[df['db_category'].isin(target_cats)]
                df = df[df['model'].isin(selected_models)]
                
                df = normalize_valid_efficency_score(df)
                df = calculate_average_metrics(df, qatch_metrics)
                # Calcola la media per db_category e modello
                avg_metrics = df.groupby(["db_category", "model"])['avg_metric'].mean().reset_index()
                # Separa Spider e le altre 4 categorie
                spider_df = avg_metrics[avg_metrics["db_category"] == "Spider"]
                other_df = avg_metrics[avg_metrics["db_category"] != "Spider"]
                # Calcola media delle altre categorie per ciascun modello
                other_mean_df = other_df.groupby("model")["avg_metric"].mean().reset_index()
                other_mean_df["db_category"] = "Others"
                # Rinominare per chiarezza e uniformitΓ 
                spider_df = spider_df.rename(columns={"avg_metric": "Spider"})
                other_mean_df = other_mean_df.rename(columns={"avg_metric": "Others"})
                # Unione dei due dataset
                merged_df = pd.merge(spider_df[["model", "Spider"]], other_mean_df[["model", "Others"]], on="model")
                # Ordina per modello o per valore se vuoi
                merged_df = merged_df.sort_values(by="model")
                fig = go.Figure()
                # Aggiungi linee orizzontali tra Spider e Others
                for _, row in merged_df.iterrows():
                    fig.add_trace(go.Scatter(
                        x=[row["Spider"], row["Others"]],
                        y=[row["model"]] * 2,
                        mode='lines',
                        line=dict(color='gray', width=2),
                        showlegend=False
                    ))
                # Punto per Spider
                fig.add_trace(go.Scatter(
                    x=merged_df["Spider"],
                    y=merged_df["model"],
                    mode='markers',
                    name='Non-Proprietary (Spider)',
                    marker=dict(size=10, color='#C84630')
                ))
                # Punto per Others (media delle altre 4 categorie)
                fig.add_trace(go.Scatter(
                    x=merged_df["Others"],
                    y=merged_df["model"],
                    mode='markers',
                    name='Proprietary Databases',
                    marker=dict(size=10, color='#0077B6')
                ))
                fig.update_layout(
                    xaxis_title='Average Metrics',
                    yaxis_title='Models',
                    template='simple_white',
                    #template='plotly_dark',
                    margin=dict(t=80),
                    title=dict(
                        font=dict(
                            family="Inter, sans-serif",
                            size=22,
                            color="black"
                        ),
                        x=0.5,
                        text='Dumbbell graph: Non-Proprietary (Spider π·οΈ) vs Proprietary Databases π'
                    ),
                    legend_title='Type of Databases:',
                    height=600,
                    xaxis=dict(
                        title=dict(
                            text='DB Category',
                            font=dict(
                                family='Inter, sans-serif',
                                size=18,
                                color='black'
                            )
                        ),
                        tickfont=dict(
                            family='Inter, sans-serif',
                            color='black'
                        )
                    ),
                    yaxis=dict(
                        title=dict(
                            text='Average Metrics',
                            font=dict(
                                family='Inter, sans-serif',
                                size=18,
                                color='black'
                            )
                        ),
                        tickfont=dict(
                            family='Inter, sans-serif',
                            color='black'
                        )
                    ),
                    legend=dict(
                        title=dict(
                            text='Models',
                            font=dict(
                                family='Inter, sans-serif',
                                size=18,
                                color='black'
                            )
                        ),
                        font=dict(
                            family='Inter, sans-serif',
                            color='black',
                            size=14
                        )
                    )
                )
                return gr.Plot(fig, visible=True)
                  
            # RADAR OR BAR CHART BASED ON CATEGORY COUNT
            def plot_radar(df, selected_models, selected_metrics, selected_categories):
                if "External" in selected_metrics:
                    selected_metrics = ["execution_accuracy", "valid_efficency_score"]
                else:
                    selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"]
                # Filtro modelli e normalizzazione
                df = df[df['model'].isin(selected_models)]
                df = normalize_valid_efficency_score(df)
                df = calculate_average_metrics(df, selected_metrics)
                avg_metrics = df.groupby(['model', 'test_category'])['avg_metric'].mean().reset_index()
                if avg_metrics.empty:
                    print("Error: No data available to compute averages.")
                    return go.Figure()
                categories = selected_categories
                if len(categories) < 3:
                    # π BAR PLOT
                    fig = go.Figure()
                    for model in selected_models:
                        model_data = avg_metrics[avg_metrics['model'] == model]
                        values = [
                            model_data[model_data['test_category'] == cat]['avg_metric'].values[0]
                            if cat in model_data['test_category'].values else 0
                            for cat in categories
                        ]
                        fig.add_trace(go.Bar(
                            x=categories,
                            y=values,
                            name=model,
                            marker=dict(color=MODEL_COLORS.get(model, "gray"))
                        ))
                    fig.update_layout(
                        barmode='group',
                        title=dict(
                            text='π Bar Plot of Metrics per Model (Few Categories)',
                            font=dict(
                                family='Inter, sans-serif',
                                size=22,
                                #color='white'
                            ),
                            x=0.5
                        ),
                        template='simple_white',
                        #template='plotly_dark',
                        xaxis=dict(
                            title=dict(
                                text='Test Category',
                                font=dict(
                                    family='Inter, sans-serif',
                                    size=18,
                                    #color='white'
                                )
                            ),
                            tickfont=dict(
                                family='Inter, sans-serif',
                                size=16
                                #color='white'
                            )
                        ),
                        yaxis=dict(
                            title=dict(
                                text='Average Metrics',
                                font=dict(
                                    family='Inter, sans-serif',
                                    size=18,
                                    #color='white'
                                )
                            ),
                            tickfont=dict(
                                family='Inter, sans-serif',
                                #color='white'
                            )
                        ),
                        legend=dict(
                            title=dict(
                                text='Models',
                                font=dict(
                                    family='Inter, sans-serif',
                                    size=16,
                                    #color='white'
                                )
                            ),
                            font=dict(
                                family='Inter, sans-serif',
                                #color='white'
                            )
                        )
                    )
                else:
                    # π§ RADAR PLOT
                    fig = go.Figure()
                    for model in sorted(selected_models, key=lambda m: avg_metrics[avg_metrics['model'] == m]['avg_metric'].mean(), reverse=True):
                        model_data = avg_metrics[avg_metrics['model'] == model]
                        values = [
                            model_data[model_data['test_category'] == cat]['avg_metric'].values[0]
                            if cat in model_data['test_category'].values else 0
                            for cat in categories
                        ]
                        fig.add_trace(go.Scatterpolar(
                            r=values,
                            theta=categories,
                            fill='toself',
                            name=model,
                            line=dict(color=MODEL_COLORS.get(model, "gray"))
                        ))
                    fig.update_layout(
                        polar=dict(
                            radialaxis=dict(
                                visible=True,
                                range=[0, max(avg_metrics['avg_metric'].max(), 0.5)],
                                tickfont=dict(
                                    family='Inter, sans-serif',
                                    #color='white'
                                )
                            ),
                            angularaxis=dict(
                                tickfont=dict(
                                    family='Inter, sans-serif',
                                    size=16
                                    #color='white'
                                )
                            )
                        ),
                        title=dict(
                            text='βοΈ Radar Plot of Metrics per Model (Average per SQL Category)',
                            font=dict(
                                family='Inter, sans-serif',
                                size=22,
                                #color='white'
                            ),
                            x=0.5
                        ),
                        legend=dict(
                            title=dict(
                                text='Models',
                                font=dict(
                                    family='Inter, sans-serif',
                                    size=18,
                                    #color='white'
                                )
                            ),
                            font=dict(
                                family='Inter, sans-serif',
                                size=16
                                #color='white'
                            )
                        ),
                        template='simple_white'
                        #template='plotly_dark'
                    )
                return fig
            def update_radar(selected_models, selected_metrics, selected_categories):
                df = load_data_csv_es()
                return plot_radar(df, selected_models, selected_metrics, selected_categories)
            # RADAR OR BAR CHART FOR SUB-CATEGORIES BASED ON CATEGORY COUNT
            def plot_radar_sub(df, selected_models, selected_metrics, selected_category):
                if "External" in selected_metrics:
                    selected_metrics = ["execution_accuracy", "valid_efficency_score"]
                else:
                    selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"]
                df = df[df['model'].isin(selected_models)]
                df = normalize_valid_efficency_score(df)
                df = calculate_average_metrics(df, selected_metrics)
                if isinstance(selected_category, str):
                    selected_category = [selected_category]
                df = df[df['test_category'].isin(selected_category)]
                avg_metrics = df.groupby(['model', 'sql_tag'])['avg_metric'].mean().reset_index()
                if avg_metrics.empty:
                    print("Error: No data available to compute averages.")
                    return go.Figure()
                categories = df['sql_tag'].unique().tolist()
                if len(categories) < 3:
                    # π BAR PLOT
                    fig = go.Figure()
                    for model in selected_models:
                        model_data = avg_metrics[avg_metrics['model'] == model]
                        values = [
                            model_data[model_data['sql_tag'] == cat]['avg_metric'].values[0]
                            if cat in model_data['sql_tag'].values else 0
                            for cat in categories
                        ]
                        fig.add_trace(go.Bar(
                            x=categories,
                            y=values,
                            name=model,
                            marker=dict(color=MODEL_COLORS.get(model, "gray"))
                        ))
                    fig.update_layout(
                        barmode='group',
                        title=dict(
                            text='π Bar Plot of Metrics per Model (Few Sub-Categories)',
                            font=dict(
                                family='Inter, sans-serif',
                                size=22,
                                #color='white'
                            ),
                            x=0.5
                        ),
                        template='simple_white',
                        #template='plotly_dark',
                        xaxis=dict(
                            title=dict(
                                text='SQL Tag (Sub Category)',
                                font=dict(
                                    family='Inter, sans-serif',
                                    size=18,
                                    #color='white'
                                )
                            ),
                            tickfont=dict(
                                family='Inter, sans-serif',
                                #color='white'
                            )
                        ),
                        yaxis=dict(
                            title=dict(
                                text='Average Metrics',
                                font=dict(
                                    family='Inter, sans-serif',
                                    size=18,
                                    #color='white'
                                )
                            ),
                            tickfont=dict(
                                family='Inter, sans-serif',
                                #color='white'
                            )
                        ),
                        legend=dict(
                            title=dict(
                                text='Models',
                                font=dict(
                                    family='Inter, sans-serif',
                                    size=16,
                                    #color='white'
                                )
                            ),
                            font=dict(
                                family='Inter, sans-serif',
                                size=14
                                #color='white'
                            )
                        )
                    )
                else:
                    # π§ RADAR PLOT
                    fig = go.Figure()
                    
                    for model in sorted(selected_models, key=lambda m: avg_metrics[avg_metrics['model'] == m]['avg_metric'].mean(), reverse=True):
                        model_data = avg_metrics[avg_metrics['model'] == model]
                        values = [
                            model_data[model_data['sql_tag'] == cat]['avg_metric'].values[0]
                            if cat in model_data['sql_tag'].values else 0
                            for cat in categories
                        ]
                        fig.add_trace(go.Scatterpolar(
                            r=values,
                            theta=categories,
                            fill='toself',
                            name=model,
                            line=dict(color=MODEL_COLORS.get(model, "gray"))
                        ))
                    fig.update_layout(
                        polar=dict(
                            radialaxis=dict(
                                visible=True,
                                range=[0, max(avg_metrics['avg_metric'].max(), 0.5)],
                                tickfont=dict(
                                    family='Inter, sans-serif',
                                    #color='white'
                                )
                            ),
                            angularaxis=dict(
                                tickfont=dict(
                                    family='Inter, sans-serif',
                                    size=16
                                    #color='white'
                                )
                            )
                        ),
                        title=dict(
                            text='βοΈ Radar Plot of Metrics per Model (Average per SQL Sub-Category)',
                            font=dict(
                                family='Inter, sans-serif',
                                size=22,
                                #color='white'
                            ),
                            x=0.5
                        ),
                        legend=dict(
                            title=dict(
                                text='Models',
                                font=dict(
                                    family='Inter, sans-serif',
                                    size=16,
                                    #color='white'
                                )
                            ),
                            font=dict(
                                family='Inter, sans-serif',
                                size=14,
                                #color='white'
                            )
                        ),
                        template='simple_white'
                        #template='plotly_dark'
                    )
                return fig
            def update_radar_sub(selected_models, selected_metrics, selected_category):
                df = load_data_csv_es()
                return plot_radar_sub(df, selected_models, selected_metrics, selected_category)
            # RANKING FOR THE 3 WORST RESULTS WITH UPDATE FUNCTION
            def worst_cases_text(df, selected_models, selected_metrics, selected_categories):
                global flag_TQA
                if selected_models == "All":
                    selected_models = models
                else:
                    selected_models = [selected_models]
                
                if selected_categories == "All":
                    selected_categories = principal_categories
                else:
                    selected_categories = [selected_categories]
                
                df = df[df['model'].isin(selected_models)]
                df = df[df['test_category'].isin(selected_categories)]
                
                if "external" in selected_metrics:
                    selected_metrics = ["execution_accuracy", "valid_efficency_score"]
                else:
                    selected_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"]
                
                df = normalize_valid_efficency_score(df)
                df = calculate_average_metrics(df, selected_metrics)
                
                if flag_TQA:
                    df["target_answer"] = df["target_answer"].apply(
                        lambda x: " - ".join([",".join(map(str, item)) for item in x]) if isinstance(x, list) else str(x)
                    )
                    df["predicted_answer"] = df["predicted_answer"].apply(
                        lambda x: " - ".join([",".join(map(str, item)) for item in x]) if isinstance(x, list) else str(x)
                    )
                    
                    worst_cases_df = df.groupby(['model', 'tbl_name', 'test_category', 'question', 'target_answer', 'predicted_answer', 'answer', 'sql_tag'])['avg_metric'].mean().reset_index()
                else:
                    worst_cases_df = df.groupby(['model', 'tbl_name', 'test_category', 'question', 'query', 'predicted_sql', 'answer', 'sql_tag'])['avg_metric'].mean().reset_index()
                    
                worst_cases_df = worst_cases_df.sort_values(by="avg_metric", ascending=True).reset_index(drop=True)
                
                worst_cases_top_3 = worst_cases_df.head(3)
                worst_cases_top_3["avg_metric"] = worst_cases_top_3["avg_metric"].round(2)
                worst_str = []
                answer_str = []
                
                medals = ["π₯", "π₯", "π₯"]
                for i, row in worst_cases_top_3.iterrows():
                    if flag_TQA:
                        entry = (
                            f"{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} - {row['sql_tag']} ({row['avg_metric']})  \n"
                            f"- Question: {row['question']}  \n"
                            f"- Original Answer: `{row['target_answer']}`  \n"
                            f"- Predicted Answer: `{row['predicted_answer']}`  \n\n"
                        )
                        
                        worst_str.append(entry)
                    else:
                        entry = (
                            f"{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} - {row['sql_tag']} ({row['avg_metric']})  \n"
                            f"- Question: {row['question']}  \n"
                            f"- Original Query: `{row['query']}`  \n"
                            f"- Predicted SQL: `{row['predicted_sql']}`  \n\n"
                        )
                    
                        worst_str.append(entry)
                    
                    raw_answer = (
                        f"{medals[i]} {row['model']} - {row['tbl_name']} - {row['test_category']} - {row['sql_tag']} ({row['avg_metric']})  \n"
                        f"- Raw Answer:
 `{row['answer']}`  \n"
                    )
                    
                    answer_str.append(raw_answer)
                
                return worst_str[0], worst_str[1], worst_str[2], answer_str[0], answer_str[1], answer_str[2]
            def update_worst_cases_text(selected_models, selected_metrics, selected_categories):
                df = load_data_csv_es()
                return worst_cases_text(df, selected_models, selected_metrics, selected_categories)
            # LINE CHART FOR CUMULATIVE TIME WITH UPDATE FUNCTION
            def plot_cumulative_flow(df, selected_models, max_points):
                df = df[df['model'].isin(selected_models)]
                df = normalize_valid_efficency_score(df)
                fig = go.Figure()
                for model in selected_models:
                    model_df = df[df['model'] == model].copy()
                    
                    # Limita il numero di punti se richiesto
                    if max_points is not None:
                        model_df = model_df.head(max_points + 1)
                    
                    # Tooltip personalizzato
                    model_df['hover_info'] = model_df.apply(
                        lambda row: 
                            f"Id question: {row['number_question']}
"
                            f"Question: {row['question']}
"
                            f"Target: {row['query']}
"
                            f"Prediction: {row['predicted_sql']}
"
                            f"Category: {row['test_category']}",
                        axis=1 
                    )
                    # Calcoli cumulativi
                    model_df['cumulative_time'] = model_df['time'].cumsum()
                    model_df['cumulative_price'] = model_df['price'].cumsum()
                    # Colore del modello
                    color = MODEL_COLORS.get(model, "gray")
                    
                    fig.add_trace(go.Scatter(
                        x=model_df['cumulative_time'],
                        y=model_df['cumulative_price'],
                        mode='lines+markers',
                        name=model,
                        line=dict(width=2, color=color),
                        customdata=model_df['hover_info'],
                        hovertemplate=
                            "Model: " + model + "
" +
                            "Cumulative Time: %{x}s
" +
                            "Cumulative Price: $%{y:.2f}
" +
                            "
Details:
%{customdata}"
                    ))
                # Layout con font elegante
                fig.update_layout(
                    title=dict(
                        text="Cumulative Price Flow Chart π°",
                        font=dict(
                            family="Inter, sans-serif",
                            size=24,
                            #color="white"
                        ),
                        x=0.5
                    ),
                    xaxis=dict(
                        title=dict(
                            text="Cumulative Time (s)",
                            font=dict(
                                family="Inter, sans-serif",
                                size=20,
                                #color="white"
                            )
                        ),
                        tickfont=dict(
                            family="Inter, sans-serif",
                            size=18
                            #color="white"
                        )
                    ),
                    yaxis=dict(
                        title=dict(
                            text="Cumulative Price ($)",
                            font=dict(
                                family="Inter, sans-serif",
                                size=20,
                                #color="white"
                            )
                        ),
                        tickfont=dict(
                            family="Inter, sans-serif",
                            size=18
                            #color="white"
                        )
                    ),
                    legend=dict(
                        title=dict(
                            text="Models",
                            font=dict(
                                family="Inter, sans-serif",
                                size=18,
                                #color="white"
                            )
                        ),
                        font=dict(
                            family="Inter, sans-serif",
                            size=16,
                            #color="white"
                        )
                    ),
                    template='simple_white',
                    #template="plotly_dark"
                )
                return fig
            def update_query_rate(selected_models, max_points):
                df = load_data_csv_es()
                return plot_cumulative_flow(df, selected_models, max_points)
            #######################
            #  PARAMETER SECTION  #
            #######################
            qatch_metrics_dict = {
                "Cell Precision": "cell_precision",
                "Cell Recall": "cell_recall",
                "Tuple Order": "tuple_order",
                "Tuple Cardinality": "tuple_cardinality",
                "Tuple Constraint": "tuple_constraint"
            }
            qatch_metrics = ["cell_precision", "cell_recall", "tuple_order", "tuple_cardinality", "tuple_constraint"]
            last_valid_qatch_metrics_selection = qatch_metrics.copy()  # Per salvare lβultima selezione valida
            def enforce_qatch_metrics_selection(selected):
                global last_valid_qatch_metrics_selection
                if not selected:  # Se nessuna metrica Γ¨ selezionata
                    return gr.update(value=last_valid_qatch_metrics_selection)
                last_valid_qatch_metrics_selection = selected  # Altrimenti aggiorna la selezione valida
                return gr.update(value=selected)
            external_metrics_dict = {
                "Execution Accuracy": "execution_accuracy",
                "Valid Efficency Score": "valid_efficency_score"
            }
            external_metric = ["execution_accuracy", "valid_efficency_score"]
            last_valid_external_metric_selection = external_metric.copy() 
            def enforce_external_metric_selection(selected):
                global last_valid_external_metric_selection
                if not selected:  # Se nessuna metrica Γ¨ selezionata
                    return gr.update(value=last_valid_external_metric_selection)
                last_valid_external_metric_selection = selected  # Altrimenti aggiorna la selezione valida
                return gr.update(value=selected)
            
            all_metrics = {
                "Qatch": ["qatch"],
                "External": ["external"]
            }
            
            group_options = {
                "Table": ["tbl_name", "model"],
                "Model": ["model"]
            }
            df_initial = load_data_csv_es()
            models = models = df_initial['model'].unique().tolist()
            last_valid_model_selection = models.copy()  # Per salvare lβultima selezione valida
            def enforce_model_selection(selected):
                global last_valid_model_selection
                if not selected:  # Se nessuna metrica Γ¨ selezionata
                    return gr.update(value=last_valid_model_selection)
                last_valid_model_selection = selected  # Altrimenti aggiorna la selezione valida
                return gr.update(value=selected)
            
            all_categories = df_initial['sql_tag'].unique().tolist()
            
            principal_categories = df_initial['test_category'].unique().tolist()
            last_valid_category_selection = principal_categories.copy()  # Per salvare lβultima selezione valida
            def enforce_category_selection(selected):
                global last_valid_category_selection
                if not selected:  # Se nessuna metrica Γ¨ selezionata
                    return gr.update(value=last_valid_category_selection)
                last_valid_category_selection = selected  # Altrimenti aggiorna la selezione valida
                return gr.update(value=selected)
            
            all_categories_as_dic = {cat: [f"{cat}"] for cat in principal_categories}
            
            all_categories_as_dic_ranking = {cat: [f"{cat}"] for cat in principal_categories}
            all_categories_as_dic_ranking["All"] = principal_categories
            
            all_model_as_dic = {cat: [f"{cat}"] for cat in models}
            all_model_as_dic["All"] = models
            
            ###########################
            #  VISUALIZATION SECTION  #
            ###########################
            gr.Markdown("""# Model Performance Analysis""")
            
            #FOR BAR
            gr.Markdown("""## Section 1: Model - Data""")
           
            with gr.Row():
                with gr.Column(scale=1):
                    with gr.Row():
                        choose_metrics_bar = gr.Radio(
                            choices=list(all_metrics.keys()),
                            label="Select the metrics group that you want to use:",
                            value="Qatch"
                        )
                        
                    with gr.Row():
                        qatch_info = gr.HTML("""
                            
                                Qatch metric info βΉοΈ
                            
                        """, visible=True)
                        
                        external_info = gr.HTML("""
                            
                                External metric info βΉοΈ
                            
                        """, visible=False)
                
                qatch_metric_multiselect_bar = gr.CheckboxGroup(
                    choices=list(qatch_metrics_dict.keys()),
                    label="Select one or mode Qatch metrics:",
                    value=list(qatch_metrics_dict.keys()),
                    visible=True
                )
                external_metric_select_bar = gr.CheckboxGroup(
                    choices=list(external_metrics_dict.keys()),
                    label="Select one or more External metrics:",
                    visible=False
                )
                
                if(input_data['input_method'] == 'default'):
                    model_radio_bar =  gr.Radio(
                        choices=list(all_model_as_dic.keys()),
                        label="Select the model that you want to use:",
                        value="All"
                    )  
                else:
                    model_multiselect_bar = gr.CheckboxGroup(
                        choices=models,
                        label="Select one or more models:",
                        value=models,
                        interactive=len(models) > 1
                    )
                    
                    group_radio = gr.Radio(
                        choices=list(group_options.keys()),
                        label="Select the grouping view:",
                        value="Table"
                    )
            def toggle_metric_selector(selected_type):
                if selected_type == "Qatch":
                    return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True, value=list(qatch_metrics_dict.keys())), gr.update(visible=False, value=[])
                else:
                    return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False, value=[]), gr.update(visible=True, value=list(external_metrics_dict.keys()))
            
            output_plot = gr.Plot(visible=False)
            
            if(input_data['input_method'] == 'default'):
                with gr.Row():
                    lollipop_propietary(models)
            
            #FOR RADAR
            gr.Markdown("""## Section 2: Model - Category""")
            with gr.Row():
                all_metrics_radar = gr.Radio(
                    choices=list(all_metrics.keys()),
                    label="Select the metrics group that you want to use:",
                    value="Qatch"
                )
                
                model_multiselect_radar = gr.CheckboxGroup(
                    choices=models,
                    label="Select one or more models:",
                    value=models,
                    interactive=len(models) > 1
                )
                    
            with gr.Row():
                with gr.Column(scale=1):
                    category_multiselect_radar = gr.CheckboxGroup(
                        choices=principal_categories,
                        label="Select one or more categories:",
                        value=principal_categories
                    )
                with gr.Column(scale=1):
                    category_radio_radar = gr.Radio(
                        choices=list(all_categories_as_dic.keys()),
                        label="Select the metrics that you want to use:",
                        value=list(all_categories_as_dic.keys())[0]
                    )
                
            with gr.Row():  
                with gr.Column(scale=1):
                    radar_plot_multiselect = gr.Plot(value=update_radar(models, "Qatch", principal_categories))
                    
                with gr.Column(scale=1):
                    radar_plot_radio = gr.Plot(value=update_radar_sub(models, "Qatch", list(all_categories_as_dic.keys())[0]))
            #FOR RANKING
            with gr.Row():
                all_metrics_ranking = gr.Radio(
                    choices=list(all_metrics.keys()),
                    label="Select the metrics group that you want to use:",
                    value="Qatch"
                )
                model_choices = list(all_model_as_dic.keys())
                
                if len(model_choices) == 2:
                    model_choices = [model_choices[0]]  # supponiamo che il modello sia in prima posizione
                    selected_value = model_choices[0]
                else:
                    selected_value = "All"
                
                model_radio_ranking = gr.Radio(
                    choices=model_choices,
                    label="Select the model that you want to use:",
                    value=selected_value
                )
                
                category_radio_ranking = gr.Radio(
                    choices=list(all_categories_as_dic_ranking.keys()),
                    label="Select the category that you want to use",
                    value="All"
                )
            
            with gr.Row():
                with gr.Column(scale=1):
                    gr.Markdown("## β 3 Worst Cases\n")
                    
                    worst_first, worst_second, worst_third, raw_first, raw_second, raw_third = update_worst_cases_text("All", "Qatch", "All")
                    
                    with gr.Row():
                        first = gr.Markdown(worst_first)
                        
                    with gr.Row():
                        first_button = gr.Button("Show raw answer for π₯")
                        
                    with gr.Row():
                        second = gr.Markdown(worst_second)
                        
                    with gr.Row():
                        second_button = gr.Button("Show raw answer for π₯")
                        
                    with gr.Row():
                        third = gr.Markdown(worst_third)
                        
                    with gr.Row():
                        third_button = gr.Button("Show raw answer for π₯")
                        
                with gr.Column(scale=1):
                    gr.Markdown("""## Raw Answer""")
                    row_answer_first = gr.Markdown(value=raw_first, visible=True)
                    row_answer_second = gr.Markdown(value=raw_second, visible=False)
                    row_answer_third = gr.Markdown(value=raw_third, visible=False)
            
            #FOR RATE
            gr.Markdown("""## Section 3: Time - Price""")
            with gr.Row():
                model_multiselect_rate = gr.CheckboxGroup(
                    choices=models,
                    label="Select one or more models:",
                    value=models,
                    interactive=len(models) > 1
                )
            with gr.Row():
                slicer = gr.Slider(minimum=0, maximum=max(df_initial["number_question"]), step=1, value=max(df_initial["number_question"]), label="Number of instances to visualize", elem_id="custom-slider")
            
            query_rate_plot = gr.Plot(value=update_query_rate(models, len(df_initial["number_question"].unique())))
            #FOR RESET
            reset_data = gr.Button("Back to upload data section")
            ###############################
            #  CALLBACK FUNCTION SECTION  #
            ###############################
            
            #FOR BAR
            def on_change(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, selected_group, selected_models):
                return update_plot(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, group_options[selected_group], selected_models)
            
            def on_change_propietary(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, selected_models):
                return update_plot_propietary(radio_metric, qatch_metric_multiselect_bar, external_metric_select_bar, selected_models)
            #FOR RADAR
            def on_radar_multiselect_change(selected_models, selected_metrics, selected_categories):
                return update_radar(selected_models, selected_metrics, selected_categories)
            
            def on_radar_radio_change(selected_models, selected_metrics, selected_category):
                return update_radar_sub(selected_models, selected_metrics, selected_category)
            
            #FOR RANKING
            def on_ranking_change(selected_models, selected_metrics, selected_categories):
                return update_worst_cases_text(selected_models, selected_metrics, selected_categories)
            
            def show_first():
                return (
                    gr.update(visible=True),
                    gr.update(visible=False),
                    gr.update(visible=False)
                )
            def show_second():
                return (
                    gr.update(visible=False),
                    gr.update(visible=True),
                    gr.update(visible=False)
                )
            def show_third():
                return (
                    gr.update(visible=False),
                    gr.update(visible=False),
                    gr.update(visible=True)
                )
            
            
            
            
            ######################
            #  ON CLICK SECTION  #
            ######################
            
            #FOR BAR
            if(input_data['input_method'] == 'default'):
                proceed_to_metrics_button.click(on_change_propietary, inputs=[choose_metrics_bar, qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot)
                qatch_metric_multiselect_bar.change(on_change_propietary, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot)
                external_metric_select_bar.change(on_change_propietary, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot)
                model_radio_bar.change(on_change_propietary, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot)
                qatch_metric_multiselect_bar.change(fn=enforce_qatch_metrics_selection, inputs=qatch_metric_multiselect_bar, outputs=qatch_metric_multiselect_bar)
                choose_metrics_bar.change(fn=toggle_metric_selector, inputs=choose_metrics_bar, outputs=[qatch_info, external_info, qatch_metric_multiselect_bar, external_metric_select_bar])
                external_metric_select_bar.change(fn=enforce_external_metric_selection, inputs=external_metric_select_bar, outputs=external_metric_select_bar)
            
            else:
                proceed_to_metrics_button.click(on_change, inputs=[choose_metrics_bar, qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot)
                qatch_metric_multiselect_bar.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot)
                external_metric_select_bar.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot)
                group_radio.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot)
                model_multiselect_bar.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot)
                qatch_metric_multiselect_bar.change(fn=enforce_qatch_metrics_selection, inputs=qatch_metric_multiselect_bar, outputs=qatch_metric_multiselect_bar)
                model_multiselect_bar.change(fn=enforce_model_selection, inputs=model_multiselect_bar, outputs=model_multiselect_bar)
                choose_metrics_bar.change(fn=toggle_metric_selector, inputs=choose_metrics_bar, outputs=[qatch_info, external_info, qatch_metric_multiselect_bar, external_metric_select_bar])
                external_metric_select_bar.change(fn=enforce_external_metric_selection, inputs=external_metric_select_bar, outputs=external_metric_select_bar)
            #FOR RADAR MULTISELECT
            model_multiselect_radar.change(on_radar_multiselect_change, inputs=[model_multiselect_radar, all_metrics_radar, category_multiselect_radar], outputs=radar_plot_multiselect)
            all_metrics_radar.change(on_radar_multiselect_change, inputs=[model_multiselect_radar, all_metrics_radar, category_multiselect_radar], outputs=radar_plot_multiselect)
            category_multiselect_radar.change(on_radar_multiselect_change, inputs=[model_multiselect_radar, all_metrics_radar, category_multiselect_radar], outputs=radar_plot_multiselect)
            model_multiselect_radar.change(fn=enforce_model_selection, inputs=model_multiselect_radar, outputs=model_multiselect_radar)
            category_multiselect_radar.change(fn=enforce_category_selection, inputs=category_multiselect_radar, outputs=category_multiselect_radar)
            
            #FOR RADAR RADIO
            model_multiselect_radar.change(on_radar_radio_change, inputs=[model_multiselect_radar, all_metrics_radar, category_radio_radar], outputs=radar_plot_radio)
            all_metrics_radar.change(on_radar_radio_change, inputs=[model_multiselect_radar, all_metrics_radar, category_radio_radar], outputs=radar_plot_radio)
            category_radio_radar.change(on_radar_radio_change, inputs=[model_multiselect_radar, all_metrics_radar, category_radio_radar], outputs=radar_plot_radio)
            
            #FOR RANKING
            model_radio_ranking.change(on_ranking_change, inputs=[model_radio_ranking, all_metrics_ranking, category_radio_ranking], outputs=[first, second, third, row_answer_first, row_answer_second, row_answer_third])
            model_radio_ranking.change(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third])
            all_metrics_ranking.change(on_ranking_change, inputs=[model_radio_ranking, all_metrics_ranking, category_radio_ranking], outputs=[first, second, third, row_answer_first, row_answer_second, row_answer_third])
            all_metrics_ranking.change(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third])
            category_radio_ranking.change(on_ranking_change, inputs=[model_radio_ranking, all_metrics_ranking, category_radio_ranking], outputs=[first, second, third, row_answer_first, row_answer_second, row_answer_third])
            category_radio_ranking.change(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third])
            model_radio_ranking.change(fn=enforce_model_selection, inputs=model_radio_ranking, outputs=model_radio_ranking)
            category_radio_ranking.change(fn=enforce_category_selection, inputs=category_radio_ranking, outputs=category_radio_ranking)
            first_button.click(fn=show_first, outputs=[row_answer_first, row_answer_second, row_answer_third])
            second_button.click(fn=show_second, outputs=[row_answer_first, row_answer_second, row_answer_third])
            third_button.click(fn=show_third, outputs=[row_answer_first, row_answer_second, row_answer_third])
            
            #FOR RATE
            model_multiselect_rate.change(update_query_rate, inputs=[model_multiselect_rate, slicer], outputs=query_rate_plot)
            proceed_to_metrics_button.click(update_query_rate, inputs=[model_multiselect_rate, slicer], outputs=query_rate_plot)
            model_multiselect_rate.change(fn=enforce_model_selection, inputs=model_multiselect_rate, outputs=model_multiselect_rate)
            slicer.change(update_query_rate, inputs=[model_multiselect_rate, slicer], outputs=query_rate_plot)
            #FOR RESET   
            reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])         
            reset_data.click(fn=lambda: gr.update(visible=False), outputs=[download_metrics])
            reset_data.click(fn=enable_disable, inputs=[gr.State(True)], outputs=[*model_checkboxes, submit_models_button, preview_output, submit_button, file_input, default_checkbox, table_selector, *table_outputs, open_model_selection])
                   
interface.launch(share = True)