Spaces:
Sleeping
Sleeping
db_schema , csv fix, check columns and number tables, eval markdown, style (#16)
Browse files- db_schema , csv fix, check columns and number tables, eval markdown, style (51d3b40e690563bc1c21a8c3139ca44a396861a6)
Co-authored-by: Francesco Giannuzzo <[email protected]>
- app.py +83 -40
- style.css +4 -2
- utilities.py +7 -5
app.py
CHANGED
|
@@ -81,9 +81,12 @@ def load_data(file, path, use_default):
|
|
| 81 |
try:
|
| 82 |
input_data["input_method"] = 'uploaded_file'
|
| 83 |
input_data["db_name"] = os.path.splitext(os.path.basename(file))[0]
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
| 87 |
input_data["data"] = us.load_data(file, input_data["db_name"])
|
| 88 |
df_current = input_data["data"]['data_frames'].get('MyTable', df_default) # Carica il DataFrame
|
| 89 |
if(input_data["data"]['data_frames'] and input_data["data"]["db"] is None): #for csv and xlsx files
|
|
@@ -303,6 +306,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 303 |
with select_table_acc:
|
| 304 |
previous_selection = gr.State([])
|
| 305 |
table_selector = gr.CheckboxGroup(choices=[], label="Select tables from the choosen database", value=[])
|
|
|
|
| 306 |
table_outputs = [gr.DataFrame(label=f"Table {i+1}", interactive=True, visible=False) for i in range(50)]
|
| 307 |
selected_table_names = gr.Textbox(label="Selected tables", visible=False, interactive=False)
|
| 308 |
|
|
@@ -310,55 +314,80 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 310 |
open_model_selection = gr.Button("Choose your models", interactive=False)
|
| 311 |
|
| 312 |
def update_table_list(data):
|
| 313 |
-
"""Dynamically updates the list of available tables."""
|
| 314 |
if isinstance(data, dict) and data:
|
| 315 |
table_names = []
|
|
|
|
|
|
|
|
|
|
| 316 |
|
| 317 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
table_names.append("All")
|
| 319 |
|
| 320 |
-
|
| 321 |
-
table_names.append("All") # In caso ci siano poche tabelle, ha senso mantenere "All"
|
| 322 |
|
| 323 |
-
|
| 324 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
|
| 326 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
|
| 328 |
def show_selected_tables(data, selected_tables):
|
| 329 |
updates = []
|
| 330 |
-
|
| 331 |
-
input_method = input_data['input_method']
|
| 332 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
allow_all = input_method == "default" or len(available_tables) < 6
|
|
|
|
| 334 |
selected_set = set(selected_tables)
|
| 335 |
tables_set = set(available_tables)
|
| 336 |
|
| 337 |
-
# ▶️
|
| 338 |
if allow_all:
|
| 339 |
if "All" in selected_set:
|
| 340 |
selected_tables = ["All"] + available_tables
|
| 341 |
elif selected_set == tables_set:
|
| 342 |
selected_tables = []
|
| 343 |
else:
|
| 344 |
-
#
|
| 345 |
selected_tables = [t for t in selected_tables if t in available_tables]
|
| 346 |
else:
|
| 347 |
-
#
|
| 348 |
selected_tables = [t for t in selected_tables if t in available_tables and t != "All"][:5]
|
| 349 |
|
| 350 |
-
#
|
| 351 |
tables = {name: data[name] for name in selected_tables if name in data}
|
| 352 |
|
| 353 |
for i, (name, df) in enumerate(tables.items()):
|
| 354 |
updates.append(gr.update(value=df, label=f"Table: {name}", visible=True, interactive=False))
|
|
|
|
| 355 |
for _ in range(len(tables), 50):
|
| 356 |
updates.append(gr.update(visible=False))
|
| 357 |
|
| 358 |
-
# ✅ Bottone abilitato solo se c'è almeno una tabella valida
|
| 359 |
updates.append(gr.update(interactive=bool(tables)))
|
| 360 |
|
| 361 |
-
# 🔄 Aggiorna la CheckboxGroup con logica coerente
|
| 362 |
if allow_all:
|
| 363 |
updates.insert(0, gr.update(
|
| 364 |
choices=["All"] + available_tables,
|
|
@@ -389,7 +418,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 389 |
return gr.update(value="", visible=False)
|
| 390 |
|
| 391 |
# Automatically updates the checkbox list when `data_state` changes
|
| 392 |
-
data_state.change(fn=update_table_list, inputs=[data_state], outputs=[table_selector])
|
| 393 |
|
| 394 |
# Updates the visible tables and the button state based on user selections
|
| 395 |
#table_selector.change(fn=show_selected_tables, inputs=[data_state, table_selector], outputs=table_outputs + [open_model_selection])
|
|
@@ -602,9 +631,20 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 602 |
{mirrored_symbols}
|
| 603 |
</div>
|
| 604 |
"""
|
| 605 |
-
|
| 606 |
-
#return f"{css_symbols}"+f"# Loading {percent}% #"+f"{mirrored_symbols}"
|
| 607 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 608 |
def qatch_flow():
|
| 609 |
#caching
|
| 610 |
global reset_flag
|
|
@@ -620,7 +660,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 620 |
reset_flag = False
|
| 621 |
for model in input_data['models']:
|
| 622 |
model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None)
|
| 623 |
-
yield gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
| 624 |
count=1
|
| 625 |
for _, row in predictions_dict[model].iterrows():
|
| 626 |
#for index, row in target_df.iterrows():
|
|
@@ -636,7 +676,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 636 |
<div style='font-size: 3rem'>➡️</div>
|
| 637 |
</div>
|
| 638 |
"""
|
| 639 |
-
yield gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
| 640 |
#time.sleep(0.02)
|
| 641 |
prediction = row['predicted_sql']
|
| 642 |
|
|
@@ -646,19 +686,19 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 646 |
<div class='sqlquery' font-family: 'Inter', sans-serif;>{prediction}</div>
|
| 647 |
</div>
|
| 648 |
"""
|
| 649 |
-
yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
| 650 |
-
yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
| 651 |
metrics_conc = target_df
|
| 652 |
if 'valid_efficiency_score' not in metrics_conc.columns:
|
| 653 |
metrics_conc['valid_efficiency_score'] = metrics_conc['VES']
|
| 654 |
-
|
|
|
|
| 655 |
else:
|
| 656 |
|
| 657 |
orchestrator_generator = OrchestratorGenerator()
|
| 658 |
# TODO: add to target_df column target_df["columns_used"], tables selection
|
| 659 |
# print(input_data['data']['db'])
|
| 660 |
#print(input_data['data']['selected_tables'])
|
| 661 |
-
#TODO s
|
| 662 |
target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_include=input_data['data']['selected_tables'])
|
| 663 |
#target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_includes=None)
|
| 664 |
|
|
@@ -666,10 +706,10 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 666 |
reset_flag = False
|
| 667 |
for model in input_data["models"]:
|
| 668 |
model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None)
|
| 669 |
-
yield gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
|
| 670 |
count=0
|
| 671 |
for index, row in target_df.iterrows():
|
| 672 |
-
if (reset_flag == False):
|
| 673 |
percent_complete = round(((index+1) / len(target_df)) * 100, 2)
|
| 674 |
load_text = f"{generate_loading_text(percent_complete)}"
|
| 675 |
|
|
@@ -680,9 +720,9 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 680 |
<div style='font-size: 3rem'>➡️</div>
|
| 681 |
</div>
|
| 682 |
"""
|
| 683 |
-
yield gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model]for model in model_list]
|
| 684 |
start_time = time.time()
|
| 685 |
-
samples = us.generate_some_samples(input_data[
|
| 686 |
|
| 687 |
schema_text = utils_get_db_tables_info.utils_extract_db_schema_as_string(
|
| 688 |
db_id = input_data["db_name"],
|
|
@@ -700,7 +740,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 700 |
answer = response['response']
|
| 701 |
|
| 702 |
end_time = time.time()
|
| 703 |
-
display_prediction = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Inter', sans-serif;'
|
| 704 |
<div style='display: flex; align-items: center;'>
|
| 705 |
<div style='font-size: 3rem'>➡️</div>
|
| 706 |
<div class='sqlquery' font-family: 'Inter', sans-serif;>{prediction}</div>
|
|
@@ -717,7 +757,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 717 |
'price':price,
|
| 718 |
'answer':answer,
|
| 719 |
'number_question':count,
|
| 720 |
-
'prompt'
|
| 721 |
}]).dropna(how="all") # Remove only completely empty rows
|
| 722 |
count=count+1
|
| 723 |
# TODO: use a for loop
|
|
@@ -730,10 +770,12 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 730 |
predictions_dict[model] = pd.concat([predictions_dict[model], new_row], ignore_index=True)
|
| 731 |
|
| 732 |
# yield gr.Textbox(), gr.Textbox(prediction), *[predictions_dict[model] for model in input_data["models"]], None
|
| 733 |
-
yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model]for model in model_list]
|
| 734 |
|
| 735 |
-
yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list]
|
| 736 |
# END
|
|
|
|
|
|
|
| 737 |
evaluator = OrchestratorEvaluator()
|
| 738 |
for model in input_data["models"]:
|
| 739 |
metrics_df_model = evaluator.evaluate_df(
|
|
@@ -747,8 +789,8 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 747 |
|
| 748 |
if 'valid_efficiency_score' not in metrics_conc.columns:
|
| 749 |
metrics_conc['valid_efficiency_score'] = metrics_conc['VES']
|
| 750 |
-
|
| 751 |
-
yield gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
|
| 752 |
|
| 753 |
# Loading Bar
|
| 754 |
with gr.Row():
|
|
@@ -771,8 +813,6 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 771 |
with gr.Column():
|
| 772 |
with gr.Column():
|
| 773 |
prediction_display = gr.Markdown()
|
| 774 |
-
|
| 775 |
-
evaluation_loading = gr.Markdown() # 𓆡
|
| 776 |
|
| 777 |
dataframe_per_model = {}
|
| 778 |
|
|
@@ -793,6 +833,9 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 793 |
dataframe_per_model[model] = gr.DataFrame()
|
| 794 |
# download_pred_model = gr.DownloadButton(label="Download Prediction per Model", visible=False)
|
| 795 |
|
|
|
|
|
|
|
|
|
|
| 796 |
def change_tab():
|
| 797 |
return [gr.update(visible=(model in input_data["models"])) for model in model_list]
|
| 798 |
|
|
@@ -809,7 +852,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 809 |
submit_models_button.click(
|
| 810 |
fn=qatch_flow,
|
| 811 |
inputs=[],
|
| 812 |
-
outputs=[model_logo, variable, question_display, prediction_display, metrics_df] + list(dataframe_per_model.values())
|
| 813 |
)
|
| 814 |
|
| 815 |
submit_models_button.click(
|
|
|
|
| 81 |
try:
|
| 82 |
input_data["input_method"] = 'uploaded_file'
|
| 83 |
input_data["db_name"] = os.path.splitext(os.path.basename(file))[0]
|
| 84 |
+
if file.endswith('.sqlite'):
|
| 85 |
+
#return 'Error: The uploaded file is not a valid SQLite database.'
|
| 86 |
+
input_data["data_path"] = file #os.path.join(".", "data", "data_interface",f"{input_data['db_name']}.sqlite")
|
| 87 |
+
else:
|
| 88 |
+
#change path
|
| 89 |
+
input_data["data_path"] = os.path.join(".", f"{input_data['db_name']}.sqlite")
|
| 90 |
input_data["data"] = us.load_data(file, input_data["db_name"])
|
| 91 |
df_current = input_data["data"]['data_frames'].get('MyTable', df_default) # Carica il DataFrame
|
| 92 |
if(input_data["data"]['data_frames'] and input_data["data"]["db"] is None): #for csv and xlsx files
|
|
|
|
| 306 |
with select_table_acc:
|
| 307 |
previous_selection = gr.State([])
|
| 308 |
table_selector = gr.CheckboxGroup(choices=[], label="Select tables from the choosen database", value=[])
|
| 309 |
+
excluded_tables_info = gr.HTML(label="Non-selectable tables (too many columns)", visible=False)
|
| 310 |
table_outputs = [gr.DataFrame(label=f"Table {i+1}", interactive=True, visible=False) for i in range(50)]
|
| 311 |
selected_table_names = gr.Textbox(label="Selected tables", visible=False, interactive=False)
|
| 312 |
|
|
|
|
| 314 |
open_model_selection = gr.Button("Choose your models", interactive=False)
|
| 315 |
|
| 316 |
def update_table_list(data):
|
| 317 |
+
"""Dynamically updates the list of available tables and excluded ones."""
|
| 318 |
if isinstance(data, dict) and data:
|
| 319 |
table_names = []
|
| 320 |
+
excluded_tables = []
|
| 321 |
+
|
| 322 |
+
data_frames = input_data['data'].get('data_frames', {})
|
| 323 |
|
| 324 |
+
available_tables = []
|
| 325 |
+
for name, df in data.items():
|
| 326 |
+
df_real = data_frames.get(name, None)
|
| 327 |
+
if df_real is not None and df_real.shape[1] > 15:
|
| 328 |
+
excluded_tables.append(name)
|
| 329 |
+
else:
|
| 330 |
+
available_tables.append(name)
|
| 331 |
+
|
| 332 |
+
if input_data['input_method'] == "default" or len(available_tables) < 6:
|
| 333 |
table_names.append("All")
|
| 334 |
|
| 335 |
+
table_names.extend(available_tables)
|
|
|
|
| 336 |
|
| 337 |
+
# Prepara il testo da mostrare
|
| 338 |
+
if excluded_tables:
|
| 339 |
+
excluded_text = "<b>⚠️ The following tables have more than 15 columns and cannot be selected:</b><br>" + "<br>".join(f"- {t}" for t in excluded_tables)
|
| 340 |
+
excluded_visible = True
|
| 341 |
+
else:
|
| 342 |
+
excluded_text = ""
|
| 343 |
+
excluded_visible = False
|
| 344 |
|
| 345 |
+
return [
|
| 346 |
+
gr.update(choices=table_names, value=[]), # CheckboxGroup update
|
| 347 |
+
gr.update(value=excluded_text, visible=excluded_visible) # HTML display update
|
| 348 |
+
]
|
| 349 |
+
|
| 350 |
+
return [
|
| 351 |
+
gr.update(choices=[], value=[]),
|
| 352 |
+
gr.update(value="", visible=False)
|
| 353 |
+
]
|
| 354 |
|
| 355 |
def show_selected_tables(data, selected_tables):
|
| 356 |
updates = []
|
| 357 |
+
data_frames = input_data['data'].get('data_frames', {})
|
|
|
|
| 358 |
|
| 359 |
+
available_tables = []
|
| 360 |
+
for name, df in data.items():
|
| 361 |
+
df_real = data_frames.get(name)
|
| 362 |
+
if df_real is not None and df_real.shape[1] <= 15:
|
| 363 |
+
available_tables.append(name)
|
| 364 |
+
|
| 365 |
+
input_method = input_data['input_method']
|
| 366 |
allow_all = input_method == "default" or len(available_tables) < 6
|
| 367 |
+
|
| 368 |
selected_set = set(selected_tables)
|
| 369 |
tables_set = set(available_tables)
|
| 370 |
|
|
|
|
| 371 |
if allow_all:
|
| 372 |
if "All" in selected_set:
|
| 373 |
selected_tables = ["All"] + available_tables
|
| 374 |
elif selected_set == tables_set:
|
| 375 |
selected_tables = []
|
| 376 |
else:
|
|
|
|
| 377 |
selected_tables = [t for t in selected_tables if t in available_tables]
|
| 378 |
else:
|
|
|
|
| 379 |
selected_tables = [t for t in selected_tables if t in available_tables and t != "All"][:5]
|
| 380 |
|
|
|
|
| 381 |
tables = {name: data[name] for name in selected_tables if name in data}
|
| 382 |
|
| 383 |
for i, (name, df) in enumerate(tables.items()):
|
| 384 |
updates.append(gr.update(value=df, label=f"Table: {name}", visible=True, interactive=False))
|
| 385 |
+
|
| 386 |
for _ in range(len(tables), 50):
|
| 387 |
updates.append(gr.update(visible=False))
|
| 388 |
|
|
|
|
| 389 |
updates.append(gr.update(interactive=bool(tables)))
|
| 390 |
|
|
|
|
| 391 |
if allow_all:
|
| 392 |
updates.insert(0, gr.update(
|
| 393 |
choices=["All"] + available_tables,
|
|
|
|
| 418 |
return gr.update(value="", visible=False)
|
| 419 |
|
| 420 |
# Automatically updates the checkbox list when `data_state` changes
|
| 421 |
+
data_state.change(fn=update_table_list, inputs=[data_state], outputs=[table_selector, excluded_tables_info])
|
| 422 |
|
| 423 |
# Updates the visible tables and the button state based on user selections
|
| 424 |
#table_selector.change(fn=show_selected_tables, inputs=[data_state, table_selector], outputs=table_outputs + [open_model_selection])
|
|
|
|
| 631 |
{mirrored_symbols}
|
| 632 |
</div>
|
| 633 |
"""
|
|
|
|
|
|
|
| 634 |
|
| 635 |
+
def generate_eval_text(text):
|
| 636 |
+
symbols = "𓆡 "
|
| 637 |
+
mirrored_symbols = f'<span class="mirrored">{symbols.strip()}</span>'
|
| 638 |
+
css_symbols = f'<span class="fish">{symbols.strip()}</span>'
|
| 639 |
+
return f"""
|
| 640 |
+
<div class='barcontainer'>
|
| 641 |
+
{css_symbols}
|
| 642 |
+
<span class='loading' style="font-family: 'Inter', sans-serif;">
|
| 643 |
+
{text}
|
| 644 |
+
</span>
|
| 645 |
+
{mirrored_symbols}
|
| 646 |
+
</div>
|
| 647 |
+
"""
|
| 648 |
def qatch_flow():
|
| 649 |
#caching
|
| 650 |
global reset_flag
|
|
|
|
| 660 |
reset_flag = False
|
| 661 |
for model in input_data['models']:
|
| 662 |
model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None)
|
| 663 |
+
yield gr.Markdown(visible=False), gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
| 664 |
count=1
|
| 665 |
for _, row in predictions_dict[model].iterrows():
|
| 666 |
#for index, row in target_df.iterrows():
|
|
|
|
| 676 |
<div style='font-size: 3rem'>➡️</div>
|
| 677 |
</div>
|
| 678 |
"""
|
| 679 |
+
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
| 680 |
#time.sleep(0.02)
|
| 681 |
prediction = row['predicted_sql']
|
| 682 |
|
|
|
|
| 686 |
<div class='sqlquery' font-family: 'Inter', sans-serif;>{prediction}</div>
|
| 687 |
</div>
|
| 688 |
"""
|
| 689 |
+
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
| 690 |
+
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
| 691 |
metrics_conc = target_df
|
| 692 |
if 'valid_efficiency_score' not in metrics_conc.columns:
|
| 693 |
metrics_conc['valid_efficiency_score'] = metrics_conc['VES']
|
| 694 |
+
eval_text = generate_eval_text("End evaluation")
|
| 695 |
+
yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
| 696 |
else:
|
| 697 |
|
| 698 |
orchestrator_generator = OrchestratorGenerator()
|
| 699 |
# TODO: add to target_df column target_df["columns_used"], tables selection
|
| 700 |
# print(input_data['data']['db'])
|
| 701 |
#print(input_data['data']['selected_tables'])
|
|
|
|
| 702 |
target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_include=input_data['data']['selected_tables'])
|
| 703 |
#target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_includes=None)
|
| 704 |
|
|
|
|
| 706 |
reset_flag = False
|
| 707 |
for model in input_data["models"]:
|
| 708 |
model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None)
|
| 709 |
+
yield gr.Markdown(visible=False), gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
|
| 710 |
count=0
|
| 711 |
for index, row in target_df.iterrows():
|
| 712 |
+
if (reset_flag == False):
|
| 713 |
percent_complete = round(((index+1) / len(target_df)) * 100, 2)
|
| 714 |
load_text = f"{generate_loading_text(percent_complete)}"
|
| 715 |
|
|
|
|
| 720 |
<div style='font-size: 3rem'>➡️</div>
|
| 721 |
</div>
|
| 722 |
"""
|
| 723 |
+
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model]for model in model_list]
|
| 724 |
start_time = time.time()
|
| 725 |
+
samples = us.generate_some_samples(input_data["data_path"], row["tbl_name"])
|
| 726 |
|
| 727 |
schema_text = utils_get_db_tables_info.utils_extract_db_schema_as_string(
|
| 728 |
db_id = input_data["db_name"],
|
|
|
|
| 740 |
answer = response['response']
|
| 741 |
|
| 742 |
end_time = time.time()
|
| 743 |
+
display_prediction = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Inter', sans-serif;'>Predicted SQL:</div>
|
| 744 |
<div style='display: flex; align-items: center;'>
|
| 745 |
<div style='font-size: 3rem'>➡️</div>
|
| 746 |
<div class='sqlquery' font-family: 'Inter', sans-serif;>{prediction}</div>
|
|
|
|
| 757 |
'price':price,
|
| 758 |
'answer':answer,
|
| 759 |
'number_question':count,
|
| 760 |
+
'prompt': prompt_to_send
|
| 761 |
}]).dropna(how="all") # Remove only completely empty rows
|
| 762 |
count=count+1
|
| 763 |
# TODO: use a for loop
|
|
|
|
| 770 |
predictions_dict[model] = pd.concat([predictions_dict[model], new_row], ignore_index=True)
|
| 771 |
|
| 772 |
# yield gr.Textbox(), gr.Textbox(prediction), *[predictions_dict[model] for model in input_data["models"]], None
|
| 773 |
+
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model]for model in model_list]
|
| 774 |
|
| 775 |
+
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list]
|
| 776 |
# END
|
| 777 |
+
eval_text = generate_eval_text("Evaluation")
|
| 778 |
+
yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
|
| 779 |
evaluator = OrchestratorEvaluator()
|
| 780 |
for model in input_data["models"]:
|
| 781 |
metrics_df_model = evaluator.evaluate_df(
|
|
|
|
| 789 |
|
| 790 |
if 'valid_efficiency_score' not in metrics_conc.columns:
|
| 791 |
metrics_conc['valid_efficiency_score'] = metrics_conc['VES']
|
| 792 |
+
eval_text = generate_eval_text("End evaluation")
|
| 793 |
+
yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
|
| 794 |
|
| 795 |
# Loading Bar
|
| 796 |
with gr.Row():
|
|
|
|
| 813 |
with gr.Column():
|
| 814 |
with gr.Column():
|
| 815 |
prediction_display = gr.Markdown()
|
|
|
|
|
|
|
| 816 |
|
| 817 |
dataframe_per_model = {}
|
| 818 |
|
|
|
|
| 833 |
dataframe_per_model[model] = gr.DataFrame()
|
| 834 |
# download_pred_model = gr.DownloadButton(label="Download Prediction per Model", visible=False)
|
| 835 |
|
| 836 |
+
|
| 837 |
+
evaluation_loading = gr.Markdown()
|
| 838 |
+
|
| 839 |
def change_tab():
|
| 840 |
return [gr.update(visible=(model in input_data["models"])) for model in model_list]
|
| 841 |
|
|
|
|
| 852 |
submit_models_button.click(
|
| 853 |
fn=qatch_flow,
|
| 854 |
inputs=[],
|
| 855 |
+
outputs=[evaluation_loading, model_logo, variable, question_display, prediction_display, metrics_df] + list(dataframe_per_model.values())
|
| 856 |
)
|
| 857 |
|
| 858 |
submit_models_button.click(
|
style.css
CHANGED
|
@@ -57,13 +57,15 @@ body, label, button, span, li, p, .prose {
|
|
| 57 |
.mirrored {
|
| 58 |
display: inline-block;
|
| 59 |
transform: scaleX(-1);
|
|
|
|
|
|
|
| 60 |
font-family: 'Inter', sans-serif;
|
| 61 |
font-size: 1.5rem;
|
| 62 |
font-weight: 700;
|
| 63 |
letter-spacing: 1px;
|
| 64 |
text-align: center;
|
| 65 |
color: #222;
|
| 66 |
-
background: linear-gradient(45deg, #1a41d9, #
|
| 67 |
-webkit-background-clip: text;
|
| 68 |
-webkit-text-fill-color: transparent;
|
| 69 |
padding: 20px;
|
|
@@ -78,7 +80,7 @@ body, label, button, span, li, p, .prose {
|
|
| 78 |
letter-spacing: 1px;
|
| 79 |
text-align: center;
|
| 80 |
color: #222;
|
| 81 |
-
background: linear-gradient(45deg, #1a41d9, #
|
| 82 |
-webkit-background-clip: text;
|
| 83 |
-webkit-text-fill-color: transparent;
|
| 84 |
padding: 20px;
|
|
|
|
| 57 |
.mirrored {
|
| 58 |
display: inline-block;
|
| 59 |
transform: scaleX(-1);
|
| 60 |
+
position: relative;
|
| 61 |
+
top: -9.5px;
|
| 62 |
font-family: 'Inter', sans-serif;
|
| 63 |
font-size: 1.5rem;
|
| 64 |
font-weight: 700;
|
| 65 |
letter-spacing: 1px;
|
| 66 |
text-align: center;
|
| 67 |
color: #222;
|
| 68 |
+
background: linear-gradient(45deg, #1a41d9, #06ffe6);
|
| 69 |
-webkit-background-clip: text;
|
| 70 |
-webkit-text-fill-color: transparent;
|
| 71 |
padding: 20px;
|
|
|
|
| 80 |
letter-spacing: 1px;
|
| 81 |
text-align: center;
|
| 82 |
color: #222;
|
| 83 |
+
background: linear-gradient(45deg, #1a41d9, #06ffe6);
|
| 84 |
-webkit-background-clip: text;
|
| 85 |
-webkit-text-fill-color: transparent;
|
| 86 |
padding: 20px;
|
utilities.py
CHANGED
|
@@ -94,16 +94,18 @@ def increment_filename(filename):
|
|
| 94 |
return new_base + ext
|
| 95 |
|
| 96 |
def prepare_prompt(prompt, question, schema, samples):
|
| 97 |
-
prompt = prompt.replace("{
|
| 98 |
-
prompt += f" Some
|
| 99 |
return prompt
|
| 100 |
|
| 101 |
-
def generate_some_samples(
|
|
|
|
| 102 |
samples = []
|
| 103 |
query = f"SELECT * FROM {tbl_name} LIMIT 3"
|
| 104 |
try:
|
| 105 |
-
sample_data =
|
| 106 |
-
samples.append(
|
|
|
|
| 107 |
except Exception as e:
|
| 108 |
samples.append(f"Error: {e}")
|
| 109 |
return samples
|
|
|
|
| 94 |
return new_base + ext
|
| 95 |
|
| 96 |
def prepare_prompt(prompt, question, schema, samples):
|
| 97 |
+
prompt = prompt.replace("{db_schema}", schema).replace("{question}", question)
|
| 98 |
+
prompt += f" Some instances: {samples}"
|
| 99 |
return prompt
|
| 100 |
|
| 101 |
+
def generate_some_samples(file_path, tbl_name):
|
| 102 |
+
conn = sqlite3.connect(file_path)
|
| 103 |
samples = []
|
| 104 |
query = f"SELECT * FROM {tbl_name} LIMIT 3"
|
| 105 |
try:
|
| 106 |
+
sample_data = pd.read_sql_query(query, conn)
|
| 107 |
+
samples.append(sample_data.to_dict(orient="records"))
|
| 108 |
+
#samples.append(str(sample_data))
|
| 109 |
except Exception as e:
|
| 110 |
samples.append(f"Error: {e}")
|
| 111 |
return samples
|