Spaces:
Running
Running
db_schema , csv fix, check columns and number tables, eval markdown, style (#16)
Browse files- db_schema , csv fix, check columns and number tables, eval markdown, style (51d3b40e690563bc1c21a8c3139ca44a396861a6)
Co-authored-by: Francesco Giannuzzo <[email protected]>
- app.py +83 -40
- style.css +4 -2
- utilities.py +7 -5
app.py
CHANGED
@@ -81,9 +81,12 @@ def load_data(file, path, use_default):
|
|
81 |
try:
|
82 |
input_data["input_method"] = 'uploaded_file'
|
83 |
input_data["db_name"] = os.path.splitext(os.path.basename(file))[0]
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
87 |
input_data["data"] = us.load_data(file, input_data["db_name"])
|
88 |
df_current = input_data["data"]['data_frames'].get('MyTable', df_default) # Carica il DataFrame
|
89 |
if(input_data["data"]['data_frames'] and input_data["data"]["db"] is None): #for csv and xlsx files
|
@@ -303,6 +306,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
303 |
with select_table_acc:
|
304 |
previous_selection = gr.State([])
|
305 |
table_selector = gr.CheckboxGroup(choices=[], label="Select tables from the choosen database", value=[])
|
|
|
306 |
table_outputs = [gr.DataFrame(label=f"Table {i+1}", interactive=True, visible=False) for i in range(50)]
|
307 |
selected_table_names = gr.Textbox(label="Selected tables", visible=False, interactive=False)
|
308 |
|
@@ -310,55 +314,80 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
310 |
open_model_selection = gr.Button("Choose your models", interactive=False)
|
311 |
|
312 |
def update_table_list(data):
|
313 |
-
"""Dynamically updates the list of available tables."""
|
314 |
if isinstance(data, dict) and data:
|
315 |
table_names = []
|
|
|
|
|
|
|
316 |
|
317 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
318 |
table_names.append("All")
|
319 |
|
320 |
-
|
321 |
-
table_names.append("All") # In caso ci siano poche tabelle, ha senso mantenere "All"
|
322 |
|
323 |
-
|
324 |
-
|
|
|
|
|
|
|
|
|
|
|
325 |
|
326 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
327 |
|
328 |
def show_selected_tables(data, selected_tables):
|
329 |
updates = []
|
330 |
-
|
331 |
-
input_method = input_data['input_method']
|
332 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
333 |
allow_all = input_method == "default" or len(available_tables) < 6
|
|
|
334 |
selected_set = set(selected_tables)
|
335 |
tables_set = set(available_tables)
|
336 |
|
337 |
-
# ▶️
|
338 |
if allow_all:
|
339 |
if "All" in selected_set:
|
340 |
selected_tables = ["All"] + available_tables
|
341 |
elif selected_set == tables_set:
|
342 |
selected_tables = []
|
343 |
else:
|
344 |
-
#
|
345 |
selected_tables = [t for t in selected_tables if t in available_tables]
|
346 |
else:
|
347 |
-
#
|
348 |
selected_tables = [t for t in selected_tables if t in available_tables and t != "All"][:5]
|
349 |
|
350 |
-
#
|
351 |
tables = {name: data[name] for name in selected_tables if name in data}
|
352 |
|
353 |
for i, (name, df) in enumerate(tables.items()):
|
354 |
updates.append(gr.update(value=df, label=f"Table: {name}", visible=True, interactive=False))
|
|
|
355 |
for _ in range(len(tables), 50):
|
356 |
updates.append(gr.update(visible=False))
|
357 |
|
358 |
-
# ✅ Bottone abilitato solo se c'è almeno una tabella valida
|
359 |
updates.append(gr.update(interactive=bool(tables)))
|
360 |
|
361 |
-
# 🔄 Aggiorna la CheckboxGroup con logica coerente
|
362 |
if allow_all:
|
363 |
updates.insert(0, gr.update(
|
364 |
choices=["All"] + available_tables,
|
@@ -389,7 +418,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
389 |
return gr.update(value="", visible=False)
|
390 |
|
391 |
# Automatically updates the checkbox list when `data_state` changes
|
392 |
-
data_state.change(fn=update_table_list, inputs=[data_state], outputs=[table_selector])
|
393 |
|
394 |
# Updates the visible tables and the button state based on user selections
|
395 |
#table_selector.change(fn=show_selected_tables, inputs=[data_state, table_selector], outputs=table_outputs + [open_model_selection])
|
@@ -602,9 +631,20 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
602 |
{mirrored_symbols}
|
603 |
</div>
|
604 |
"""
|
605 |
-
|
606 |
-
#return f"{css_symbols}"+f"# Loading {percent}% #"+f"{mirrored_symbols}"
|
607 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
608 |
def qatch_flow():
|
609 |
#caching
|
610 |
global reset_flag
|
@@ -620,7 +660,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
620 |
reset_flag = False
|
621 |
for model in input_data['models']:
|
622 |
model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None)
|
623 |
-
yield gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
624 |
count=1
|
625 |
for _, row in predictions_dict[model].iterrows():
|
626 |
#for index, row in target_df.iterrows():
|
@@ -636,7 +676,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
636 |
<div style='font-size: 3rem'>➡️</div>
|
637 |
</div>
|
638 |
"""
|
639 |
-
yield gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
640 |
#time.sleep(0.02)
|
641 |
prediction = row['predicted_sql']
|
642 |
|
@@ -646,19 +686,19 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
646 |
<div class='sqlquery' font-family: 'Inter', sans-serif;>{prediction}</div>
|
647 |
</div>
|
648 |
"""
|
649 |
-
yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
650 |
-
yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
651 |
metrics_conc = target_df
|
652 |
if 'valid_efficiency_score' not in metrics_conc.columns:
|
653 |
metrics_conc['valid_efficiency_score'] = metrics_conc['VES']
|
654 |
-
|
|
|
655 |
else:
|
656 |
|
657 |
orchestrator_generator = OrchestratorGenerator()
|
658 |
# TODO: add to target_df column target_df["columns_used"], tables selection
|
659 |
# print(input_data['data']['db'])
|
660 |
#print(input_data['data']['selected_tables'])
|
661 |
-
#TODO s
|
662 |
target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_include=input_data['data']['selected_tables'])
|
663 |
#target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_includes=None)
|
664 |
|
@@ -666,10 +706,10 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
666 |
reset_flag = False
|
667 |
for model in input_data["models"]:
|
668 |
model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None)
|
669 |
-
yield gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
|
670 |
count=0
|
671 |
for index, row in target_df.iterrows():
|
672 |
-
if (reset_flag == False):
|
673 |
percent_complete = round(((index+1) / len(target_df)) * 100, 2)
|
674 |
load_text = f"{generate_loading_text(percent_complete)}"
|
675 |
|
@@ -680,9 +720,9 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
680 |
<div style='font-size: 3rem'>➡️</div>
|
681 |
</div>
|
682 |
"""
|
683 |
-
yield gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model]for model in model_list]
|
684 |
start_time = time.time()
|
685 |
-
samples = us.generate_some_samples(input_data[
|
686 |
|
687 |
schema_text = utils_get_db_tables_info.utils_extract_db_schema_as_string(
|
688 |
db_id = input_data["db_name"],
|
@@ -700,7 +740,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
700 |
answer = response['response']
|
701 |
|
702 |
end_time = time.time()
|
703 |
-
display_prediction = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Inter', sans-serif;'
|
704 |
<div style='display: flex; align-items: center;'>
|
705 |
<div style='font-size: 3rem'>➡️</div>
|
706 |
<div class='sqlquery' font-family: 'Inter', sans-serif;>{prediction}</div>
|
@@ -717,7 +757,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
717 |
'price':price,
|
718 |
'answer':answer,
|
719 |
'number_question':count,
|
720 |
-
'prompt'
|
721 |
}]).dropna(how="all") # Remove only completely empty rows
|
722 |
count=count+1
|
723 |
# TODO: use a for loop
|
@@ -730,10 +770,12 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
730 |
predictions_dict[model] = pd.concat([predictions_dict[model], new_row], ignore_index=True)
|
731 |
|
732 |
# yield gr.Textbox(), gr.Textbox(prediction), *[predictions_dict[model] for model in input_data["models"]], None
|
733 |
-
yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model]for model in model_list]
|
734 |
|
735 |
-
yield gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list]
|
736 |
# END
|
|
|
|
|
737 |
evaluator = OrchestratorEvaluator()
|
738 |
for model in input_data["models"]:
|
739 |
metrics_df_model = evaluator.evaluate_df(
|
@@ -747,8 +789,8 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
747 |
|
748 |
if 'valid_efficiency_score' not in metrics_conc.columns:
|
749 |
metrics_conc['valid_efficiency_score'] = metrics_conc['VES']
|
750 |
-
|
751 |
-
yield gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
|
752 |
|
753 |
# Loading Bar
|
754 |
with gr.Row():
|
@@ -771,8 +813,6 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
771 |
with gr.Column():
|
772 |
with gr.Column():
|
773 |
prediction_display = gr.Markdown()
|
774 |
-
|
775 |
-
evaluation_loading = gr.Markdown() # 𓆡
|
776 |
|
777 |
dataframe_per_model = {}
|
778 |
|
@@ -793,6 +833,9 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
793 |
dataframe_per_model[model] = gr.DataFrame()
|
794 |
# download_pred_model = gr.DownloadButton(label="Download Prediction per Model", visible=False)
|
795 |
|
|
|
|
|
|
|
796 |
def change_tab():
|
797 |
return [gr.update(visible=(model in input_data["models"])) for model in model_list]
|
798 |
|
@@ -809,7 +852,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
809 |
submit_models_button.click(
|
810 |
fn=qatch_flow,
|
811 |
inputs=[],
|
812 |
-
outputs=[model_logo, variable, question_display, prediction_display, metrics_df] + list(dataframe_per_model.values())
|
813 |
)
|
814 |
|
815 |
submit_models_button.click(
|
|
|
81 |
try:
|
82 |
input_data["input_method"] = 'uploaded_file'
|
83 |
input_data["db_name"] = os.path.splitext(os.path.basename(file))[0]
|
84 |
+
if file.endswith('.sqlite'):
|
85 |
+
#return 'Error: The uploaded file is not a valid SQLite database.'
|
86 |
+
input_data["data_path"] = file #os.path.join(".", "data", "data_interface",f"{input_data['db_name']}.sqlite")
|
87 |
+
else:
|
88 |
+
#change path
|
89 |
+
input_data["data_path"] = os.path.join(".", f"{input_data['db_name']}.sqlite")
|
90 |
input_data["data"] = us.load_data(file, input_data["db_name"])
|
91 |
df_current = input_data["data"]['data_frames'].get('MyTable', df_default) # Carica il DataFrame
|
92 |
if(input_data["data"]['data_frames'] and input_data["data"]["db"] is None): #for csv and xlsx files
|
|
|
306 |
with select_table_acc:
|
307 |
previous_selection = gr.State([])
|
308 |
table_selector = gr.CheckboxGroup(choices=[], label="Select tables from the choosen database", value=[])
|
309 |
+
excluded_tables_info = gr.HTML(label="Non-selectable tables (too many columns)", visible=False)
|
310 |
table_outputs = [gr.DataFrame(label=f"Table {i+1}", interactive=True, visible=False) for i in range(50)]
|
311 |
selected_table_names = gr.Textbox(label="Selected tables", visible=False, interactive=False)
|
312 |
|
|
|
314 |
open_model_selection = gr.Button("Choose your models", interactive=False)
|
315 |
|
316 |
def update_table_list(data):
|
317 |
+
"""Dynamically updates the list of available tables and excluded ones."""
|
318 |
if isinstance(data, dict) and data:
|
319 |
table_names = []
|
320 |
+
excluded_tables = []
|
321 |
+
|
322 |
+
data_frames = input_data['data'].get('data_frames', {})
|
323 |
|
324 |
+
available_tables = []
|
325 |
+
for name, df in data.items():
|
326 |
+
df_real = data_frames.get(name, None)
|
327 |
+
if df_real is not None and df_real.shape[1] > 15:
|
328 |
+
excluded_tables.append(name)
|
329 |
+
else:
|
330 |
+
available_tables.append(name)
|
331 |
+
|
332 |
+
if input_data['input_method'] == "default" or len(available_tables) < 6:
|
333 |
table_names.append("All")
|
334 |
|
335 |
+
table_names.extend(available_tables)
|
|
|
336 |
|
337 |
+
# Prepara il testo da mostrare
|
338 |
+
if excluded_tables:
|
339 |
+
excluded_text = "<b>⚠️ The following tables have more than 15 columns and cannot be selected:</b><br>" + "<br>".join(f"- {t}" for t in excluded_tables)
|
340 |
+
excluded_visible = True
|
341 |
+
else:
|
342 |
+
excluded_text = ""
|
343 |
+
excluded_visible = False
|
344 |
|
345 |
+
return [
|
346 |
+
gr.update(choices=table_names, value=[]), # CheckboxGroup update
|
347 |
+
gr.update(value=excluded_text, visible=excluded_visible) # HTML display update
|
348 |
+
]
|
349 |
+
|
350 |
+
return [
|
351 |
+
gr.update(choices=[], value=[]),
|
352 |
+
gr.update(value="", visible=False)
|
353 |
+
]
|
354 |
|
355 |
def show_selected_tables(data, selected_tables):
|
356 |
updates = []
|
357 |
+
data_frames = input_data['data'].get('data_frames', {})
|
|
|
358 |
|
359 |
+
available_tables = []
|
360 |
+
for name, df in data.items():
|
361 |
+
df_real = data_frames.get(name)
|
362 |
+
if df_real is not None and df_real.shape[1] <= 15:
|
363 |
+
available_tables.append(name)
|
364 |
+
|
365 |
+
input_method = input_data['input_method']
|
366 |
allow_all = input_method == "default" or len(available_tables) < 6
|
367 |
+
|
368 |
selected_set = set(selected_tables)
|
369 |
tables_set = set(available_tables)
|
370 |
|
|
|
371 |
if allow_all:
|
372 |
if "All" in selected_set:
|
373 |
selected_tables = ["All"] + available_tables
|
374 |
elif selected_set == tables_set:
|
375 |
selected_tables = []
|
376 |
else:
|
|
|
377 |
selected_tables = [t for t in selected_tables if t in available_tables]
|
378 |
else:
|
|
|
379 |
selected_tables = [t for t in selected_tables if t in available_tables and t != "All"][:5]
|
380 |
|
|
|
381 |
tables = {name: data[name] for name in selected_tables if name in data}
|
382 |
|
383 |
for i, (name, df) in enumerate(tables.items()):
|
384 |
updates.append(gr.update(value=df, label=f"Table: {name}", visible=True, interactive=False))
|
385 |
+
|
386 |
for _ in range(len(tables), 50):
|
387 |
updates.append(gr.update(visible=False))
|
388 |
|
|
|
389 |
updates.append(gr.update(interactive=bool(tables)))
|
390 |
|
|
|
391 |
if allow_all:
|
392 |
updates.insert(0, gr.update(
|
393 |
choices=["All"] + available_tables,
|
|
|
418 |
return gr.update(value="", visible=False)
|
419 |
|
420 |
# Automatically updates the checkbox list when `data_state` changes
|
421 |
+
data_state.change(fn=update_table_list, inputs=[data_state], outputs=[table_selector, excluded_tables_info])
|
422 |
|
423 |
# Updates the visible tables and the button state based on user selections
|
424 |
#table_selector.change(fn=show_selected_tables, inputs=[data_state, table_selector], outputs=table_outputs + [open_model_selection])
|
|
|
631 |
{mirrored_symbols}
|
632 |
</div>
|
633 |
"""
|
|
|
|
|
634 |
|
635 |
+
def generate_eval_text(text):
|
636 |
+
symbols = "𓆡 "
|
637 |
+
mirrored_symbols = f'<span class="mirrored">{symbols.strip()}</span>'
|
638 |
+
css_symbols = f'<span class="fish">{symbols.strip()}</span>'
|
639 |
+
return f"""
|
640 |
+
<div class='barcontainer'>
|
641 |
+
{css_symbols}
|
642 |
+
<span class='loading' style="font-family: 'Inter', sans-serif;">
|
643 |
+
{text}
|
644 |
+
</span>
|
645 |
+
{mirrored_symbols}
|
646 |
+
</div>
|
647 |
+
"""
|
648 |
def qatch_flow():
|
649 |
#caching
|
650 |
global reset_flag
|
|
|
660 |
reset_flag = False
|
661 |
for model in input_data['models']:
|
662 |
model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None)
|
663 |
+
yield gr.Markdown(visible=False), gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
664 |
count=1
|
665 |
for _, row in predictions_dict[model].iterrows():
|
666 |
#for index, row in target_df.iterrows():
|
|
|
676 |
<div style='font-size: 3rem'>➡️</div>
|
677 |
</div>
|
678 |
"""
|
679 |
+
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
680 |
#time.sleep(0.02)
|
681 |
prediction = row['predicted_sql']
|
682 |
|
|
|
686 |
<div class='sqlquery' font-family: 'Inter', sans-serif;>{prediction}</div>
|
687 |
</div>
|
688 |
"""
|
689 |
+
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
690 |
+
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
691 |
metrics_conc = target_df
|
692 |
if 'valid_efficiency_score' not in metrics_conc.columns:
|
693 |
metrics_conc['valid_efficiency_score'] = metrics_conc['VES']
|
694 |
+
eval_text = generate_eval_text("End evaluation")
|
695 |
+
yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
696 |
else:
|
697 |
|
698 |
orchestrator_generator = OrchestratorGenerator()
|
699 |
# TODO: add to target_df column target_df["columns_used"], tables selection
|
700 |
# print(input_data['data']['db'])
|
701 |
#print(input_data['data']['selected_tables'])
|
|
|
702 |
target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_include=input_data['data']['selected_tables'])
|
703 |
#target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_includes=None)
|
704 |
|
|
|
706 |
reset_flag = False
|
707 |
for model in input_data["models"]:
|
708 |
model_image_path = next((m["image_path"] for m in model_list_dict if m["code"] == model), None)
|
709 |
+
yield gr.Markdown(visible=False), gr.Image(model_image_path), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
|
710 |
count=0
|
711 |
for index, row in target_df.iterrows():
|
712 |
+
if (reset_flag == False):
|
713 |
percent_complete = round(((index+1) / len(target_df)) * 100, 2)
|
714 |
load_text = f"{generate_loading_text(percent_complete)}"
|
715 |
|
|
|
720 |
<div style='font-size: 3rem'>➡️</div>
|
721 |
</div>
|
722 |
"""
|
723 |
+
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model]for model in model_list]
|
724 |
start_time = time.time()
|
725 |
+
samples = us.generate_some_samples(input_data["data_path"], row["tbl_name"])
|
726 |
|
727 |
schema_text = utils_get_db_tables_info.utils_extract_db_schema_as_string(
|
728 |
db_id = input_data["db_name"],
|
|
|
740 |
answer = response['response']
|
741 |
|
742 |
end_time = time.time()
|
743 |
+
display_prediction = f"""<div class='loading' style='font-size: 1.7rem; font-family: 'Inter', sans-serif;'>Predicted SQL:</div>
|
744 |
<div style='display: flex; align-items: center;'>
|
745 |
<div style='font-size: 3rem'>➡️</div>
|
746 |
<div class='sqlquery' font-family: 'Inter', sans-serif;>{prediction}</div>
|
|
|
757 |
'price':price,
|
758 |
'answer':answer,
|
759 |
'number_question':count,
|
760 |
+
'prompt': prompt_to_send
|
761 |
}]).dropna(how="all") # Remove only completely empty rows
|
762 |
count=count+1
|
763 |
# TODO: use a for loop
|
|
|
770 |
predictions_dict[model] = pd.concat([predictions_dict[model], new_row], ignore_index=True)
|
771 |
|
772 |
# yield gr.Textbox(), gr.Textbox(prediction), *[predictions_dict[model] for model in input_data["models"]], None
|
773 |
+
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model]for model in model_list]
|
774 |
|
775 |
+
yield gr.Markdown(), gr.Image(), gr.Markdown(load_text), gr.Markdown(), gr.Markdown(display_prediction), metrics_conc, *[predictions_dict[model] for model in model_list]
|
776 |
# END
|
777 |
+
eval_text = generate_eval_text("Evaluation")
|
778 |
+
yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
|
779 |
evaluator = OrchestratorEvaluator()
|
780 |
for model in input_data["models"]:
|
781 |
metrics_df_model = evaluator.evaluate_df(
|
|
|
789 |
|
790 |
if 'valid_efficiency_score' not in metrics_conc.columns:
|
791 |
metrics_conc['valid_efficiency_score'] = metrics_conc['VES']
|
792 |
+
eval_text = generate_eval_text("End evaluation")
|
793 |
+
yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model] for model in model_list]
|
794 |
|
795 |
# Loading Bar
|
796 |
with gr.Row():
|
|
|
813 |
with gr.Column():
|
814 |
with gr.Column():
|
815 |
prediction_display = gr.Markdown()
|
|
|
|
|
816 |
|
817 |
dataframe_per_model = {}
|
818 |
|
|
|
833 |
dataframe_per_model[model] = gr.DataFrame()
|
834 |
# download_pred_model = gr.DownloadButton(label="Download Prediction per Model", visible=False)
|
835 |
|
836 |
+
|
837 |
+
evaluation_loading = gr.Markdown()
|
838 |
+
|
839 |
def change_tab():
|
840 |
return [gr.update(visible=(model in input_data["models"])) for model in model_list]
|
841 |
|
|
|
852 |
submit_models_button.click(
|
853 |
fn=qatch_flow,
|
854 |
inputs=[],
|
855 |
+
outputs=[evaluation_loading, model_logo, variable, question_display, prediction_display, metrics_df] + list(dataframe_per_model.values())
|
856 |
)
|
857 |
|
858 |
submit_models_button.click(
|
style.css
CHANGED
@@ -57,13 +57,15 @@ body, label, button, span, li, p, .prose {
|
|
57 |
.mirrored {
|
58 |
display: inline-block;
|
59 |
transform: scaleX(-1);
|
|
|
|
|
60 |
font-family: 'Inter', sans-serif;
|
61 |
font-size: 1.5rem;
|
62 |
font-weight: 700;
|
63 |
letter-spacing: 1px;
|
64 |
text-align: center;
|
65 |
color: #222;
|
66 |
-
background: linear-gradient(45deg, #1a41d9, #
|
67 |
-webkit-background-clip: text;
|
68 |
-webkit-text-fill-color: transparent;
|
69 |
padding: 20px;
|
@@ -78,7 +80,7 @@ body, label, button, span, li, p, .prose {
|
|
78 |
letter-spacing: 1px;
|
79 |
text-align: center;
|
80 |
color: #222;
|
81 |
-
background: linear-gradient(45deg, #1a41d9, #
|
82 |
-webkit-background-clip: text;
|
83 |
-webkit-text-fill-color: transparent;
|
84 |
padding: 20px;
|
|
|
57 |
.mirrored {
|
58 |
display: inline-block;
|
59 |
transform: scaleX(-1);
|
60 |
+
position: relative;
|
61 |
+
top: -9.5px;
|
62 |
font-family: 'Inter', sans-serif;
|
63 |
font-size: 1.5rem;
|
64 |
font-weight: 700;
|
65 |
letter-spacing: 1px;
|
66 |
text-align: center;
|
67 |
color: #222;
|
68 |
+
background: linear-gradient(45deg, #1a41d9, #06ffe6);
|
69 |
-webkit-background-clip: text;
|
70 |
-webkit-text-fill-color: transparent;
|
71 |
padding: 20px;
|
|
|
80 |
letter-spacing: 1px;
|
81 |
text-align: center;
|
82 |
color: #222;
|
83 |
+
background: linear-gradient(45deg, #1a41d9, #06ffe6);
|
84 |
-webkit-background-clip: text;
|
85 |
-webkit-text-fill-color: transparent;
|
86 |
padding: 20px;
|
utilities.py
CHANGED
@@ -94,16 +94,18 @@ def increment_filename(filename):
|
|
94 |
return new_base + ext
|
95 |
|
96 |
def prepare_prompt(prompt, question, schema, samples):
|
97 |
-
prompt = prompt.replace("{
|
98 |
-
prompt += f" Some
|
99 |
return prompt
|
100 |
|
101 |
-
def generate_some_samples(
|
|
|
102 |
samples = []
|
103 |
query = f"SELECT * FROM {tbl_name} LIMIT 3"
|
104 |
try:
|
105 |
-
sample_data =
|
106 |
-
samples.append(
|
|
|
107 |
except Exception as e:
|
108 |
samples.append(f"Error: {e}")
|
109 |
return samples
|
|
|
94 |
return new_base + ext
|
95 |
|
96 |
def prepare_prompt(prompt, question, schema, samples):
|
97 |
+
prompt = prompt.replace("{db_schema}", schema).replace("{question}", question)
|
98 |
+
prompt += f" Some instances: {samples}"
|
99 |
return prompt
|
100 |
|
101 |
+
def generate_some_samples(file_path, tbl_name):
|
102 |
+
conn = sqlite3.connect(file_path)
|
103 |
samples = []
|
104 |
query = f"SELECT * FROM {tbl_name} LIMIT 3"
|
105 |
try:
|
106 |
+
sample_data = pd.read_sql_query(query, conn)
|
107 |
+
samples.append(sample_data.to_dict(orient="records"))
|
108 |
+
#samples.append(str(sample_data))
|
109 |
except Exception as e:
|
110 |
samples.append(f"Error: {e}")
|
111 |
return samples
|