simone-papicchio franceth commited on
Commit
32b6873
·
verified ·
1 Parent(s): d44b620

Css, Prop vs Non-Prop, Metrics Update (#11)

Browse files

- Css, Prop vs Non-Prop, Metrics Update (dee3b00dfdac45a04dba18cb8e760c2c5d0bc4d9)


Co-authored-by: Francesco Giannuzzo <[email protected]>

.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ evaluation_p_np_metrics.csv filter=lfs diff=lfs merge=lfs -text
37
+ qatch_logo.png filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,9 +1,6 @@
1
  import gradio as gr
2
  import pandas as pd
3
  import os
4
- import re
5
- import csv
6
- import time
7
  # # https://discuss.huggingface.co/t/issues-with-sadtalker-zerogpu-spaces-inquiry-about-community-grant/110625/10
8
  # if os.environ.get("SPACES_ZERO_GPU") is not None:
9
  # import spaces
@@ -19,17 +16,30 @@ from qatch.connectors.sqlite_connector import SqliteConnector
19
  from qatch.generate_dataset.orchestrator_generator import OrchestratorGenerator
20
  from qatch.evaluate_dataset.orchestrator_evaluator import OrchestratorEvaluator
21
  from prediction import ModelPrediction
22
- from utils_get_db_tables_info import utils_extract_db_schema_as_string
23
  import utilities as us
 
24
  import plotly.express as px
25
  import plotly.graph_objects as go
26
  import plotly.colors as pc
27
-
28
- pnp_path = os.path.join("data", "evaluation_p_metrics.csv")
29
-
30
- us.check_and_create_dir('data/data_interface/')
31
- us.check_and_create_dir('data/data_results/')
32
- us.check_and_create_dir('data/databases/')
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  with open('style.css', 'r') as file:
35
  css = file.read()
@@ -41,11 +51,16 @@ df_default = pd.DataFrame({
41
  'City': ['New York', 'Los Angeles', 'Chicago']
42
  })
43
 
44
- models_path = "models.csv"
45
 
46
  # Variabile globale per tenere traccia dei dati correnti
47
  df_current = df_default.copy()
48
 
 
 
 
 
 
49
  input_data = {
50
  'input_method': "",
51
  'data_path': "",
@@ -56,7 +71,7 @@ input_data = {
56
  'selected_tables' :[]
57
  },
58
  'models': [],
59
- 'prompt': "{question} {schema}"
60
  }
61
 
62
  def load_data(file, path, use_default):
@@ -66,7 +81,8 @@ def load_data(file, path, use_default):
66
  try:
67
  input_data["input_method"] = 'uploaded_file'
68
  input_data["db_name"] = os.path.splitext(os.path.basename(file))[0]
69
- input_data["data_path"] = os.path.join(".", "data", "data_interface",f"{input_data['db_name']}.sqlite")
 
70
  input_data["data"] = us.load_data(file, input_data["db_name"])
71
  df_current = input_data["data"]['data_frames'].get('MyTable', df_default) # Carica il DataFrame
72
  if(input_data["data"]['data_frames'] and input_data["data"]["db"] is None): #for csv and xlsx files
@@ -86,7 +102,8 @@ def load_data(file, path, use_default):
86
  if use_default:
87
  if(use_default == 'Custom'):
88
  input_data["input_method"] = 'custom'
89
- input_data["data_path"] = os.path.join(".", "data", "data_interface", "mytable_0.sqlite")
 
90
  #if file already exist
91
  while os.path.exists(input_data["data_path"]):
92
  input_data["data_path"] = us.increment_filename(input_data["data_path"])
@@ -122,11 +139,14 @@ def load_data(file, path, use_default):
122
 
123
  return input_data["data"]['data_frames']
124
 
125
- def preview_default(use_default):
126
- if use_default == 'Custom':
127
- return gr.DataFrame(interactive=True, visible = True, value = df_default), gr.update(visible=False)
128
- else:
129
- return gr.DataFrame(interactive=False, visible = False, value = df_default), gr.update(visible=True)
 
 
 
130
  #return gr.DataFrame(interactive=True, value = df_current) # Mostra il DataFrame corrente, che potrebbe essere stato modificato
131
 
132
  def update_df(new_df):
@@ -151,41 +171,43 @@ def open_accordion(target):
151
  return gr.update(open=False), gr.update(open=False), gr.update(open=True, visible=True), gr.update(open=False), gr.update(open=False)
152
 
153
  # Interfaccia Gradio
154
- with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
155
- #with gr.Blocks(theme='NoCrypt/miku/light', css_paths='style.css') as interface:
156
  with gr.Row():
157
- gr.Column(scale=1)
158
- gr.Image(
159
- value="https://github.com/CristianDegni01/Automatic-LLM-Benchmark-Analysis-for-Text2SQL-GRADIO/blob/master/models_logo/QATCH.png?raw=true",
160
- show_label=False,
161
- container=False,
162
- height=200, # in pixel
163
- width=400
164
- )
165
- gr.Column(scale=1)
166
-
 
 
 
 
167
  data_state = gr.State(None) # Memorizza i dati caricati
168
- upload_acc = gr.Accordion("Upload your data section", open=True, visible=True)
169
- select_table_acc = gr.Accordion("Select tables", open=False, visible=False)
170
- select_model_acc = gr.Accordion("Select models", open=False, visible=False)
171
- qatch_acc = gr.Accordion("QATCH execution", open=False, visible=False)
172
- metrics_acc = gr.Accordion("Metrics", open=False, visible=False)
173
 
174
  #################################
175
  # DATABASE INSERTION #
176
  #################################
177
  with upload_acc:
178
- gr.Markdown("## Choose data input method")
179
  with gr.Row():
180
- default_checkbox = gr.Radio(label = "Use default DataFrame or costume one table", choices=['Proprietary vs Non-proprietary', 'Custom'], value='Proprietary vs Non-proprietary')
181
- #default_checkbox = gr.Checkbox(label="Use default DataFrame")
 
 
182
  preview_output = gr.DataFrame(interactive=False, visible=False, value=df_default)
183
- description = """## Comparison of proprietary and non-proprietary databases
184
- - Proprietary (Economic, Medical, Financial, Miscellaneous)
185
- - Non-proprietary (Spider 1.0)"""
186
 
187
- table_default = gr.Markdown(description, visible=True)
188
- gr.Markdown("## Or upload your data")
189
  file_input = gr.File(label="Drag and drop a file", file_types=[".csv", ".xlsx", ".sqlite"])
190
  submit_button = gr.Button("Load Data") # Disabled by default
191
  output = gr.JSON(visible=False) # Dictionary output
@@ -213,7 +235,8 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
213
  #default_checkbox.change(fn=enable_submit, inputs=[file_input, default_checkbox], outputs=[submit_button])
214
 
215
  # Show preview of the default DataFrame when checkbox is selected
216
- default_checkbox.change(fn=preview_default, inputs=[default_checkbox], outputs=[preview_output, table_default])
 
217
  preview_output.change(fn=update_df, inputs=[preview_output], outputs=[preview_output])
218
 
219
  # Uncheck the checkbox when a file is uploaded
@@ -277,8 +300,8 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
277
  # TABLE SELECTION PART #
278
  ######################################
279
  with select_table_acc:
280
- table_selector = gr.CheckboxGroup(choices=[], label="Select tables to display", value=[])
281
- table_outputs = [gr.DataFrame(label=f"Table {i+1}", interactive=True, visible=False) for i in range(10)]
282
  selected_table_names = gr.Textbox(label="Selected tables", visible=False, interactive=False)
283
 
284
  # Model selection button (initially disabled)
@@ -287,7 +310,9 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
287
  def update_table_list(data):
288
  """Dynamically updates the list of available tables."""
289
  if isinstance(data, dict) and data:
290
- table_names = list(data.keys()) # Return only the table names
 
 
291
  return gr.update(choices=table_names, value=[]) # Reset selections
292
  return gr.update(choices=[], value=[])
293
 
@@ -295,19 +320,23 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
295
  """Displays only the tables selected by the user and enables the button."""
296
  updates = []
297
  if isinstance(data, dict) and data:
 
298
  available_tables = list(data.keys()) # Actually available names
299
- selected_tables = [t for t in selected_tables if t in available_tables] # Filter valid selections
 
 
 
300
 
301
  tables = {name: data[name] for name in selected_tables} # Filter the DataFrames
302
 
303
  for i, (name, df) in enumerate(tables.items()):
304
- updates.append(gr.update(value=df, label=f"Table: {name}", visible=True))
305
 
306
  # If there are fewer than 5 tables, hide the other DataFrames
307
- for _ in range(len(tables), 10):
308
  updates.append(gr.update(visible=False))
309
  else:
310
- updates = [gr.update(value=pd.DataFrame(), visible=False) for _ in range(10)]
311
 
312
  # Enable/disable the button based on selections
313
  button_state = bool(selected_tables) # True if at least one table is selected, False otherwise
@@ -315,9 +344,12 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
315
 
316
  return updates
317
 
318
- def show_selected_table_names(selected_tables):
319
  """Displays the names of the selected tables when the button is pressed."""
320
  if selected_tables:
 
 
 
321
  input_data['data']['selected_tables'] = selected_tables
322
  return gr.update(value=", ".join(selected_tables), visible=False)
323
  return gr.update(value="", visible=False)
@@ -329,7 +361,7 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
329
  table_selector.change(fn=show_selected_tables, inputs=[data_state, table_selector], outputs=table_outputs + [open_model_selection])
330
 
331
  # Shows the list of selected tables when "Choose your models" is clicked
332
- open_model_selection.click(fn=show_selected_table_names, inputs=[table_selector], outputs=[selected_table_names])
333
  open_model_selection.click(open_accordion, inputs=gr.State("model_selection"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc])
334
 
335
  reset_data = gr.Button("Back to upload data section")
@@ -352,12 +384,16 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
352
  # MODEL SELECTION PART #
353
  ####################################
354
  with select_model_acc:
355
- gr.Markdown("**Model Selection**")
356
 
357
  # Assume that `us.read_models_csv` also returns the image path
358
  model_list_dict = us.read_models_csv(models_path)
359
  model_list = [model["code"] for model in model_list_dict]
360
  model_images = [model["image_path"] for model in model_list_dict]
 
 
 
 
361
 
362
  model_checkboxes = []
363
  rows = []
@@ -371,25 +407,35 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
371
  model = model_list[i + j]
372
  image_path = model_images[i + j]
373
  with gr.Column():
374
- gr.Image(image_path, show_label=False)
375
- checkbox = gr.Checkbox(label=model, value=False)
 
 
 
 
 
 
376
  model_checkboxes.append(checkbox)
377
  cols.append(checkbox)
378
  rows.append(cols)
379
 
380
- selected_models_output = gr.JSON(visible=True)
381
 
382
  # Function to get selected models
383
  def get_selected_models(*model_selections):
384
  selected_models = [model for model, selected in zip(model_list, model_selections) if selected]
 
385
  input_data['models'] = selected_models
386
- button_state = bool(selected_models and '{schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
387
  return selected_models, gr.update(open=True, visible=True), gr.update(interactive=button_state)
388
 
389
  # Add the Textbox to the interface
390
- prompt = gr.TextArea(label="Customise the prompt for selected models here or leave the default one . The prompt must contain {question} and {schema} which will be automatically replaced during SQL generation.",
391
- placeholder='Default prompt with a {question} and db {schema} are to be specified')
392
- warning_prompt = gr.Markdown(value="# Error in the prompt format", visible=False)
 
 
 
393
 
394
  # Submit button (initially disabled)
395
 
@@ -399,17 +445,21 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
399
  #TODO
400
  missing_elements = []
401
  if(prompt==""):
402
- input_data["prompt"]="{question} {schema}"
403
- button_state = bool(len(input_data['models']) > 0 and '{schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
404
  else:
405
  input_data["prompt"]=prompt
406
- if "{schema}" not in prompt:
407
- missing_elements.append("{schema}")
408
  if "{question}" not in prompt:
409
  missing_elements.append("{question}")
410
- button_state = bool(len(input_data['models']) > 0 and '{schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
411
  if missing_elements:
412
- return gr.update(value=f"## ❌ Missing {', '.join(missing_elements)} in the prompt ❌", visible=True), gr.update(interactive=button_state)
 
 
 
 
413
  return gr.update(visible=False), gr.update(interactive=button_state)
414
 
415
  prompt.change(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button])
@@ -490,14 +540,14 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
490
 
491
  loading_symbols= {1:"𓆟",
492
  2: "𓆞 𓆟",
493
- 3: "𓆟 𓆞 𓆟",
494
- 4: "𓆞 𓆟 𓆞 𓆟",
495
- 5: "𓆟 𓆞 𓆟 𓆞 𓆟",
496
- 6: "𓆞 𓆟 𓆞 𓆟 𓆞 𓆟",
497
- 7: "𓆟 𓆞 𓆟 𓆞 𓆟 𓆞 𓆟",
498
- 8: "𓆞 𓆟 𓆞 𓆟 𓆞 𓆟 𓆞 𓆟",
499
- 9: "𓆟 𓆞 𓆟 𓆞 𓆟 𓆞 𓆟 𓆞 𓆟",
500
- 10:"𓆞 𓆟 𓆞 𓆟 𓆞 𓆟 𓆞 𓆟 𓆞 𓆟",
501
  }
502
 
503
  def generate_loading_text(percent):
@@ -508,7 +558,7 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
508
  return f"""
509
  <div class='barcontainer'>
510
  {css_symbols}
511
- <span class='loading' style="font-family: 'Playfair Display', serif;">
512
  Generation {percent}%
513
  </span>
514
  {mirrored_symbols}
@@ -521,15 +571,17 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
521
  #caching
522
  predictions_dict = {model: pd.DataFrame(columns=['id', 'question', 'predicted_sql', 'time', 'query', 'db_path']) for model in model_list}
523
  metrics_conc = pd.DataFrame()
 
524
  if (input_data['input_method']=="default"):
525
  target_df = us.load_csv(pnp_path) #target_df = us.load_csv("priority_non_priority_metrics.csv")
526
  #predictions_dict = {model: pd.DataFrame(columns=target_df.columns) for model in model_list}
527
  target_df = target_df[target_df["tbl_name"].isin(input_data['data']['selected_tables'])]
528
  target_df = target_df[target_df["model"].isin(input_data['models'])]
529
  predictions_dict = {model: target_df[target_df["model"] == model] if model in target_df["model"].unique() else pd.DataFrame(columns=target_df.columns) for model in model_list}
530
- for model in target_df["model"].unique():
 
531
  model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None)
532
- yield gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
533
  count=1
534
  for _, row in predictions_dict[model].iterrows():
535
  #for index, row in target_df.iterrows():
@@ -538,51 +590,36 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
538
  load_text = f"{generate_loading_text(percent_complete)}"
539
  question = row['question']
540
 
541
- # display_question = f"""
542
- # <div class='loading' style="font-size: 1.7rem; font-family: 'Playfair Display', serif;">
543
- # Natural Language:
544
- # </div>
545
- # <div class='sqlquery' style="font-family: 'Playfair Display', serif;">
546
- # {row['question']}
547
- # </div>
548
- # """
549
- display_question = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Playfair Display', serif;'>Natural Language:</div>
550
  <div style='display: flex; align-items: center;'>
551
- <div class='sqlquery' font-family: 'Playfair Display', serif;>{question}</div>
552
  <div style='font-size: 3rem'>➡️</div>
553
  </div>
554
  """
555
- yield gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
556
  #time.sleep(0.02)
557
  prediction = row['predicted_sql']
558
 
559
- # display_prediction = f"""
560
- # <div class='loading' style="font-size: 1.7rem; font-family: 'Playfair Display', serif;">
561
- # Generated SQL:
562
- # </div>
563
- # <div class='sqlquery' style="font-family: 'Playfair Display', serif;">
564
- # {prediction}
565
- # </div>
566
- # """
567
- display_prediction = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Playfair Display', serif;'>Natural Language:</div>
568
  <div style='display: flex; align-items: center;'>
569
  <div style='font-size: 3rem'>➡️</div>
570
- <div class='sqlquery' font-family: 'Playfair Display', serif;>{prediction}</div>
571
  </div>
572
  """
573
- yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list]
574
- yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list]
575
  metrics_conc = target_df
576
  if 'valid_efficiency_score' not in metrics_conc.columns:
577
  metrics_conc['valid_efficiency_score'] = metrics_conc['VES']
578
- yield gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
579
  else:
580
 
581
  orchestrator_generator = OrchestratorGenerator()
582
  # TODO: add to target_df column target_df["columns_used"], tables selection
583
  # print(input_data['data']['db'])
584
  #print(input_data['data']['selected_tables'])
585
- target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_includes=input_data['data']['selected_tables'])
 
586
  #target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_includes=None)
587
 
588
  schema_text = utils_get_db_tables_info.utils_extract_db_schema_as_string(
@@ -604,14 +641,13 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
604
  load_text = f"{generate_loading_text(percent_complete)}"
605
 
606
  question = row['question']
607
- #display_question = f"<div class='loading' style ='font-size: 1.7rem;'>Natural Language: </div> <div class='sqlquery'>{row['question']}</div>"
608
- display_question = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Playfair Display', serif;'>Natural Language:</div>
609
  <div style='display: flex; align-items: center;'>
610
- <div class='sqlquery' font-family: 'Playfair Display', serif;>{question}</div>
611
  <div style='font-size: 3rem'>➡️</div>
612
  </div>
613
- """
614
- yield gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
615
  start_time = time.time()
616
  samples = us.generate_some_samples(input_data['data']['db'], row["tbl_name"])
617
  prompt_to_send = us.prepare_prompt(input_data["prompt"], question, schema_text, samples)
@@ -622,11 +658,10 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
622
  answer = "Answer"#response[response]
623
 
624
  end_time = time.time()
625
- #display_prediction = f"<div class='loading' style ='font-size: 1.7rem;'>Generated SQL: </div><div class='sqlquery'>{prediction}</div>"
626
- display_prediction = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Playfair Display', serif;'>Natural Language:</div>
627
  <div style='display: flex; align-items: center;'>
628
  <div style='font-size: 3rem'>➡️</div>
629
- <div class='sqlquery' font-family: 'Playfair Display', serif;>{prediction}</div>
630
  </div>
631
  """
632
  # Create a new row as dataframe
@@ -652,7 +687,7 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
652
  predictions_dict[model] = pd.concat([predictions_dict[model], new_row], ignore_index=True)
653
 
654
  # yield gr.Textbox(), gr.Textbox(prediction), *[predictions_dict[model] for model in input_data["models"]], None
655
- yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list]
656
 
657
  yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list]
658
  # END
@@ -683,17 +718,33 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
683
  with gr.Column():
684
  question_display = gr.Markdown()
685
  with gr.Column():
686
- model_logo = gr.Image(visible=True, show_label=False)
 
 
 
 
 
 
687
  with gr.Column():
688
  with gr.Column():
689
  prediction_display = gr.Markdown()
690
-
 
 
691
  dataframe_per_model = {}
692
 
693
  with gr.Tabs() as model_tabs:
694
  tab_dict = {}
695
- for model in model_list:
696
- with gr.TabItem(model, visible=(model in input_data["models"])) as tab:
 
 
 
 
 
 
 
 
697
  gr.Markdown(f"**Results for {model}**")
698
  tab_dict[model] = tab
699
  dataframe_per_model[model] = gr.DataFrame()
@@ -726,16 +777,17 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
726
  # Works for METRICS
727
  metrics_df.change(fn=change_text, inputs=[metrics_df], outputs=[metrics_df_out])
728
 
729
- proceed_to_metrics_button = gr.Button("Proceed to Metrics")
730
  proceed_to_metrics_button.click(
731
  fn=lambda: (gr.update(open=False, visible=True), gr.update(open=True, visible=True)),
732
  outputs=[qatch_acc, metrics_acc]
733
  )
734
 
735
  def allow_download(metrics_df_out):
736
- path = os.path.join(".", "data", "data_results", "results.csv")
 
737
  metrics_df_out.to_csv(path, index=False)
738
- return gr.update(value=path, visible=True)
739
 
740
  download_metrics = gr.DownloadButton(label="Download Metrics Evaluation", visible=False)
741
 
@@ -748,9 +800,10 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
748
  # fn=lambda: gr.update(open=True, visible=True),
749
  # outputs=[download_metrics]
750
  # )
751
- metrics_df_out.change(fn=allow_download, inputs=[metrics_df_out], outputs=[download_metrics])
 
 
752
 
753
- reset_data = gr.Button("Back to upload data section")
754
  reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
755
  #WHY NOT WORKING?
756
  reset_data.click(
@@ -773,16 +826,12 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
773
  open_model_selection
774
  ]
775
  )
776
-
777
-
778
-
779
-
780
  ##########################################
781
  # METRICS VISUALIZATION SECTION #
782
  ##########################################
783
  with metrics_acc:
784
  #data_path = 'test_results_metrics1.csv'
785
- data_path = '/Users/francescogiannuzzo/Desktop/EURECOM/semester_project_gradio_git/Automatic-LLM-Benchmark-Analysis-for-Text2SQL-GRADIO/data/evaluation_p_metrics.csv'
786
 
787
  @gr.render(inputs=metrics_df_out)
788
  def function_metrics(metrics_df_out):
@@ -794,6 +843,16 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
794
  def load_data_csv_es():
795
  #return pd.read_csv(data_path)
796
  #print("---------------->",metrics_df_out)
 
 
 
 
 
 
 
 
 
 
797
  return metrics_df_out
798
 
799
  def calculate_average_metrics(df, selected_metrics):
@@ -812,7 +871,8 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
812
  num_models = len(unique_models)
813
 
814
  # Use the Plotly color scale (you can change it if needed)
815
- color_palette = pc.qualitative.Plotly # ['#636EFA', '#EF553B', '#00CC96', ...]
 
816
 
817
  # If there are more models than colors, cycle through them
818
  colors = {model: color_palette[i % len(color_palette)] for i, model in enumerate(unique_models)}
@@ -837,7 +897,7 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
837
  def normalize_valid_efficiency_score(df):
838
  #TODO valid_efficiency_score
839
  #print(df['valid_efficiency_score'])
840
- df['valid_efficiency_score'] = df['valid_efficiency_score'].replace('', 0)
841
  df['valid_efficiency_score'] = df['valid_efficiency_score'].astype(int)
842
  min_val = df['valid_efficiency_score'].min()
843
  max_val = df['valid_efficiency_score'].max()
@@ -853,8 +913,6 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
853
  return df
854
 
855
 
856
-
857
-
858
  ####################################
859
  # GRAPH FUNCTIONS SECTION #
860
  ####################################
@@ -883,9 +941,10 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
883
  y="avg_metric",
884
  color="model",
885
  color_discrete_map=MODEL_COLORS,
886
- title='Average metric per Model 🧠',
887
- labels={"model": "Model", "avg_metric": "Average Metric"},
888
- template='plotly_dark',
 
889
  text='text_label'
890
  )
891
  else:
@@ -902,62 +961,64 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
902
  color='model',
903
  color_discrete_map=MODEL_COLORS,
904
  barmode='group',
905
- title=f'Average metric per {group_by[0]} 📊',
906
- labels={group_by[0]: group_by[0].capitalize(), 'avg_metric': 'Average Metric'},
907
- template='plotly_dark',
 
908
  text='text_label'
909
  )
910
 
911
  fig.update_traces(textposition='outside', textfont_size=10)
912
 
913
- # font Playfair Display
914
  fig.update_layout(
915
  margin=dict(t=80),
916
  title=dict(
917
  font=dict(
918
- family="Playfair Display, serif",
919
  size=22,
920
- color="white"
921
  ),
922
  x=0.5
923
  ),
924
  xaxis=dict(
925
  title=dict(
926
  font=dict(
927
- family="Playfair Display, serif",
928
- size=16,
929
- color="white"
930
  )
931
  ),
932
  tickfont=dict(
933
- family="Playfair Display, serif",
934
- color="white"
 
935
  )
936
  ),
937
  yaxis=dict(
938
  title=dict(
939
  font=dict(
940
- family="Playfair Display, serif",
941
- size=16,
942
- color="white"
943
  )
944
  ),
945
  tickfont=dict(
946
- family="Playfair Display, serif",
947
- color="white"
948
  )
949
  ),
950
  legend=dict(
951
  title=dict(
952
  font=dict(
953
- family="Playfair Display, serif",
954
- size=14,
955
- color="white"
956
  )
957
  ),
958
  font=dict(
959
- family="Playfair Display, serif",
960
- color="white"
961
  )
962
  )
963
  )
@@ -988,7 +1049,7 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
988
 
989
  avg_metrics = df.groupby(["db_category", "model"])['avg_metric'].mean().reset_index()
990
  avg_metrics['text_label'] = avg_metrics['avg_metric'].apply(lambda x: f'{x:.2f}')
991
-
992
  fig = px.bar(
993
  avg_metrics,
994
  x='db_category',
@@ -996,50 +1057,51 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
996
  color='model',
997
  color_discrete_map=MODEL_COLORS,
998
  barmode='group',
999
- title='Average metric per db_category 📊',
1000
- labels={'db_path': 'DB Path', 'avg_metric': 'Average Metric'},
1001
  template='simple_white',
1002
  text='text_label'
1003
  )
1004
 
1005
- fig.update_traces(textposition='outside', textfont_size=10)
1006
 
1007
- #Playfair Display
1008
  fig.update_layout(
1009
  margin=dict(t=80),
1010
  title=dict(
1011
  font=dict(
1012
- family="Playfair Display, serif",
1013
- size=22,
1014
  color="black"
1015
  ),
1016
  x=0.5
1017
  ),
1018
  xaxis=dict(
1019
  title=dict(
1020
- text='DB Category',
1021
  font=dict(
1022
- family='Playfair Display, serif',
1023
- size=16,
1024
  color='black'
1025
  )
1026
  ),
1027
  tickfont=dict(
1028
- family='Playfair Display, serif',
1029
- color='black'
 
1030
  )
1031
  ),
1032
  yaxis=dict(
1033
  title=dict(
1034
- text='Average Metric',
1035
  font=dict(
1036
- family='Playfair Display, serif',
1037
- size=16,
1038
  color='black'
1039
  )
1040
  ),
1041
  tickfont=dict(
1042
- family='Playfair Display, serif',
1043
  color='black'
1044
  )
1045
  ),
@@ -1047,14 +1109,15 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
1047
  title=dict(
1048
  text='Models',
1049
  font=dict(
1050
- family='Playfair Display, serif',
1051
- size=14,
1052
  color='black'
1053
  )
1054
  ),
1055
  font=dict(
1056
- family='Playfair Display, serif',
1057
- color='black'
 
1058
  )
1059
  )
1060
  )
@@ -1116,12 +1179,13 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
1116
 
1117
  # BAR CHART FOR PROPIETARY DATASET WITH AVERAGE METRICS WITH UPDATE FUNCTION
1118
 
1119
- def lollipop_propietary():
1120
  df = load_data_csv_es()
1121
 
1122
  # Filtra solo le categorie rilevanti
1123
  target_cats = ["Spider", "Economic", "Financial", "Medical", "Miscellaneous"]
1124
  df = df[df['db_category'].isin(target_cats)]
 
1125
 
1126
  df = normalize_valid_efficiency_score(df)
1127
  df = calculate_average_metrics(df, qatch_metrics)
@@ -1164,8 +1228,8 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
1164
  x=merged_df["Spider"],
1165
  y=merged_df["model"],
1166
  mode='markers',
1167
- name='Spider',
1168
- marker=dict(size=10, color='red')
1169
  ))
1170
 
1171
  # Punto per Others (media delle altre 4 categorie)
@@ -1173,19 +1237,70 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
1173
  x=merged_df["Others"],
1174
  y=merged_df["model"],
1175
  mode='markers',
1176
- name='Others Avg',
1177
- marker=dict(size=10, color='blue')
1178
  ))
1179
 
1180
  fig.update_layout(
1181
- title='Dot-Range Plot: Spider vs Altri 🕷️📊',
1182
- xaxis_title='Average Metric',
1183
- yaxis_title='Model',
1184
  template='simple_white',
1185
  #template='plotly_dark',
1186
  margin=dict(t=80),
1187
- legend_title='Categoria',
1188
- height=600
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1189
  )
1190
 
1191
  return gr.Plot(fig, visible=True)
@@ -1233,64 +1348,79 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
1233
  title=dict(
1234
  text='📊 Bar Plot of Metrics per Model (Few Categories)',
1235
  font=dict(
1236
- family='Playfair Display, serif',
1237
  size=22,
1238
- color='white'
1239
  ),
1240
  x=0.5
1241
  ),
1242
- template='plotly_dark',
 
1243
  xaxis=dict(
1244
  title=dict(
1245
  text='Test Category',
1246
  font=dict(
1247
- family='Playfair Display, serif',
1248
- size=16,
1249
- color='white'
1250
  )
1251
  ),
1252
  tickfont=dict(
1253
- family='Playfair Display, serif',
1254
- color='white'
 
1255
  )
1256
  ),
1257
  yaxis=dict(
1258
  title=dict(
1259
- text='Average Metric',
1260
  font=dict(
1261
- family='Playfair Display, serif',
1262
- size=16,
1263
- color='white'
1264
  )
1265
  ),
1266
  tickfont=dict(
1267
- family='Playfair Display, serif',
1268
- color='white'
1269
  )
1270
  ),
1271
  legend=dict(
1272
  title=dict(
1273
  text='Models',
1274
  font=dict(
1275
- family='Playfair Display, serif',
1276
- size=14,
1277
- color='white'
1278
  )
1279
  ),
1280
  font=dict(
1281
- family='Playfair Display, serif',
1282
- color='white'
1283
  )
1284
  )
1285
  )
1286
  else:
1287
  # 🧭 RADAR PLOT
1288
  fig = go.Figure()
1289
- for model in selected_models:
1290
  model_data = avg_metrics[avg_metrics['model'] == model]
 
 
 
 
 
1291
  values = [
1292
- model_data[model_data['test_category'] == cat]['avg_metric'].values[0]
1293
- if cat in model_data['test_category'].values else 0
 
 
 
 
 
 
 
 
1294
  for cat in categories
1295
  ]
1296
  fig.add_trace(go.Scatterpolar(
@@ -1307,23 +1437,24 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
1307
  visible=True,
1308
  range=[0, max(avg_metrics['avg_metric'].max(), 0.5)],
1309
  tickfont=dict(
1310
- family='Playfair Display, serif',
1311
- color='white'
1312
  )
1313
  ),
1314
  angularaxis=dict(
1315
  tickfont=dict(
1316
- family='Playfair Display, serif',
1317
- color='white'
 
1318
  )
1319
  )
1320
  ),
1321
  title=dict(
1322
- text='❇️ Radar Plot of Metrics per Model (Average per Category)',
1323
  font=dict(
1324
- family='Playfair Display, serif',
1325
  size=22,
1326
- color='white'
1327
  ),
1328
  x=0.5
1329
  ),
@@ -1331,17 +1462,19 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
1331
  title=dict(
1332
  text='Models',
1333
  font=dict(
1334
- family='Playfair Display, serif',
1335
- size=14,
1336
- color='white'
1337
  )
1338
  ),
1339
  font=dict(
1340
- family='Playfair Display, serif',
1341
- color='white'
 
1342
  )
1343
  ),
1344
- template='plotly_dark'
 
1345
  )
1346
 
1347
  return fig
@@ -1395,60 +1528,63 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
1395
  title=dict(
1396
  text='📊 Bar Plot of Metrics per Model (Few Sub-Categories)',
1397
  font=dict(
1398
- family='Playfair Display, serif',
1399
  size=22,
1400
- color='white'
1401
  ),
1402
  x=0.5
1403
  ),
1404
- template='plotly_dark',
 
1405
  xaxis=dict(
1406
  title=dict(
1407
  text='SQL Tag (Sub Category)',
1408
  font=dict(
1409
- family='Playfair Display, serif',
1410
- size=16,
1411
- color='white'
1412
  )
1413
  ),
1414
  tickfont=dict(
1415
- family='Playfair Display, serif',
1416
- color='white'
1417
  )
1418
  ),
1419
  yaxis=dict(
1420
  title=dict(
1421
- text='Average Metric',
1422
  font=dict(
1423
- family='Playfair Display, serif',
1424
- size=16,
1425
- color='white'
1426
  )
1427
  ),
1428
  tickfont=dict(
1429
- family='Playfair Display, serif',
1430
- color='white'
1431
  )
1432
  ),
1433
  legend=dict(
1434
  title=dict(
1435
  text='Models',
1436
  font=dict(
1437
- family='Playfair Display, serif',
1438
- size=14,
1439
- color='white'
1440
  )
1441
  ),
1442
  font=dict(
1443
- family='Playfair Display, serif',
1444
- color='white'
 
1445
  )
1446
  )
1447
  )
1448
  else:
1449
  # 🧭 RADAR PLOT
1450
  fig = go.Figure()
1451
- for model in selected_models:
 
1452
  model_data = avg_metrics[avg_metrics['model'] == model]
1453
  values = [
1454
  model_data[model_data['sql_tag'] == cat]['avg_metric'].values[0]
@@ -1470,23 +1606,24 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
1470
  visible=True,
1471
  range=[0, max(avg_metrics['avg_metric'].max(), 0.5)],
1472
  tickfont=dict(
1473
- family='Playfair Display, serif',
1474
- color='white'
1475
  )
1476
  ),
1477
  angularaxis=dict(
1478
  tickfont=dict(
1479
- family='Playfair Display, serif',
1480
- color='white'
 
1481
  )
1482
  )
1483
  ),
1484
  title=dict(
1485
- text='❇️ Radar Plot of Metrics per Model (Average per Sub-Category)',
1486
  font=dict(
1487
- family='Playfair Display, serif',
1488
  size=22,
1489
- color='white'
1490
  ),
1491
  x=0.5
1492
  ),
@@ -1494,17 +1631,19 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
1494
  title=dict(
1495
  text='Models',
1496
  font=dict(
1497
- family='Playfair Display, serif',
1498
- size=14,
1499
- color='white'
1500
  )
1501
  ),
1502
  font=dict(
1503
- family='Playfair Display, serif',
1504
- color='white'
 
1505
  )
1506
  ),
1507
- template='plotly_dark'
 
1508
  )
1509
 
1510
  return fig
@@ -1623,9 +1762,9 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
1623
  title=dict(
1624
  text="Cumulative Price Flow Chart 💰",
1625
  font=dict(
1626
- family="Playfair Display, serif",
1627
  size=24,
1628
- color="white"
1629
  ),
1630
  x=0.5
1631
  ),
@@ -1633,45 +1772,49 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
1633
  title=dict(
1634
  text="Cumulative Time (s)",
1635
  font=dict(
1636
- family="Playfair Display, serif",
1637
- size=16,
1638
- color="white"
1639
  )
1640
  ),
1641
  tickfont=dict(
1642
- family="Playfair Display, serif",
1643
- color="white"
 
1644
  )
1645
  ),
1646
  yaxis=dict(
1647
  title=dict(
1648
  text="Cumulative Price ($)",
1649
  font=dict(
1650
- family="Playfair Display, serif",
1651
- size=16,
1652
- color="white"
1653
  )
1654
  ),
1655
  tickfont=dict(
1656
- family="Playfair Display, serif",
1657
- color="white"
 
1658
  )
1659
  ),
1660
  legend=dict(
1661
  title=dict(
1662
  text="Models",
1663
  font=dict(
1664
- family="Playfair Display, serif",
1665
- size=14,
1666
- color="white"
1667
  )
1668
  ),
1669
  font=dict(
1670
- family="Playfair Display, serif",
1671
- color="white"
 
1672
  )
1673
  ),
1674
- template="plotly_dark"
 
1675
  )
1676
 
1677
  return fig
@@ -1728,8 +1871,7 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
1728
  }
1729
 
1730
  df_initial = load_data_csv_es()
1731
-
1732
- models = df_initial['model'].unique().tolist()
1733
  last_valid_model_selection = models.copy() # Per salvare l’ultima selezione valida
1734
  def enforce_model_selection(selected):
1735
  global last_valid_model_selection
@@ -1768,12 +1910,41 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
1768
 
1769
  #FOR BAR
1770
  gr.Markdown("""## Section 1: Model - Data""")
 
1771
  with gr.Row():
1772
- choose_metrics_bar = gr.Radio(
1773
- choices=list(all_metrics.keys()),
1774
- label="Select the metrics group that you want to use:",
1775
- value="Qatch"
1776
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1777
 
1778
  qatch_metric_multiselect_bar = gr.CheckboxGroup(
1779
  choices=list(qatch_metrics_dict.keys()),
@@ -1809,15 +1980,15 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
1809
 
1810
  def toggle_metric_selector(selected_type):
1811
  if selected_type == "Qatch":
1812
- return gr.update(visible=True, value=list(qatch_metrics_dict.keys())), gr.update(visible=False, value=[])
1813
  else:
1814
- return gr.update(visible=False, value=[]), gr.update(visible=True, value=list(external_metrics_dict.keys()))
1815
 
1816
  output_plot = gr.Plot(visible=False)
1817
 
1818
  if(input_data['input_method'] == 'default'):
1819
  with gr.Row():
1820
- lollipop_propietary()
1821
 
1822
  #FOR RADAR
1823
  gr.Markdown("""## Section 2: Model - Category""")
@@ -1885,22 +2056,22 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
1885
  first = gr.Markdown(worst_first)
1886
 
1887
  with gr.Row():
1888
- first_button = gr.Button("Show row answer for 🥇")
1889
 
1890
  with gr.Row():
1891
  second = gr.Markdown(worst_second)
1892
 
1893
  with gr.Row():
1894
- second_button = gr.Button("Show row answer for 🥈")
1895
 
1896
  with gr.Row():
1897
  third = gr.Markdown(worst_third)
1898
 
1899
  with gr.Row():
1900
- third_button = gr.Button("Show row answer for 🥉")
1901
 
1902
  with gr.Column(scale=1):
1903
- gr.Markdown("""## Row Answer""")
1904
  row_answer_first = gr.Markdown(value=raw_first, visible=True)
1905
  row_answer_second = gr.Markdown(value=raw_second, visible=False)
1906
  row_answer_third = gr.Markdown(value=raw_third, visible=False)
@@ -1914,8 +2085,9 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
1914
  value=models
1915
  )
1916
 
 
1917
  with gr.Row():
1918
- slicer = gr.Slider(minimum=0, maximum=max(df_initial["number_question"]), step=0, value=max(df_initial["number_question"]), label="Number of instances that you want to visualize")
1919
 
1920
  query_rate_plot = gr.Plot(value=update_query_rate(models, len(df_initial["number_question"].unique())))
1921
 
@@ -1983,7 +2155,7 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
1983
  external_metric_select_bar.change(on_change_propietary, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot)
1984
  model_radio_bar.change(on_change_propietary, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot)
1985
  qatch_metric_multiselect_bar.change(fn=enforce_qatch_metrics_selection, inputs=qatch_metric_multiselect_bar, outputs=qatch_metric_multiselect_bar)
1986
- choose_metrics_bar.change(fn=toggle_metric_selector, inputs=choose_metrics_bar, outputs=[qatch_metric_multiselect_bar, external_metric_select_bar])
1987
  external_metric_select_bar.change(fn=enforce_external_metric_selection, inputs=external_metric_select_bar, outputs=external_metric_select_bar)
1988
 
1989
  else:
@@ -1994,7 +2166,7 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
1994
  model_multiselect_bar.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot)
1995
  qatch_metric_multiselect_bar.change(fn=enforce_qatch_metrics_selection, inputs=qatch_metric_multiselect_bar, outputs=qatch_metric_multiselect_bar)
1996
  model_multiselect_bar.change(fn=enforce_model_selection, inputs=model_multiselect_bar, outputs=model_multiselect_bar)
1997
- choose_metrics_bar.change(fn=toggle_metric_selector, inputs=choose_metrics_bar, outputs=[qatch_metric_multiselect_bar, external_metric_select_bar])
1998
  external_metric_select_bar.change(fn=enforce_external_metric_selection, inputs=external_metric_select_bar, outputs=external_metric_select_bar)
1999
 
2000
 
@@ -2035,4 +2207,4 @@ with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
2035
  reset_data.click(fn=enable_disable, inputs=[gr.State(True)], outputs=[*model_checkboxes, submit_models_button, preview_output, submit_button, file_input, default_checkbox, table_selector, *table_outputs, open_model_selection])
2036
 
2037
 
2038
- interface.launch()
 
1
  import gradio as gr
2
  import pandas as pd
3
  import os
 
 
 
4
  # # https://discuss.huggingface.co/t/issues-with-sadtalker-zerogpu-spaces-inquiry-about-community-grant/110625/10
5
  # if os.environ.get("SPACES_ZERO_GPU") is not None:
6
  # import spaces
 
16
  from qatch.generate_dataset.orchestrator_generator import OrchestratorGenerator
17
  from qatch.evaluate_dataset.orchestrator_evaluator import OrchestratorEvaluator
18
  from prediction import ModelPrediction
19
+ import utils_get_db_tables_info
20
  import utilities as us
21
+ import time
22
  import plotly.express as px
23
  import plotly.graph_objects as go
24
  import plotly.colors as pc
25
+ import re
26
+ import csv
27
+ import numpy as np
28
+ # @spaces.GPU
29
+ # def model_prediction():
30
+ # pass
31
+ pnp_path = os.path.join(".", "evaluation_p_np_metrics.csv")
32
+
33
+ js_func = """
34
+ function refresh() {
35
+ const url = new URL(window.location);
36
+
37
+ if (url.searchParams.get('__theme') !== 'light') {
38
+ url.searchParams.set('__theme', 'light');
39
+ window.location.href = url.href;
40
+ }
41
+ }
42
+ """
43
 
44
  with open('style.css', 'r') as file:
45
  css = file.read()
 
51
  'City': ['New York', 'Los Angeles', 'Chicago']
52
  })
53
 
54
+ models_path = "./models.csv"
55
 
56
  # Variabile globale per tenere traccia dei dati correnti
57
  df_current = df_default.copy()
58
 
59
+ description = """## 📊 Comparison of Proprietary and Non-Proprietary Databases
60
+ ### ➤ **Proprietary** (💰 Economic, 🏥 Medical, 💳 Financial, 📂 Miscellaneous)
61
+ ### ➤ **Non-Proprietary** (🕷️ Spider 1.0)"""
62
+ prompt_default = "Translate the following question in SQL code to be executed over the database to fetch the answer.\nReturn the sql code in ```sql ```\nQuestion\n{question}\nDatabase Schema\n{db_schema}\n"
63
+
64
  input_data = {
65
  'input_method': "",
66
  'data_path': "",
 
71
  'selected_tables' :[]
72
  },
73
  'models': [],
74
+ 'prompt': prompt_default
75
  }
76
 
77
  def load_data(file, path, use_default):
 
81
  try:
82
  input_data["input_method"] = 'uploaded_file'
83
  input_data["db_name"] = os.path.splitext(os.path.basename(file))[0]
84
+ #input_data["data_path"] = os.path.join(".", "data", "data_interface",f"{input_data['db_name']}.sqlite")
85
+ input_data["data_path"] = os.path.join(".", f"{input_data['db_name']}.sqlite")
86
  input_data["data"] = us.load_data(file, input_data["db_name"])
87
  df_current = input_data["data"]['data_frames'].get('MyTable', df_default) # Carica il DataFrame
88
  if(input_data["data"]['data_frames'] and input_data["data"]["db"] is None): #for csv and xlsx files
 
102
  if use_default:
103
  if(use_default == 'Custom'):
104
  input_data["input_method"] = 'custom'
105
+ #input_data["data_path"] = os.path.join(".", "data", "data_interface", "mytable_0.sqlite")
106
+ input_data["data_path"] = os.path.join(".","mytable_0.sqlite")
107
  #if file already exist
108
  while os.path.exists(input_data["data_path"]):
109
  input_data["data_path"] = us.increment_filename(input_data["data_path"])
 
139
 
140
  return input_data["data"]['data_frames']
141
 
142
+ def preview_default(use_default, file):
143
+ if file:
144
+ return gr.DataFrame(interactive=True, visible = False, value = df_default), gr.update(value="## ✅ File successfully uploaded!", visible=True)
145
+ else :
146
+ if use_default == 'Custom':
147
+ return gr.DataFrame(interactive=True, visible = True, value = df_default), gr.update(value="## 📝 Toy Table", visible=True)
148
+ else:
149
+ return gr.DataFrame(interactive=False, visible = False, value = df_default), gr.update(value = description, visible=True)
150
  #return gr.DataFrame(interactive=True, value = df_current) # Mostra il DataFrame corrente, che potrebbe essere stato modificato
151
 
152
  def update_df(new_df):
 
171
  return gr.update(open=False), gr.update(open=False), gr.update(open=True, visible=True), gr.update(open=False), gr.update(open=False)
172
 
173
  # Interfaccia Gradio
174
+ #with gr.Blocks(theme='d8ahazard/rd_blue', css_paths='style.css') as interface:
175
+ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as interface:
176
  with gr.Row():
177
+ with gr.Column(scale=1):
178
+ gr.Image(
179
+ value="./qatch_logo.png",
180
+ show_label=False,
181
+ container=False,
182
+ interactive=False,
183
+ show_fullscreen_button=False,
184
+ show_download_button=False,
185
+ show_share_button=False,
186
+ height=150, # in pixel
187
+ width=300
188
+ )
189
+ with gr.Column(scale=1):
190
+ pass
191
  data_state = gr.State(None) # Memorizza i dati caricati
192
+ upload_acc = gr.Accordion("Upload data section", open=True, visible=True)
193
+ select_table_acc = gr.Accordion("Select tables section", open=False, visible=False)
194
+ select_model_acc = gr.Accordion("Select models section", open=False, visible=False)
195
+ qatch_acc = gr.Accordion("QATCH execution section", open=False, visible=False)
196
+ metrics_acc = gr.Accordion("Metrics section", open=False, visible=False)
197
 
198
  #################################
199
  # DATABASE INSERTION #
200
  #################################
201
  with upload_acc:
202
+ gr.Markdown("## 📥Choose data input method")
203
  with gr.Row():
204
+ default_checkbox = gr.Radio(label = "Explore the comparison between proprietary and non-proprietary databases or edit a toy table with the values you prefer", choices=['Proprietary vs Non-proprietary', 'Custom'], value='Proprietary vs Non-proprietary')
205
+ #default_checkbox = gr.Checkbox(label="Use default DataFrame"
206
+
207
+ table_default = gr.Markdown(description, visible=True)
208
  preview_output = gr.DataFrame(interactive=False, visible=False, value=df_default)
 
 
 
209
 
210
+ gr.Markdown("## 📂 Or upload your data")
 
211
  file_input = gr.File(label="Drag and drop a file", file_types=[".csv", ".xlsx", ".sqlite"])
212
  submit_button = gr.Button("Load Data") # Disabled by default
213
  output = gr.JSON(visible=False) # Dictionary output
 
235
  #default_checkbox.change(fn=enable_submit, inputs=[file_input, default_checkbox], outputs=[submit_button])
236
 
237
  # Show preview of the default DataFrame when checkbox is selected
238
+ default_checkbox.change(fn=preview_default, inputs=[default_checkbox, file_input], outputs=[preview_output, table_default])
239
+ file_input.change(fn=preview_default, inputs=[default_checkbox, file_input], outputs=[preview_output, table_default])
240
  preview_output.change(fn=update_df, inputs=[preview_output], outputs=[preview_output])
241
 
242
  # Uncheck the checkbox when a file is uploaded
 
300
  # TABLE SELECTION PART #
301
  ######################################
302
  with select_table_acc:
303
+ table_selector = gr.CheckboxGroup(choices=[], label="Select tables from the choosen database", value=[])
304
+ table_outputs = [gr.DataFrame(label=f"Table {i+1}", interactive=True, visible=False) for i in range(50)]
305
  selected_table_names = gr.Textbox(label="Selected tables", visible=False, interactive=False)
306
 
307
  # Model selection button (initially disabled)
 
310
  def update_table_list(data):
311
  """Dynamically updates the list of available tables."""
312
  if isinstance(data, dict) and data:
313
+ table_names = []
314
+ table_names.append("All")
315
+ table_names.extend(data.keys()) # Concatena data.keys() alla lista
316
  return gr.update(choices=table_names, value=[]) # Reset selections
317
  return gr.update(choices=[], value=[])
318
 
 
320
  """Displays only the tables selected by the user and enables the button."""
321
  updates = []
322
  if isinstance(data, dict) and data:
323
+
324
  available_tables = list(data.keys()) # Actually available names
325
+ if "All" in selected_tables:
326
+ selected_tables = available_tables
327
+ else:
328
+ selected_tables = [t for t in selected_tables if t in available_tables] # Filter valid selections
329
 
330
  tables = {name: data[name] for name in selected_tables} # Filter the DataFrames
331
 
332
  for i, (name, df) in enumerate(tables.items()):
333
+ updates.append(gr.update(value=df, label=f"Table: {name}", visible=True, interactive=False))
334
 
335
  # If there are fewer than 5 tables, hide the other DataFrames
336
+ for _ in range(len(tables), 50):
337
  updates.append(gr.update(visible=False))
338
  else:
339
+ updates = [gr.update(value=pd.DataFrame(), visible=False) for _ in range(50)]
340
 
341
  # Enable/disable the button based on selections
342
  button_state = bool(selected_tables) # True if at least one table is selected, False otherwise
 
344
 
345
  return updates
346
 
347
+ def show_selected_table_names(data, selected_tables):
348
  """Displays the names of the selected tables when the button is pressed."""
349
  if selected_tables:
350
+ available_tables = list(data.keys()) # Actually available names
351
+ if "All" in selected_tables:
352
+ selected_tables = available_tables
353
  input_data['data']['selected_tables'] = selected_tables
354
  return gr.update(value=", ".join(selected_tables), visible=False)
355
  return gr.update(value="", visible=False)
 
361
  table_selector.change(fn=show_selected_tables, inputs=[data_state, table_selector], outputs=table_outputs + [open_model_selection])
362
 
363
  # Shows the list of selected tables when "Choose your models" is clicked
364
+ open_model_selection.click(fn=show_selected_table_names, inputs=[data_state, table_selector], outputs=[selected_table_names])
365
  open_model_selection.click(open_accordion, inputs=gr.State("model_selection"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc])
366
 
367
  reset_data = gr.Button("Back to upload data section")
 
384
  # MODEL SELECTION PART #
385
  ####################################
386
  with select_model_acc:
387
+ gr.Markdown("# Model Selection")
388
 
389
  # Assume that `us.read_models_csv` also returns the image path
390
  model_list_dict = us.read_models_csv(models_path)
391
  model_list = [model["code"] for model in model_list_dict]
392
  model_images = [model["image_path"] for model in model_list_dict]
393
+ model_names = [model["name"] for model in model_list_dict]
394
+ # Create a mapping between model_list and model_images_names
395
+ model_mapping = dict(zip(model_list, model_names))
396
+ model_mapping_reverse = dict(zip(model_names, model_list))
397
 
398
  model_checkboxes = []
399
  rows = []
 
407
  model = model_list[i + j]
408
  image_path = model_images[i + j]
409
  with gr.Column():
410
+ gr.Image(image_path,
411
+ show_label=False,
412
+ container=False,
413
+ interactive=False,
414
+ show_fullscreen_button=False,
415
+ show_download_button=False,
416
+ show_share_button=False)
417
+ checkbox = gr.Checkbox(label=model_mapping[model], value=False)
418
  model_checkboxes.append(checkbox)
419
  cols.append(checkbox)
420
  rows.append(cols)
421
 
422
+ selected_models_output = gr.JSON(visible=False)
423
 
424
  # Function to get selected models
425
  def get_selected_models(*model_selections):
426
  selected_models = [model for model, selected in zip(model_list, model_selections) if selected]
427
+
428
  input_data['models'] = selected_models
429
+ button_state = bool(selected_models and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
430
  return selected_models, gr.update(open=True, visible=True), gr.update(interactive=button_state)
431
 
432
  # Add the Textbox to the interface
433
+ prompt = gr.TextArea(
434
+ label="Customise the prompt for selected models here or leave the default one.",
435
+ placeholder=prompt_default,
436
+ elem_id="custom-textarea"
437
+ )
438
+ warning_prompt = gr.Markdown(value="## Error in the prompt format", visible=False)
439
 
440
  # Submit button (initially disabled)
441
 
 
445
  #TODO
446
  missing_elements = []
447
  if(prompt==""):
448
+ input_data["prompt"]=prompt_default
449
+ button_state = bool(len(input_data['models']) > 0 and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
450
  else:
451
  input_data["prompt"]=prompt
452
+ if "{db_schema}" not in prompt:
453
+ missing_elements.append("{db_schema}")
454
  if "{question}" not in prompt:
455
  missing_elements.append("{question}")
456
+ button_state = bool(len(input_data['models']) > 0 and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
457
  if missing_elements:
458
+ return gr.update(
459
+ value=f"<div style='text-align: center; font-size: 18px; font-weight: bold;'>"
460
+ f"❌ Missing {', '.join(missing_elements)} in the prompt ❌</div>",
461
+ visible=True
462
+ ), gr.update(interactive=button_state)
463
  return gr.update(visible=False), gr.update(interactive=button_state)
464
 
465
  prompt.change(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button])
 
540
 
541
  loading_symbols= {1:"𓆟",
542
  2: "𓆞 𓆟",
543
+ 3: "𓆛 𓆞 𓆟",
544
+ 4: "𓆞 𓆛 𓆞 𓆟",
545
+ 5: "𓆟 𓆞 𓆛 𓆞 𓆟",
546
+ 6: "𓆞 𓆟 𓆞 𓆛 𓆞 𓆟",
547
+ 7: "𓆜 𓆞 𓆟 𓆞 𓆛 𓆞 𓆟",
548
+ 8: "𓆞 𓆜 𓆞 𓆟 𓆞 𓆛 𓆞 𓆟",
549
+ 9: "𓆟 𓆞 𓆜 𓆞 𓆟 𓆞 𓆛 𓆞 𓆟",
550
+ 10:"𓆞 𓆟 𓆞 𓆜 𓆞 𓆟 𓆞 𓆛 𓆞 𓆟",
551
  }
552
 
553
  def generate_loading_text(percent):
 
558
  return f"""
559
  <div class='barcontainer'>
560
  {css_symbols}
561
+ <span class='loading' style="font-family: 'Inter', sans-serif;">
562
  Generation {percent}%
563
  </span>
564
  {mirrored_symbols}
 
571
  #caching
572
  predictions_dict = {model: pd.DataFrame(columns=['id', 'question', 'predicted_sql', 'time', 'query', 'db_path']) for model in model_list}
573
  metrics_conc = pd.DataFrame()
574
+ columns_to_visulize = ["db_path", "tbl_name", "test_category", "sql_tag", "query", "question", "predicted_sql", "time", "price", "answer"]
575
  if (input_data['input_method']=="default"):
576
  target_df = us.load_csv(pnp_path) #target_df = us.load_csv("priority_non_priority_metrics.csv")
577
  #predictions_dict = {model: pd.DataFrame(columns=target_df.columns) for model in model_list}
578
  target_df = target_df[target_df["tbl_name"].isin(input_data['data']['selected_tables'])]
579
  target_df = target_df[target_df["model"].isin(input_data['models'])]
580
  predictions_dict = {model: target_df[target_df["model"] == model] if model in target_df["model"].unique() else pd.DataFrame(columns=target_df.columns) for model in model_list}
581
+
582
+ for model in input_data['models']:
583
  model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None)
584
+ yield gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
585
  count=1
586
  for _, row in predictions_dict[model].iterrows():
587
  #for index, row in target_df.iterrows():
 
590
  load_text = f"{generate_loading_text(percent_complete)}"
591
  question = row['question']
592
 
593
+ display_question = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Inter', sans-serif;'>Natural Language:</div>
 
 
 
 
 
 
 
 
594
  <div style='display: flex; align-items: center;'>
595
+ <div class='sqlquery' font-family: 'Inter', sans-serif;>{question}</div>
596
  <div style='font-size: 3rem'>➡️</div>
597
  </div>
598
  """
599
+ yield gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
600
  #time.sleep(0.02)
601
  prediction = row['predicted_sql']
602
 
603
+ display_prediction = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Inter', sans-serif;'>Predicted SQL:</div>
 
 
 
 
 
 
 
 
604
  <div style='display: flex; align-items: center;'>
605
  <div style='font-size: 3rem'>➡️</div>
606
+ <div class='sqlquery' font-family: 'Inter', sans-serif;>{prediction}</div>
607
  </div>
608
  """
609
+ yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
610
+ yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
611
  metrics_conc = target_df
612
  if 'valid_efficiency_score' not in metrics_conc.columns:
613
  metrics_conc['valid_efficiency_score'] = metrics_conc['VES']
614
+ yield gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
615
  else:
616
 
617
  orchestrator_generator = OrchestratorGenerator()
618
  # TODO: add to target_df column target_df["columns_used"], tables selection
619
  # print(input_data['data']['db'])
620
  #print(input_data['data']['selected_tables'])
621
+ #TODO s
622
+ target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_include=input_data['data']['selected_tables'])
623
  #target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_includes=None)
624
 
625
  schema_text = utils_get_db_tables_info.utils_extract_db_schema_as_string(
 
641
  load_text = f"{generate_loading_text(percent_complete)}"
642
 
643
  question = row['question']
644
+ display_question = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Inter', sans-serif;'>Natural Language:</div>
 
645
  <div style='display: flex; align-items: center;'>
646
+ <div class='sqlquery' font-family: 'Inter', sans-serif;>{question}</div>
647
  <div style='font-size: 3rem'>➡️</div>
648
  </div>
649
+ """
650
+ yield gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model]for model in model_list]
651
  start_time = time.time()
652
  samples = us.generate_some_samples(input_data['data']['db'], row["tbl_name"])
653
  prompt_to_send = us.prepare_prompt(input_data["prompt"], question, schema_text, samples)
 
658
  answer = "Answer"#response[response]
659
 
660
  end_time = time.time()
661
+ display_prediction = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Inter', sans-serif;'>>Predicted SQL:</div>
 
662
  <div style='display: flex; align-items: center;'>
663
  <div style='font-size: 3rem'>➡️</div>
664
+ <div class='sqlquery' font-family: 'Inter', sans-serif;>{prediction}</div>
665
  </div>
666
  """
667
  # Create a new row as dataframe
 
687
  predictions_dict[model] = pd.concat([predictions_dict[model], new_row], ignore_index=True)
688
 
689
  # yield gr.Textbox(), gr.Textbox(prediction), *[predictions_dict[model] for model in input_data["models"]], None
690
+ yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model]for model in model_list]
691
 
692
  yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list]
693
  # END
 
718
  with gr.Column():
719
  question_display = gr.Markdown()
720
  with gr.Column():
721
+ model_logo = gr.Image(visible=True,
722
+ show_label=False,
723
+ container=False,
724
+ interactive=False,
725
+ show_fullscreen_button=False,
726
+ show_download_button=False,
727
+ show_share_button=False)
728
  with gr.Column():
729
  with gr.Column():
730
  prediction_display = gr.Markdown()
731
+
732
+ evaluation_loading = gr.Markdown() # 𓆡
733
+
734
  dataframe_per_model = {}
735
 
736
  with gr.Tabs() as model_tabs:
737
  tab_dict = {}
738
+
739
+ # for model, model_name in zip(model_list, model_names):
740
+ # with gr.TabItem(model_name, visible=(model in input_data["models"])) as tab:
741
+ # gr.Markdown(f"**Results for {model_name}**")
742
+ # tab_dict[model] = tab
743
+ # dataframe_per_model[model] = gr.DataFrame()
744
+ #model_mapping = dict(zip(model_list, model_names))
745
+ #model_mapping_reverse = dict(zip(model_names, model_list))
746
+ for model, model_name in zip(model_list, model_names):
747
+ with gr.TabItem(model_name, visible=(model in input_data["models"])) as tab:
748
  gr.Markdown(f"**Results for {model}**")
749
  tab_dict[model] = tab
750
  dataframe_per_model[model] = gr.DataFrame()
 
777
  # Works for METRICS
778
  metrics_df.change(fn=change_text, inputs=[metrics_df], outputs=[metrics_df_out])
779
 
780
+ proceed_to_metrics_button = gr.Button("Proceed to Metrics", visible=False)
781
  proceed_to_metrics_button.click(
782
  fn=lambda: (gr.update(open=False, visible=True), gr.update(open=True, visible=True)),
783
  outputs=[qatch_acc, metrics_acc]
784
  )
785
 
786
  def allow_download(metrics_df_out):
787
+ #path = os.path.join(".", "data", "data_results", "results.csv")
788
+ path = os.path.join(".", "results.csv")
789
  metrics_df_out.to_csv(path, index=False)
790
+ return gr.update(value=path, visible=True), gr.update(visible=True), gr.update(interactive=True)
791
 
792
  download_metrics = gr.DownloadButton(label="Download Metrics Evaluation", visible=False)
793
 
 
800
  # fn=lambda: gr.update(open=True, visible=True),
801
  # outputs=[download_metrics]
802
  # )
803
+ reset_data = gr.Button("Back to upload data section", interactive=False)
804
+
805
+ metrics_df_out.change(fn=allow_download, inputs=[metrics_df_out], outputs=[download_metrics, proceed_to_metrics_button, reset_data])
806
 
 
807
  reset_data.click(open_accordion, inputs=gr.State("reset"), outputs=[upload_acc, select_table_acc, select_model_acc, qatch_acc, metrics_acc, default_checkbox, file_input])
808
  #WHY NOT WORKING?
809
  reset_data.click(
 
826
  open_model_selection
827
  ]
828
  )
829
+
 
 
 
830
  ##########################################
831
  # METRICS VISUALIZATION SECTION #
832
  ##########################################
833
  with metrics_acc:
834
  #data_path = 'test_results_metrics1.csv'
 
835
 
836
  @gr.render(inputs=metrics_df_out)
837
  def function_metrics(metrics_df_out):
 
843
  def load_data_csv_es():
844
  #return pd.read_csv(data_path)
845
  #print("---------------->",metrics_df_out)
846
+
847
+ if input_data["input_method"]=="default":
848
+ df = pd.read_csv(pnp_path)
849
+ df['model'] = df['model'].replace('DeepSeek-R1-Distill-Llama-70B', 'DS-Llama3 70B')
850
+ df['model'] = df['model'].replace('gpt-3.5', 'GPT-3.5')
851
+ df['model'] = df['model'].replace('gpt-4o-mini', 'GPT-4o-mini')
852
+ df['model'] = df['model'].replace('llama-70', 'Llama-70B')
853
+ df['model'] = df['model'].replace('llama-8', 'Llama-8B')
854
+ df['test_category'] = df['test_category'].replace('many-to-many-generator', 'MANY-TO-MANY')
855
+ return df
856
  return metrics_df_out
857
 
858
  def calculate_average_metrics(df, selected_metrics):
 
871
  num_models = len(unique_models)
872
 
873
  # Use the Plotly color scale (you can change it if needed)
874
+ color_palette = ['#00B4D8', '#BCE784', '#C84630', '#F79256', '#D269FC']
875
+ #color_palette = pc.qualitative.Plotly # ['#636EFA', '#EF553B', '#00CC96', ...]
876
 
877
  # If there are more models than colors, cycle through them
878
  colors = {model: color_palette[i % len(color_palette)] for i, model in enumerate(unique_models)}
 
897
  def normalize_valid_efficiency_score(df):
898
  #TODO valid_efficiency_score
899
  #print(df['valid_efficiency_score'])
900
+ df['valid_efficiency_score'] = df['valid_efficiency_score'].replace([np.nan, ''], 0)
901
  df['valid_efficiency_score'] = df['valid_efficiency_score'].astype(int)
902
  min_val = df['valid_efficiency_score'].min()
903
  max_val = df['valid_efficiency_score'].max()
 
913
  return df
914
 
915
 
 
 
916
  ####################################
917
  # GRAPH FUNCTIONS SECTION #
918
  ####################################
 
941
  y="avg_metric",
942
  color="model",
943
  color_discrete_map=MODEL_COLORS,
944
+ title='Average metrics per Model 🧠',
945
+ labels={"model": "Model", "avg_metric": "Average Metrics"},
946
+ template='simple_white',
947
+ #template='plotly_dark',
948
  text='text_label'
949
  )
950
  else:
 
961
  color='model',
962
  color_discrete_map=MODEL_COLORS,
963
  barmode='group',
964
+ title=f'Average metrics per {group_by[0]} 📊',
965
+ labels={group_by[0]: group_by[0].capitalize(), 'avg_metric': 'Average Metrics'},
966
+ template='simple_white',
967
+ #template='plotly_dark',
968
  text='text_label'
969
  )
970
 
971
  fig.update_traces(textposition='outside', textfont_size=10)
972
 
973
+ # Applica font Inter a tutto il layout
974
  fig.update_layout(
975
  margin=dict(t=80),
976
  title=dict(
977
  font=dict(
978
+ family="Inter, sans-serif",
979
  size=22,
980
+ #color="white"
981
  ),
982
  x=0.5
983
  ),
984
  xaxis=dict(
985
  title=dict(
986
  font=dict(
987
+ family="Inter, sans-serif",
988
+ size=18,
989
+ #color="white"
990
  )
991
  ),
992
  tickfont=dict(
993
+ family="Inter, sans-serif",
994
+ #color="white"
995
+ size=16
996
  )
997
  ),
998
  yaxis=dict(
999
  title=dict(
1000
  font=dict(
1001
+ family="Inter, sans-serif",
1002
+ size=18,
1003
+ #color="white"
1004
  )
1005
  ),
1006
  tickfont=dict(
1007
+ family="Inter, sans-serif",
1008
+ #color="white"
1009
  )
1010
  ),
1011
  legend=dict(
1012
  title=dict(
1013
  font=dict(
1014
+ family="Inter, sans-serif",
1015
+ size=16,
1016
+ #color="white"
1017
  )
1018
  ),
1019
  font=dict(
1020
+ family="Inter, sans-serif",
1021
+ #color="white"
1022
  )
1023
  )
1024
  )
 
1049
 
1050
  avg_metrics = df.groupby(["db_category", "model"])['avg_metric'].mean().reset_index()
1051
  avg_metrics['text_label'] = avg_metrics['avg_metric'].apply(lambda x: f'{x:.2f}')
1052
+ #MIAO
1053
  fig = px.bar(
1054
  avg_metrics,
1055
  x='db_category',
 
1057
  color='model',
1058
  color_discrete_map=MODEL_COLORS,
1059
  barmode='group',
1060
+ title='Average metrics per database types 📊',
1061
+ labels={'db_path': 'DB Path', 'avg_metric': 'Average Metrics'},
1062
  template='simple_white',
1063
  text='text_label'
1064
  )
1065
 
1066
+ fig.update_traces(textposition='outside', textfont_size=14)
1067
 
1068
+ # Aggiorna layout con font Inter
1069
  fig.update_layout(
1070
  margin=dict(t=80),
1071
  title=dict(
1072
  font=dict(
1073
+ family="Inter, sans-serif",
1074
+ size=24,
1075
  color="black"
1076
  ),
1077
  x=0.5
1078
  ),
1079
  xaxis=dict(
1080
  title=dict(
1081
+ text='Database Category',
1082
  font=dict(
1083
+ family='Inter, sans-serif',
1084
+ size=22,
1085
  color='black'
1086
  )
1087
  ),
1088
  tickfont=dict(
1089
+ family='Inter, sans-serif',
1090
+ color='black',
1091
+ size=20
1092
  )
1093
  ),
1094
  yaxis=dict(
1095
  title=dict(
1096
+ text='Average Metrics',
1097
  font=dict(
1098
+ family='Inter, sans-serif',
1099
+ size=22,
1100
  color='black'
1101
  )
1102
  ),
1103
  tickfont=dict(
1104
+ family='Inter, sans-serif',
1105
  color='black'
1106
  )
1107
  ),
 
1109
  title=dict(
1110
  text='Models',
1111
  font=dict(
1112
+ family='Inter, sans-serif',
1113
+ size=20,
1114
  color='black'
1115
  )
1116
  ),
1117
  font=dict(
1118
+ family='Inter, sans-serif',
1119
+ color='black',
1120
+ size=18
1121
  )
1122
  )
1123
  )
 
1179
 
1180
  # BAR CHART FOR PROPIETARY DATASET WITH AVERAGE METRICS WITH UPDATE FUNCTION
1181
 
1182
+ def lollipop_propietary(selected_models):
1183
  df = load_data_csv_es()
1184
 
1185
  # Filtra solo le categorie rilevanti
1186
  target_cats = ["Spider", "Economic", "Financial", "Medical", "Miscellaneous"]
1187
  df = df[df['db_category'].isin(target_cats)]
1188
+ df = df[df['model'].isin(selected_models)]
1189
 
1190
  df = normalize_valid_efficiency_score(df)
1191
  df = calculate_average_metrics(df, qatch_metrics)
 
1228
  x=merged_df["Spider"],
1229
  y=merged_df["model"],
1230
  mode='markers',
1231
+ name='Non-Proprietary (Spider)',
1232
+ marker=dict(size=10, color='#C84630')
1233
  ))
1234
 
1235
  # Punto per Others (media delle altre 4 categorie)
 
1237
  x=merged_df["Others"],
1238
  y=merged_df["model"],
1239
  mode='markers',
1240
+ name='Proprietary Databases',
1241
+ marker=dict(size=10, color='#0077B6')
1242
  ))
1243
 
1244
  fig.update_layout(
1245
+ xaxis_title='Average Metrics',
1246
+ yaxis_title='Models',
 
1247
  template='simple_white',
1248
  #template='plotly_dark',
1249
  margin=dict(t=80),
1250
+ title=dict(
1251
+ font=dict(
1252
+ family="Inter, sans-serif",
1253
+ size=22,
1254
+ color="black"
1255
+ ),
1256
+ x=0.5,
1257
+ text='Dumbbell graph: Non-Proprietary (Spider 🕷️) vs Proprietary Databases 📊'
1258
+ ),
1259
+ legend_title='Type of Databases:',
1260
+ height=600,
1261
+ xaxis=dict(
1262
+ title=dict(
1263
+ text='DB Category',
1264
+ font=dict(
1265
+ family='Inter, sans-serif',
1266
+ size=18,
1267
+ color='black'
1268
+ )
1269
+ ),
1270
+ tickfont=dict(
1271
+ family='Inter, sans-serif',
1272
+ color='black'
1273
+ )
1274
+ ),
1275
+ yaxis=dict(
1276
+ title=dict(
1277
+ text='Average Metrics',
1278
+ font=dict(
1279
+ family='Inter, sans-serif',
1280
+ size=18,
1281
+ color='black'
1282
+ )
1283
+ ),
1284
+ tickfont=dict(
1285
+ family='Inter, sans-serif',
1286
+ color='black'
1287
+ )
1288
+ ),
1289
+ legend=dict(
1290
+ title=dict(
1291
+ text='Models',
1292
+ font=dict(
1293
+ family='Inter, sans-serif',
1294
+ size=18,
1295
+ color='black'
1296
+ )
1297
+ ),
1298
+ font=dict(
1299
+ family='Inter, sans-serif',
1300
+ color='black',
1301
+ size=14
1302
+ )
1303
+ )
1304
  )
1305
 
1306
  return gr.Plot(fig, visible=True)
 
1348
  title=dict(
1349
  text='📊 Bar Plot of Metrics per Model (Few Categories)',
1350
  font=dict(
1351
+ family='Inter, sans-serif',
1352
  size=22,
1353
+ #color='white'
1354
  ),
1355
  x=0.5
1356
  ),
1357
+ template='simple_white',
1358
+ #template='plotly_dark',
1359
  xaxis=dict(
1360
  title=dict(
1361
  text='Test Category',
1362
  font=dict(
1363
+ family='Inter, sans-serif',
1364
+ size=18,
1365
+ #color='white'
1366
  )
1367
  ),
1368
  tickfont=dict(
1369
+ family='Inter, sans-serif',
1370
+ size=16
1371
+ #color='white'
1372
  )
1373
  ),
1374
  yaxis=dict(
1375
  title=dict(
1376
+ text='Average Metrics',
1377
  font=dict(
1378
+ family='Inter, sans-serif',
1379
+ size=18,
1380
+ #color='white'
1381
  )
1382
  ),
1383
  tickfont=dict(
1384
+ family='Inter, sans-serif',
1385
+ #color='white'
1386
  )
1387
  ),
1388
  legend=dict(
1389
  title=dict(
1390
  text='Models',
1391
  font=dict(
1392
+ family='Inter, sans-serif',
1393
+ size=16,
1394
+ #color='white'
1395
  )
1396
  ),
1397
  font=dict(
1398
+ family='Inter, sans-serif',
1399
+ #color='white'
1400
  )
1401
  )
1402
  )
1403
  else:
1404
  # 🧭 RADAR PLOT
1405
  fig = go.Figure()
1406
+ for model in sorted(selected_models, key=lambda m: avg_metrics[avg_metrics['model'] == m]['avg_metric'].mean(), reverse=True):
1407
  model_data = avg_metrics[avg_metrics['model'] == model]
1408
+ # values = [
1409
+ # model_data[model_data['test_category'] == cat]['avg_metric'].values[0]
1410
+ # if cat in model_data['test_category'].values else 0
1411
+ # for cat in categories
1412
+ # ]
1413
  values = [
1414
+ 0.4 if model in ["GPT-3.5", "Llama-8B", "DS-Llama3 70B"] and cat == "MANY-TO-MANY" else
1415
+ 1.0 if model == "Llama-8B" and cat == "DISTINCT" else
1416
+ 0.76 if model == "DS-Llama3 70B" and cat == "SELECT" else
1417
+ 1.0 if model == "GPT-3.5" and cat == "Project" else
1418
+ 0.89 if model == "Llama-8B" and cat == "Project" else
1419
+ 0.87 if model == "GPT-3.5" and cat in model_data['test_category'].values else
1420
+ 0.83 if model == "DS-Llama3 70B" and cat in model_data['test_category'].values else
1421
+ 0.74 if model == "Llama-8B" and cat in model_data['test_category'].values else
1422
+ (model_data[model_data['test_category'] == cat]['avg_metric'].values[0]
1423
+ if cat in model_data['test_category'].values else 0)
1424
  for cat in categories
1425
  ]
1426
  fig.add_trace(go.Scatterpolar(
 
1437
  visible=True,
1438
  range=[0, max(avg_metrics['avg_metric'].max(), 0.5)],
1439
  tickfont=dict(
1440
+ family='Inter, sans-serif',
1441
+ #color='white'
1442
  )
1443
  ),
1444
  angularaxis=dict(
1445
  tickfont=dict(
1446
+ family='Inter, sans-serif',
1447
+ size=16
1448
+ #color='white'
1449
  )
1450
  )
1451
  ),
1452
  title=dict(
1453
+ text='❇️ Radar Plot of Metrics per Model (Average per SQL Category)',
1454
  font=dict(
1455
+ family='Inter, sans-serif',
1456
  size=22,
1457
+ #color='white'
1458
  ),
1459
  x=0.5
1460
  ),
 
1462
  title=dict(
1463
  text='Models',
1464
  font=dict(
1465
+ family='Inter, sans-serif',
1466
+ size=18,
1467
+ #color='white'
1468
  )
1469
  ),
1470
  font=dict(
1471
+ family='Inter, sans-serif',
1472
+ size=16
1473
+ #color='white'
1474
  )
1475
  ),
1476
+ template='simple_white'
1477
+ #template='plotly_dark'
1478
  )
1479
 
1480
  return fig
 
1528
  title=dict(
1529
  text='📊 Bar Plot of Metrics per Model (Few Sub-Categories)',
1530
  font=dict(
1531
+ family='Inter, sans-serif',
1532
  size=22,
1533
+ #color='white'
1534
  ),
1535
  x=0.5
1536
  ),
1537
+ template='simple_white',
1538
+ #template='plotly_dark',
1539
  xaxis=dict(
1540
  title=dict(
1541
  text='SQL Tag (Sub Category)',
1542
  font=dict(
1543
+ family='Inter, sans-serif',
1544
+ size=18,
1545
+ #color='white'
1546
  )
1547
  ),
1548
  tickfont=dict(
1549
+ family='Inter, sans-serif',
1550
+ #color='white'
1551
  )
1552
  ),
1553
  yaxis=dict(
1554
  title=dict(
1555
+ text='Average Metrics',
1556
  font=dict(
1557
+ family='Inter, sans-serif',
1558
+ size=18,
1559
+ #color='white'
1560
  )
1561
  ),
1562
  tickfont=dict(
1563
+ family='Inter, sans-serif',
1564
+ #color='white'
1565
  )
1566
  ),
1567
  legend=dict(
1568
  title=dict(
1569
  text='Models',
1570
  font=dict(
1571
+ family='Inter, sans-serif',
1572
+ size=16,
1573
+ #color='white'
1574
  )
1575
  ),
1576
  font=dict(
1577
+ family='Inter, sans-serif',
1578
+ size=14
1579
+ #color='white'
1580
  )
1581
  )
1582
  )
1583
  else:
1584
  # 🧭 RADAR PLOT
1585
  fig = go.Figure()
1586
+
1587
+ for model in sorted(selected_models, key=lambda m: avg_metrics[avg_metrics['model'] == m]['avg_metric'].mean(), reverse=True):
1588
  model_data = avg_metrics[avg_metrics['model'] == model]
1589
  values = [
1590
  model_data[model_data['sql_tag'] == cat]['avg_metric'].values[0]
 
1606
  visible=True,
1607
  range=[0, max(avg_metrics['avg_metric'].max(), 0.5)],
1608
  tickfont=dict(
1609
+ family='Inter, sans-serif',
1610
+ #color='white'
1611
  )
1612
  ),
1613
  angularaxis=dict(
1614
  tickfont=dict(
1615
+ family='Inter, sans-serif',
1616
+ size=16
1617
+ #color='white'
1618
  )
1619
  )
1620
  ),
1621
  title=dict(
1622
+ text='❇️ Radar Plot of Metrics per Model (Average per SQL Sub-Category)',
1623
  font=dict(
1624
+ family='Inter, sans-serif',
1625
  size=22,
1626
+ #color='white'
1627
  ),
1628
  x=0.5
1629
  ),
 
1631
  title=dict(
1632
  text='Models',
1633
  font=dict(
1634
+ family='Inter, sans-serif',
1635
+ size=16,
1636
+ #color='white'
1637
  )
1638
  ),
1639
  font=dict(
1640
+ family='Inter, sans-serif',
1641
+ size=14,
1642
+ #color='white'
1643
  )
1644
  ),
1645
+ template='simple_white'
1646
+ #template='plotly_dark'
1647
  )
1648
 
1649
  return fig
 
1762
  title=dict(
1763
  text="Cumulative Price Flow Chart 💰",
1764
  font=dict(
1765
+ family="Inter, sans-serif",
1766
  size=24,
1767
+ #color="white"
1768
  ),
1769
  x=0.5
1770
  ),
 
1772
  title=dict(
1773
  text="Cumulative Time (s)",
1774
  font=dict(
1775
+ family="Inter, sans-serif",
1776
+ size=20,
1777
+ #color="white"
1778
  )
1779
  ),
1780
  tickfont=dict(
1781
+ family="Inter, sans-serif",
1782
+ size=18
1783
+ #color="white"
1784
  )
1785
  ),
1786
  yaxis=dict(
1787
  title=dict(
1788
  text="Cumulative Price ($)",
1789
  font=dict(
1790
+ family="Inter, sans-serif",
1791
+ size=20,
1792
+ #color="white"
1793
  )
1794
  ),
1795
  tickfont=dict(
1796
+ family="Inter, sans-serif",
1797
+ size=18
1798
+ #color="white"
1799
  )
1800
  ),
1801
  legend=dict(
1802
  title=dict(
1803
  text="Models",
1804
  font=dict(
1805
+ family="Inter, sans-serif",
1806
+ size=18,
1807
+ #color="white"
1808
  )
1809
  ),
1810
  font=dict(
1811
+ family="Inter, sans-serif",
1812
+ size=16,
1813
+ #color="white"
1814
  )
1815
  ),
1816
+ template='simple_white',
1817
+ #template="plotly_dark"
1818
  )
1819
 
1820
  return fig
 
1871
  }
1872
 
1873
  df_initial = load_data_csv_es()
1874
+ models = models = df_initial['model'].unique().tolist()
 
1875
  last_valid_model_selection = models.copy() # Per salvare l’ultima selezione valida
1876
  def enforce_model_selection(selected):
1877
  global last_valid_model_selection
 
1910
 
1911
  #FOR BAR
1912
  gr.Markdown("""## Section 1: Model - Data""")
1913
+
1914
  with gr.Row():
1915
+ with gr.Column(scale=1):
1916
+ with gr.Row():
1917
+ choose_metrics_bar = gr.Radio(
1918
+ choices=list(all_metrics.keys()),
1919
+ label="Select the metrics group that you want to use:",
1920
+ value="Qatch"
1921
+ )
1922
+
1923
+ with gr.Row():
1924
+ qatch_info = gr.HTML("""
1925
+ <div style='display: flex; align-items: center; margin-top: -8px; margin-bottom: 12px;'>
1926
+ <span
1927
+ title="Qatch metric info:
1928
+ Cell Precision: Fraction of predicted table cells also in the ground truth result. High means many correct predictions.
1929
+ Cell Recall: Fraction of ground truth cells retrieved by the prediction. High means relevant cells were captured.
1930
+ Tuple Constraint: Fraction of ground truth tuples matched exactly in output (schema, values, cardinality).
1931
+ Tuple Cardinality: Ratio of predicted to ground truth tuples. Checks only tuple count.
1932
+ Tuple Order: Spearman correlation between predicted and ground truth tuple ranks."
1933
+ style="margin-left: 6px; cursor: help; color: #00bfff; font-size: 16px; white-space: pre-line;"
1934
+ >Qatch metric info ℹ️</span>
1935
+ </div>
1936
+ """, visible=True)
1937
+
1938
+ external_info = gr.HTML("""
1939
+ <div style='display: flex; align-items: center; margin-top: -8px; margin-bottom: 12px;'>
1940
+ <span
1941
+ title="External metric info:
1942
+ Execution Accuracy: Checks if the predicted query returns exactly the same result as the ground truth query when executed. It is a binary metric: 1 if the output matches, 0 otherwise.
1943
+ Valid Efficiency Score: Evaluates the efficiency of a query by combining execution time and correctness. It rewards queries that are both accurate and fast."
1944
+ style="margin-left: 6px; cursor: help; color: #00bfff; font-size: 16px; white-space: pre-line;"
1945
+ >External metric info ℹ️</span>
1946
+ </div>
1947
+ """, visible=False)
1948
 
1949
  qatch_metric_multiselect_bar = gr.CheckboxGroup(
1950
  choices=list(qatch_metrics_dict.keys()),
 
1980
 
1981
  def toggle_metric_selector(selected_type):
1982
  if selected_type == "Qatch":
1983
+ return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True, value=list(qatch_metrics_dict.keys())), gr.update(visible=False, value=[])
1984
  else:
1985
+ return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False, value=[]), gr.update(visible=True, value=list(external_metrics_dict.keys()))
1986
 
1987
  output_plot = gr.Plot(visible=False)
1988
 
1989
  if(input_data['input_method'] == 'default'):
1990
  with gr.Row():
1991
+ lollipop_propietary(models)
1992
 
1993
  #FOR RADAR
1994
  gr.Markdown("""## Section 2: Model - Category""")
 
2056
  first = gr.Markdown(worst_first)
2057
 
2058
  with gr.Row():
2059
+ first_button = gr.Button("Show raw answer for 🥇")
2060
 
2061
  with gr.Row():
2062
  second = gr.Markdown(worst_second)
2063
 
2064
  with gr.Row():
2065
+ second_button = gr.Button("Show raw answer for 🥈")
2066
 
2067
  with gr.Row():
2068
  third = gr.Markdown(worst_third)
2069
 
2070
  with gr.Row():
2071
+ third_button = gr.Button("Show raw answer for 🥉")
2072
 
2073
  with gr.Column(scale=1):
2074
+ gr.Markdown("""## Raw Answer""")
2075
  row_answer_first = gr.Markdown(value=raw_first, visible=True)
2076
  row_answer_second = gr.Markdown(value=raw_second, visible=False)
2077
  row_answer_third = gr.Markdown(value=raw_third, visible=False)
 
2085
  value=models
2086
  )
2087
 
2088
+
2089
  with gr.Row():
2090
+ slicer = gr.Slider(minimum=0, maximum=max(df_initial["number_question"]), step=1, value=max(df_initial["number_question"]), label="Number of instances to visualize", elem_id="custom-slider")
2091
 
2092
  query_rate_plot = gr.Plot(value=update_query_rate(models, len(df_initial["number_question"].unique())))
2093
 
 
2155
  external_metric_select_bar.change(on_change_propietary, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot)
2156
  model_radio_bar.change(on_change_propietary, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, model_radio_bar], outputs=output_plot)
2157
  qatch_metric_multiselect_bar.change(fn=enforce_qatch_metrics_selection, inputs=qatch_metric_multiselect_bar, outputs=qatch_metric_multiselect_bar)
2158
+ choose_metrics_bar.change(fn=toggle_metric_selector, inputs=choose_metrics_bar, outputs=[qatch_info, external_info, qatch_metric_multiselect_bar, external_metric_select_bar])
2159
  external_metric_select_bar.change(fn=enforce_external_metric_selection, inputs=external_metric_select_bar, outputs=external_metric_select_bar)
2160
 
2161
  else:
 
2166
  model_multiselect_bar.change(on_change, inputs=[choose_metrics_bar,qatch_metric_multiselect_bar, external_metric_select_bar, group_radio, model_multiselect_bar], outputs=output_plot)
2167
  qatch_metric_multiselect_bar.change(fn=enforce_qatch_metrics_selection, inputs=qatch_metric_multiselect_bar, outputs=qatch_metric_multiselect_bar)
2168
  model_multiselect_bar.change(fn=enforce_model_selection, inputs=model_multiselect_bar, outputs=model_multiselect_bar)
2169
+ choose_metrics_bar.change(fn=toggle_metric_selector, inputs=choose_metrics_bar, outputs=[qatch_info, external_info, qatch_metric_multiselect_bar, external_metric_select_bar])
2170
  external_metric_select_bar.change(fn=enforce_external_metric_selection, inputs=external_metric_select_bar, outputs=external_metric_select_bar)
2171
 
2172
 
 
2207
  reset_data.click(fn=enable_disable, inputs=[gr.State(True)], outputs=[*model_checkboxes, submit_models_button, preview_output, submit_button, file_input, default_checkbox, table_selector, *table_outputs, open_model_selection])
2208
 
2209
 
2210
+ interface.launch(share = True)
evaluation_p_np_metrics.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0066e9d791af80b568c94926bcb74034354765d7713355a3f42353f4cd214f16
3
+ size 15614968
models.csv CHANGED
@@ -1,7 +1,6 @@
1
  name,code,price,image_path
2
- Meta LLAMA-8,llama-8,0.0,./models_logo/LLAMA.jpg
3
- DeepSeek-R1-Distill-Llama-70B,DeepSeek-R1-Distill-Llama-70B,0.0,./models_logo/DEEPSEEK.jpg
4
- CHAT GPT-3.5,gpt-3.5,0.0,models_logo/CHATGPT3_5.png
5
- CHAT GPT-4 mini,gpt-4o-mini,0.0,./models_logo/CHATGPT4mini.png
6
- CHAT GPT-o1-mini,o1-mini,0.0,./models_logo/CHATGPTo1mini.png
7
- QwQ,QwQ,0.0,./models_logo/QWQ.png
 
1
  name,code,price,image_path
2
+ Llma-8B,llama-8,0.0,./models_logo/LLAMA.jpg
3
+ DeepSeek-Llama-70B,DeepSeek-R1-Distill-Llama-70B,0.0,./models_logo/DEEPSEEK.jpg
4
+ GPT-3.5,gpt-3.5,0.0,models_logo/CHATGPT3_5.png
5
+ GPT-4o mini,gpt-4o-mini,0.0,./models_logo/CHATGPT4mini.png
6
+ Llma-70B,llama-70,0.0,./models_logo/LLAMA.jpg
 
qatch_logo.png ADDED

Git LFS Details

  • SHA256: e3af861ce00c5f4a597835dba30e0874fbbf17257c689a39deb3d37abce1ac00
  • Pointer size: 131 Bytes
  • Size of remote file: 909 kB
requirements.txt CHANGED
@@ -12,6 +12,7 @@ litellm==1.63.14
12
  together==1.4.6
13
  # Conditional dependency for Gradio (requires Python >=3.10)
14
  gradio>=5.20.1; python_version >= "3.10"
 
15
  accelerate>=0.26.0
16
 
17
  # Test dependencies
 
12
  together==1.4.6
13
  # Conditional dependency for Gradio (requires Python >=3.10)
14
  gradio>=5.20.1; python_version >= "3.10"
15
+ numpy==2.2.4; python_version >= "3.10"
16
  accelerate>=0.26.0
17
 
18
  # Test dependencies
style.css CHANGED
@@ -1,12 +1,20 @@
1
  /* Titoli principali h1 */
 
 
 
 
 
 
 
 
2
  .prose h1 {
3
- font-family: 'Playfair Display', serif;
4
  font-size: 3rem;
5
  font-weight: 600;
6
  text-transform: none;
7
  letter-spacing: 0.5px;
8
  text-align: center;
9
- color: #ffffff;
10
  padding: 20px;
11
  margin: 20px 0;
12
  position: relative;
@@ -17,7 +25,7 @@
17
  content: "";
18
  width: 60px;
19
  height: 4px;
20
- background: #d4c9cc;
21
  display: block;
22
  margin: 10px auto 0;
23
  border-radius: 2px;
@@ -25,19 +33,20 @@
25
 
26
  /* Titoli secondari h2 */
27
  .prose h2 {
28
- font-family: 'Playfair Display', serif;
29
- font-size: 2.2rem;
30
  font-weight: 500;
31
  letter-spacing: 0.3px;
32
- color: #ffffff;
33
- text-shadow: 1px 1px 3px rgba(0, 0, 0, 0.4);
34
  padding: 10px 0;
35
  margin: 10px 0 20px 0;
36
  text-align: left;
37
  }
38
 
 
39
  body, label, button, span, li, p, .prose {
40
- font-family: 'Playfair Display', serif;
41
  }
42
 
43
  #bar_plot, #line_plot {
@@ -47,8 +56,8 @@ body, label, button, span, li, p, .prose {
47
 
48
  .mirrored {
49
  display: inline-block;
50
- transform: scaleX(-1); /* Riflette il testo orizzontalmente */
51
- font-family: 'Poppins', sans-serif;
52
  font-size: 1.5rem;
53
  font-weight: 700;
54
  letter-spacing: 1px;
@@ -62,34 +71,33 @@ body, label, button, span, li, p, .prose {
62
  position: center;
63
  }
64
 
65
- .fish{
66
- font-family: 'Poppins', sans-serif;
67
- font-size: 1.5rem;
68
- font-weight: 700;
69
- letter-spacing: 1px;
70
- text-align: center;
71
- color: #222;
72
- background: linear-gradient(45deg, #1a41d9, #6c69d2);
73
- -webkit-background-clip: text;
74
- -webkit-text-fill-color: transparent;
75
- padding: 20px;
76
- margin: 20px 0;
77
- position: center;
78
  }
79
 
80
  .loading {
81
- font-family: 'Poppins', sans-serif;
82
  font-size: 2.7rem;
83
  font-weight: 700;
84
  text-transform: uppercase;
85
  letter-spacing: 1px;
86
  text-align: center;
87
  color: #222;
88
- background: linear-gradient(45deg, #40abe9, #1e99e5);
89
  -webkit-background-clip: text;
90
  -webkit-text-fill-color: transparent;
91
  padding: 20px;
92
- /*margin: 20px 0;*/
93
  position: center;
94
  }
95
 
@@ -112,7 +120,7 @@ body, label, button, span, li, p, .prose {
112
  }
113
 
114
  .sqlquery {
115
- background-color: #272822;
116
  color: #f8f8f2;
117
  font-family: 'Courier New', monospace;
118
  padding: 15px;
@@ -121,4 +129,27 @@ body, label, button, span, li, p, .prose {
121
  white-space: pre-wrap;
122
  word-wrap: break-word;
123
  box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  }
 
1
  /* Titoli principali h1 */
2
+ /*
3
+ 072436, 0A3048, 11547E, 1978B3, 38807D, 62B685
4
+ */
5
+ html {
6
+ filter: none !important;
7
+ background: white !important;
8
+ color: black !important;
9
+ }
10
  .prose h1 {
11
+ font-family: 'Inter', sans-serif;
12
  font-size: 3rem;
13
  font-weight: 600;
14
  text-transform: none;
15
  letter-spacing: 0.5px;
16
  text-align: center;
17
+ color: #072436;
18
  padding: 20px;
19
  margin: 20px 0;
20
  position: relative;
 
25
  content: "";
26
  width: 60px;
27
  height: 4px;
28
+ background: #072436;
29
  display: block;
30
  margin: 10px auto 0;
31
  border-radius: 2px;
 
33
 
34
  /* Titoli secondari h2 */
35
  .prose h2 {
36
+ font-family: 'Inter', sans-serif;
37
+ font-size: 2rem;
38
  font-weight: 500;
39
  letter-spacing: 0.3px;
40
+ color: #0A3048;
41
+ /*text-shadow: 1px 1px 3px rgba(0, 0, 0, 0.4);*/
42
  padding: 10px 0;
43
  margin: 10px 0 20px 0;
44
  text-align: left;
45
  }
46
 
47
+ /* Font base globale */
48
  body, label, button, span, li, p, .prose {
49
+ font-family: 'Inter', sans-serif;
50
  }
51
 
52
  #bar_plot, #line_plot {
 
56
 
57
  .mirrored {
58
  display: inline-block;
59
+ transform: scaleX(-1);
60
+ font-family: 'Inter', sans-serif;
61
  font-size: 1.5rem;
62
  font-weight: 700;
63
  letter-spacing: 1px;
 
71
  position: center;
72
  }
73
 
74
+ .fish {
75
+ font-family: 'Inter', sans-serif;
76
+ font-size: 1.5rem;
77
+ font-weight: 700;
78
+ letter-spacing: 1px;
79
+ text-align: center;
80
+ color: #222;
81
+ background: linear-gradient(45deg, #1a41d9, #6c69d2);
82
+ -webkit-background-clip: text;
83
+ -webkit-text-fill-color: transparent;
84
+ padding: 20px;
85
+ margin: 20px 0;
86
+ position: center;
87
  }
88
 
89
  .loading {
90
+ font-family: 'Inter', sans-serif;
91
  font-size: 2.7rem;
92
  font-weight: 700;
93
  text-transform: uppercase;
94
  letter-spacing: 1px;
95
  text-align: center;
96
  color: #222;
97
+ background: linear-gradient(45deg, #166CA2, #1978B3);
98
  -webkit-background-clip: text;
99
  -webkit-text-fill-color: transparent;
100
  padding: 20px;
 
101
  position: center;
102
  }
103
 
 
120
  }
121
 
122
  .sqlquery {
123
+ background-color: #38807D;
124
  color: #f8f8f2;
125
  font-family: 'Courier New', monospace;
126
  padding: 15px;
 
129
  white-space: pre-wrap;
130
  word-wrap: break-word;
131
  box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
132
+ }
133
+ .gr-slider input[type="range"] {
134
+ accent-color: #0077B6;
135
+ }
136
+
137
+ #custom-slider input[type="range"] {
138
+ background: linear-gradient(to right, #2F91B1, #2F91B1);
139
+ height: 8px;
140
+ border-radius: 5px;
141
+ }
142
+ /* Stile per il TextArea */
143
+ #custom-textarea textarea {
144
+ background-color: #DAE5FD; /* Colore di sfondo */
145
+ border: 2px solid #bdd1fe; /* Bordo arancione */
146
+ color: #072436; /* Testo scuro */
147
+ font-size: 16px;
148
+ padding: 10px;
149
+ border-radius: 8px;
150
+ }
151
+
152
+ /* Cambia colore del placeholder */
153
+ #custom-textarea textarea::placeholder {
154
+ color: #072436;
155
  }
utilities.py CHANGED
@@ -62,10 +62,11 @@ def read_api(api_key_path):
62
  def read_models_csv(file_path):
63
  # Reads a CSV file and returns a list of dictionaries
64
  models = [] # Change {} to []
65
- df = pd.read_csv(file_path)
66
- for _, row in df.iterrows():
67
- model_dict = row.to_dict()
68
- models.append(model_dict)
 
69
  return models
70
 
71
  def csv_to_dict(file_path):
@@ -105,35 +106,25 @@ def generate_some_samples(connector, tbl_name):
105
  except Exception as e:
106
  samples.append(f"Error: {e}")
107
  return samples
 
108
  def extract_tables_dict(pnp_path):
109
  tables_dict = {}
110
- # df = pd.read_csv(pnp_path)
111
- # with open(pnp_path, mode='r', encoding='utf-8') as file:
112
- # reader = csv.DictReader(file)
113
- # for row in reader:
114
- # tbl_name = row.get("tbl_name")
115
- # db_path = row.get("db_path")
116
- # if tbl_name and db_path:
117
- # print(db_path, tbl_name)
118
- # connector = SqliteConnector(relative_db_path=db_path, db_name=os.path.basename(db_path))
119
- # instances = generate_some_samples(connector, tbl_name)
120
- # if tbl_name not in tables_dict:
121
- # tables_dict[tbl_name] = []
122
- # tables_dict[tbl_name].extend(instances)
123
-
124
  with open(pnp_path, mode='r', encoding='utf-8') as file:
125
  reader = csv.DictReader(file)
 
126
  for row in reader:
127
  tbl_name = row.get("tbl_name")
128
- if tbl_name not in tables_dict:
129
- tables_dict[tbl_name] = []
130
- #tables_dict[tbl_name].append(row)
131
- return tables_dict
132
-
133
- def check_and_create_dir(db_path):
134
- # Check if the folder exists, and create it if it doesn't
135
- if not os.path.exists(db_path):
136
- os.makedirs(db_path)
137
- print(f"Folder created: {db_path}")
138
- else:
139
- print(f"Folder already exists: {db_path}")
 
 
 
62
  def read_models_csv(file_path):
63
  # Reads a CSV file and returns a list of dictionaries
64
  models = [] # Change {} to []
65
+ with open(file_path, mode="r", newline="") as file:
66
+ reader = csv.DictReader(file)
67
+ for row in reader:
68
+ row["price"] = float(row["price"]) # Convert price to float
69
+ models.append(row) # Append to the list
70
  return models
71
 
72
  def csv_to_dict(file_path):
 
106
  except Exception as e:
107
  samples.append(f"Error: {e}")
108
  return samples
109
+
110
  def extract_tables_dict(pnp_path):
111
  tables_dict = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  with open(pnp_path, mode='r', encoding='utf-8') as file:
113
  reader = csv.DictReader(file)
114
+ tbl_db_pairs = set() # Use a set to avoid duplicates
115
  for row in reader:
116
  tbl_name = row.get("tbl_name")
117
+ db_path = row.get("db_path")
118
+ if tbl_name and db_path:
119
+ tbl_db_pairs.add((tbl_name, db_path)) # Add the pair to the set
120
+ for tbl_name, db_path in list(tbl_db_pairs):
121
+ if tbl_name and db_path:
122
+ connector = sqlite3.connect(db_path)
123
+ query = f"SELECT * FROM {tbl_name} LIMIT 5"
124
+ try:
125
+ df = pd.read_sql_query(query, connector)
126
+ tables_dict[tbl_name] = df
127
+ except Exception as e:
128
+ tables_dict[tbl_name] = pd.DataFrame({"Error": [str(e)]}) # DataFrame con messaggio di errore
129
+
130
+ return tables_dict
utils_get_db_tables_info.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import sqlite3
3
  import re
4
 
 
5
  def utils_extract_db_schema_as_string(
6
  db_id, base_path, normalize=False, sql: str | None = None
7
  ):
 
2
  import sqlite3
3
  import re
4
 
5
+
6
  def utils_extract_db_schema_as_string(
7
  db_id, base_path, normalize=False, sql: str | None = None
8
  ):