Spaces:
Sleeping
Sleeping
Fix prompts buttons, and NL2SQL bug (#24)
Browse files- Fix prompts buttons, and NL2SQL bug (b8f53f4140ce72bf889c039fa072989834ee8d73)
Co-authored-by: Francesco Giannuzzo <[email protected]>
- app.py +48 -63
- utilities.py +10 -4
- utils_get_db_tables_info.py +31 -3
app.py
CHANGED
|
@@ -509,9 +509,13 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 509 |
selected_models = [model for model, selected in zip(model_list, model_selections) if selected]
|
| 510 |
input_data['models'] = selected_models
|
| 511 |
button_state = bool(selected_models and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
|
| 512 |
-
return selected_models, gr.update(open=True, visible=True), gr.update(interactive=button_state)
|
| 513 |
|
| 514 |
# Add the Textbox to the interface
|
|
|
|
|
|
|
|
|
|
|
|
|
| 515 |
prompt = gr.TextArea(
|
| 516 |
label="Customise the prompt for selected models here or leave the default one.",
|
| 517 |
placeholder=prompt_default,
|
|
@@ -522,17 +526,20 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 522 |
|
| 523 |
# Submit button (initially disabled)
|
| 524 |
with gr.Row():
|
| 525 |
-
submit_models_button = gr.Button("Submit Models
|
| 526 |
-
submit_models_button_tqa = gr.Button("Submit Models for TQA task", interactive=False)
|
| 527 |
|
| 528 |
def check_prompt(prompt):
|
| 529 |
#TODO
|
| 530 |
missing_elements = []
|
| 531 |
if(prompt==""):
|
| 532 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 533 |
button_state = bool(len(input_data['models']) > 0 and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
|
| 534 |
else:
|
| 535 |
-
input_data["prompt"]=prompt
|
| 536 |
if "{db_schema}" not in prompt:
|
| 537 |
missing_elements.append("{db_schema}")
|
| 538 |
if "{question}" not in prompt:
|
|
@@ -543,21 +550,21 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 543 |
value=f"<div style='text-align: center; font-size: 18px; font-weight: bold;'>"
|
| 544 |
f"❌ Missing {', '.join(missing_elements)} in the prompt ❌</div>",
|
| 545 |
visible=True
|
| 546 |
-
), gr.update(interactive=button_state)
|
| 547 |
-
return gr.update(visible=False),
|
| 548 |
|
| 549 |
-
prompt.change(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button
|
| 550 |
# Link checkboxes to selection events
|
| 551 |
for checkbox in model_checkboxes:
|
| 552 |
checkbox.change(
|
| 553 |
fn=get_selected_models,
|
| 554 |
inputs=model_checkboxes,
|
| 555 |
-
outputs=[selected_models_output, select_model_acc, submit_models_button
|
| 556 |
)
|
| 557 |
prompt.change(
|
| 558 |
fn=get_selected_models,
|
| 559 |
inputs=model_checkboxes,
|
| 560 |
-
outputs=[selected_models_output, select_model_acc, submit_models_button
|
| 561 |
)
|
| 562 |
|
| 563 |
submit_models_button.click(
|
|
@@ -566,11 +573,6 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 566 |
outputs=[selected_models_output, select_model_acc, qatch_acc]
|
| 567 |
)
|
| 568 |
|
| 569 |
-
submit_models_button_tqa.click(
|
| 570 |
-
fn=lambda *args: (get_selected_models(*args), gr.update(open=False, visible=True), gr.update(open=True, visible=True)),
|
| 571 |
-
inputs=model_checkboxes,
|
| 572 |
-
outputs=[selected_models_output, select_model_acc, qatch_acc]
|
| 573 |
-
)
|
| 574 |
def change_flag():
|
| 575 |
global flag_TQA
|
| 576 |
flag_TQA = True
|
|
@@ -579,8 +581,14 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 579 |
global flag_TQA
|
| 580 |
flag_TQA = False
|
| 581 |
|
| 582 |
-
|
| 583 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 584 |
|
| 585 |
def enable_disable(enable):
|
| 586 |
return (
|
|
@@ -592,7 +600,6 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 592 |
gr.update(interactive=enable),
|
| 593 |
gr.update(interactive=enable),
|
| 594 |
*[gr.update(interactive=enable) for _ in table_outputs],
|
| 595 |
-
gr.update(interactive=enable),
|
| 596 |
gr.update(interactive=enable)
|
| 597 |
)
|
| 598 |
|
|
@@ -610,24 +617,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 610 |
default_checkbox,
|
| 611 |
table_selector,
|
| 612 |
*table_outputs,
|
| 613 |
-
open_model_selection
|
| 614 |
-
submit_models_button_tqa
|
| 615 |
-
]
|
| 616 |
-
)
|
| 617 |
-
submit_models_button_tqa.click(
|
| 618 |
-
fn=enable_disable,
|
| 619 |
-
inputs=[gr.State(False)],
|
| 620 |
-
outputs=[
|
| 621 |
-
*model_checkboxes,
|
| 622 |
-
submit_models_button,
|
| 623 |
-
preview_output,
|
| 624 |
-
submit_button,
|
| 625 |
-
file_input,
|
| 626 |
-
default_checkbox,
|
| 627 |
-
table_selector,
|
| 628 |
-
*table_outputs,
|
| 629 |
-
open_model_selection,
|
| 630 |
-
submit_models_button_tqa
|
| 631 |
]
|
| 632 |
)
|
| 633 |
|
|
@@ -645,8 +635,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 645 |
default_checkbox,
|
| 646 |
table_selector,
|
| 647 |
*table_outputs,
|
| 648 |
-
open_model_selection
|
| 649 |
-
submit_models_button_tqa
|
| 650 |
]
|
| 651 |
)
|
| 652 |
|
|
@@ -749,13 +738,28 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 749 |
yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
| 750 |
|
| 751 |
else:
|
|
|
|
| 752 |
orchestrator_generator = OrchestratorGenerator()
|
| 753 |
target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_include=input_data['data']['selected_tables'])
|
| 754 |
|
| 755 |
#create target_df[target_answer]
|
| 756 |
if flag_TQA :
|
| 757 |
-
if (input_data["prompt"] == prompt_default):
|
| 758 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 759 |
target_df = us.extract_answer(target_df)
|
| 760 |
|
| 761 |
predictor = ModelPrediction()
|
|
@@ -766,6 +770,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 766 |
count=0
|
| 767 |
for index, row in target_df.iterrows():
|
| 768 |
if (reset_flag == False):
|
|
|
|
| 769 |
percent_complete = round(((index+1) / len(target_df)) * 100, 2)
|
| 770 |
load_text = f"{generate_loading_text(percent_complete)}"
|
| 771 |
|
|
@@ -780,7 +785,6 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 780 |
#samples = us.generate_some_samples(input_data["data_path"], row["tbl_name"])
|
| 781 |
model_to_send = None if not flag_TQA else model
|
| 782 |
|
| 783 |
-
|
| 784 |
db_schema_text = utils_get_db_tables_info.utils_extract_db_schema_as_string(
|
| 785 |
db_id = input_data["db_name"],
|
| 786 |
base_path = input_data["data_path"],
|
|
@@ -806,11 +810,11 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 806 |
prompt=f"{prompt_to_send}",
|
| 807 |
task=task
|
| 808 |
)
|
|
|
|
| 809 |
prediction = response['response_parsed']
|
| 810 |
price = response['cost']
|
| 811 |
answer = response['response']
|
| 812 |
|
| 813 |
-
end_time = time.time()
|
| 814 |
if flag_TQA:
|
| 815 |
task_string = "Answer"
|
| 816 |
else:
|
|
@@ -857,6 +861,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 857 |
evaluator = OrchestratorEvaluator()
|
| 858 |
|
| 859 |
for model in input_data["models"]:
|
|
|
|
| 860 |
if not flag_TQA:
|
| 861 |
metrics_df_model = evaluator.evaluate_df(
|
| 862 |
df=predictions_dict[model],
|
|
@@ -920,11 +925,6 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 920 |
inputs=[],
|
| 921 |
outputs=[tab_dict[model] for model in model_list] # Update TabItem visibility
|
| 922 |
)
|
| 923 |
-
submit_models_button_tqa.click(
|
| 924 |
-
change_tab,
|
| 925 |
-
inputs=[],
|
| 926 |
-
outputs=[tab_dict[model] for model in model_list] # Update TabItem visibility
|
| 927 |
-
)
|
| 928 |
|
| 929 |
selected_models_display = gr.JSON(label="Final input data", visible=False)
|
| 930 |
metrics_df = gr.DataFrame(visible=False)
|
|
@@ -936,20 +936,10 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 936 |
outputs=[evaluation_loading, model_logo, variable, question_display, prediction_display, metrics_df] + list(dataframe_per_model.values())
|
| 937 |
)
|
| 938 |
|
| 939 |
-
submit_models_button_tqa.click(
|
| 940 |
-
fn=qatch_flow_nl_sql,
|
| 941 |
-
inputs=[],
|
| 942 |
-
outputs=[evaluation_loading, model_logo, variable, question_display, prediction_display, metrics_df] + list(dataframe_per_model.values())
|
| 943 |
-
)
|
| 944 |
-
|
| 945 |
submit_models_button.click(
|
| 946 |
fn=lambda: gr.update(value=input_data),
|
| 947 |
outputs=[selected_models_display]
|
| 948 |
)
|
| 949 |
-
submit_models_button_tqa.click(
|
| 950 |
-
fn=lambda: gr.update(value=input_data),
|
| 951 |
-
outputs=[selected_models_display]
|
| 952 |
-
)
|
| 953 |
|
| 954 |
# Works for METRICS
|
| 955 |
metrics_df.change(fn=change_text, inputs=[metrics_df], outputs=[metrics_df_out])
|
|
@@ -972,10 +962,6 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 972 |
fn=lambda: gr.update(visible=False),
|
| 973 |
outputs=[download_metrics]
|
| 974 |
)
|
| 975 |
-
submit_models_button_tqa.click(
|
| 976 |
-
fn=lambda: gr.update(visible=False),
|
| 977 |
-
outputs=[download_metrics]
|
| 978 |
-
)
|
| 979 |
|
| 980 |
def refresh():
|
| 981 |
global reset_flag
|
|
@@ -1007,8 +993,7 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
|
|
| 1007 |
default_checkbox,
|
| 1008 |
table_selector,
|
| 1009 |
*table_outputs,
|
| 1010 |
-
open_model_selection
|
| 1011 |
-
submit_models_button_tqa
|
| 1012 |
]
|
| 1013 |
)
|
| 1014 |
|
|
|
|
| 509 |
selected_models = [model for model, selected in zip(model_list, model_selections) if selected]
|
| 510 |
input_data['models'] = selected_models
|
| 511 |
button_state = bool(selected_models and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
|
| 512 |
+
return selected_models, gr.update(open=True, visible=True), gr.update(interactive=button_state)
|
| 513 |
|
| 514 |
# Add the Textbox to the interface
|
| 515 |
+
with gr.Row():
|
| 516 |
+
button_prompt_nlsql = gr.Button("Choose NL2SQL task")
|
| 517 |
+
button_prompt_tqa = gr.Button("Choose TQA task")
|
| 518 |
+
|
| 519 |
prompt = gr.TextArea(
|
| 520 |
label="Customise the prompt for selected models here or leave the default one.",
|
| 521 |
placeholder=prompt_default,
|
|
|
|
| 526 |
|
| 527 |
# Submit button (initially disabled)
|
| 528 |
with gr.Row():
|
| 529 |
+
submit_models_button = gr.Button("Submit Models", interactive=False)
|
|
|
|
| 530 |
|
| 531 |
def check_prompt(prompt):
|
| 532 |
#TODO
|
| 533 |
missing_elements = []
|
| 534 |
if(prompt==""):
|
| 535 |
+
global flag_TQA
|
| 536 |
+
if not flag_TQA:
|
| 537 |
+
input_data["prompt"] = prompt_default
|
| 538 |
+
else:
|
| 539 |
+
input_data["prompt"] = prompt_default_tqa
|
| 540 |
button_state = bool(len(input_data['models']) > 0 and '{db_schema}' in input_data["prompt"] and '{question}' in input_data["prompt"])
|
| 541 |
else:
|
| 542 |
+
input_data["prompt"] = prompt
|
| 543 |
if "{db_schema}" not in prompt:
|
| 544 |
missing_elements.append("{db_schema}")
|
| 545 |
if "{question}" not in prompt:
|
|
|
|
| 550 |
value=f"<div style='text-align: center; font-size: 18px; font-weight: bold;'>"
|
| 551 |
f"❌ Missing {', '.join(missing_elements)} in the prompt ❌</div>",
|
| 552 |
visible=True
|
| 553 |
+
), gr.update(interactive=button_state), gr.TextArea(placeholder=input_data["prompt"])
|
| 554 |
+
return gr.update(visible=False), gr.update(interactive=button_state), gr.TextArea(placeholder=input_data["prompt"])
|
| 555 |
|
| 556 |
+
prompt.change(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button])
|
| 557 |
# Link checkboxes to selection events
|
| 558 |
for checkbox in model_checkboxes:
|
| 559 |
checkbox.change(
|
| 560 |
fn=get_selected_models,
|
| 561 |
inputs=model_checkboxes,
|
| 562 |
+
outputs=[selected_models_output, select_model_acc, submit_models_button]
|
| 563 |
)
|
| 564 |
prompt.change(
|
| 565 |
fn=get_selected_models,
|
| 566 |
inputs=model_checkboxes,
|
| 567 |
+
outputs=[selected_models_output, select_model_acc, submit_models_button]
|
| 568 |
)
|
| 569 |
|
| 570 |
submit_models_button.click(
|
|
|
|
| 573 |
outputs=[selected_models_output, select_model_acc, qatch_acc]
|
| 574 |
)
|
| 575 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 576 |
def change_flag():
|
| 577 |
global flag_TQA
|
| 578 |
flag_TQA = True
|
|
|
|
| 581 |
global flag_TQA
|
| 582 |
flag_TQA = False
|
| 583 |
|
| 584 |
+
button_prompt_tqa.click(fn = change_flag, inputs=[], outputs=[])
|
| 585 |
+
|
| 586 |
+
button_prompt_nlsql.click(fn = dis_flag, inputs=[], outputs=[])
|
| 587 |
+
|
| 588 |
+
button_prompt_tqa.click(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button, prompt])
|
| 589 |
+
|
| 590 |
+
button_prompt_nlsql.click(fn=check_prompt, inputs=[prompt], outputs=[warning_prompt, submit_models_button, prompt])
|
| 591 |
+
|
| 592 |
|
| 593 |
def enable_disable(enable):
|
| 594 |
return (
|
|
|
|
| 600 |
gr.update(interactive=enable),
|
| 601 |
gr.update(interactive=enable),
|
| 602 |
*[gr.update(interactive=enable) for _ in table_outputs],
|
|
|
|
| 603 |
gr.update(interactive=enable)
|
| 604 |
)
|
| 605 |
|
|
|
|
| 617 |
default_checkbox,
|
| 618 |
table_selector,
|
| 619 |
*table_outputs,
|
| 620 |
+
open_model_selection
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 621 |
]
|
| 622 |
)
|
| 623 |
|
|
|
|
| 635 |
default_checkbox,
|
| 636 |
table_selector,
|
| 637 |
*table_outputs,
|
| 638 |
+
open_model_selection
|
|
|
|
| 639 |
]
|
| 640 |
)
|
| 641 |
|
|
|
|
| 738 |
yield gr.Markdown(eval_text, visible=True), gr.Image(), gr.Markdown(), gr.Markdown(), gr.Markdown(), metrics_conc, *[predictions_dict[model][columns_to_visulize] for model in model_list]
|
| 739 |
|
| 740 |
else:
|
| 741 |
+
global flag_TQA
|
| 742 |
orchestrator_generator = OrchestratorGenerator()
|
| 743 |
target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_include=input_data['data']['selected_tables'])
|
| 744 |
|
| 745 |
#create target_df[target_answer]
|
| 746 |
if flag_TQA :
|
| 747 |
+
# if (input_data["prompt"] == prompt_default):
|
| 748 |
+
# input_data["prompt"] = prompt_default_tqa
|
| 749 |
+
|
| 750 |
+
target_df['db_schema'] = target_df.apply(
|
| 751 |
+
lambda row: utils_get_db_tables_info.utils_extract_db_schema_as_string(
|
| 752 |
+
db_id=input_data["db_name"],
|
| 753 |
+
base_path=input_data["data_path"],
|
| 754 |
+
normalize=False,
|
| 755 |
+
sql=row["query"],
|
| 756 |
+
get_insert_into=True,
|
| 757 |
+
model=None,
|
| 758 |
+
prompt=input_data["prompt"].format(question=row["question"], db_schema="")
|
| 759 |
+
),
|
| 760 |
+
axis=1
|
| 761 |
+
)
|
| 762 |
+
|
| 763 |
target_df = us.extract_answer(target_df)
|
| 764 |
|
| 765 |
predictor = ModelPrediction()
|
|
|
|
| 770 |
count=0
|
| 771 |
for index, row in target_df.iterrows():
|
| 772 |
if (reset_flag == False):
|
| 773 |
+
global flag_TQA
|
| 774 |
percent_complete = round(((index+1) / len(target_df)) * 100, 2)
|
| 775 |
load_text = f"{generate_loading_text(percent_complete)}"
|
| 776 |
|
|
|
|
| 785 |
#samples = us.generate_some_samples(input_data["data_path"], row["tbl_name"])
|
| 786 |
model_to_send = None if not flag_TQA else model
|
| 787 |
|
|
|
|
| 788 |
db_schema_text = utils_get_db_tables_info.utils_extract_db_schema_as_string(
|
| 789 |
db_id = input_data["db_name"],
|
| 790 |
base_path = input_data["data_path"],
|
|
|
|
| 810 |
prompt=f"{prompt_to_send}",
|
| 811 |
task=task
|
| 812 |
)
|
| 813 |
+
end_time = time.time()
|
| 814 |
prediction = response['response_parsed']
|
| 815 |
price = response['cost']
|
| 816 |
answer = response['response']
|
| 817 |
|
|
|
|
| 818 |
if flag_TQA:
|
| 819 |
task_string = "Answer"
|
| 820 |
else:
|
|
|
|
| 861 |
evaluator = OrchestratorEvaluator()
|
| 862 |
|
| 863 |
for model in input_data["models"]:
|
| 864 |
+
global flag_TQA
|
| 865 |
if not flag_TQA:
|
| 866 |
metrics_df_model = evaluator.evaluate_df(
|
| 867 |
df=predictions_dict[model],
|
|
|
|
| 925 |
inputs=[],
|
| 926 |
outputs=[tab_dict[model] for model in model_list] # Update TabItem visibility
|
| 927 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 928 |
|
| 929 |
selected_models_display = gr.JSON(label="Final input data", visible=False)
|
| 930 |
metrics_df = gr.DataFrame(visible=False)
|
|
|
|
| 936 |
outputs=[evaluation_loading, model_logo, variable, question_display, prediction_display, metrics_df] + list(dataframe_per_model.values())
|
| 937 |
)
|
| 938 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 939 |
submit_models_button.click(
|
| 940 |
fn=lambda: gr.update(value=input_data),
|
| 941 |
outputs=[selected_models_display]
|
| 942 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 943 |
|
| 944 |
# Works for METRICS
|
| 945 |
metrics_df.change(fn=change_text, inputs=[metrics_df], outputs=[metrics_df_out])
|
|
|
|
| 962 |
fn=lambda: gr.update(visible=False),
|
| 963 |
outputs=[download_metrics]
|
| 964 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 965 |
|
| 966 |
def refresh():
|
| 967 |
global reset_flag
|
|
|
|
| 993 |
default_checkbox,
|
| 994 |
table_selector,
|
| 995 |
*table_outputs,
|
| 996 |
+
open_model_selection
|
|
|
|
| 997 |
]
|
| 998 |
)
|
| 999 |
|
utilities.py
CHANGED
|
@@ -8,6 +8,7 @@ import os
|
|
| 8 |
from qatch.connectors.sqlite_connector import SqliteConnector
|
| 9 |
from qatch.evaluate_dataset.metrics_evaluators import CellPrecision, CellRecall, ExecutionAccuracy, TupleCardinality, TupleConstraint, TupleOrder, ValidEfficiencyScore
|
| 10 |
import qatch.evaluate_dataset.orchestrator_evaluator as eva
|
|
|
|
| 11 |
#import tiktoken
|
| 12 |
from transformers import AutoTokenizer
|
| 13 |
|
|
@@ -151,11 +152,16 @@ def extract_answer(df):
|
|
| 151 |
answers = []
|
| 152 |
for _, row in df.iterrows():
|
| 153 |
query = row["query"]
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
answers.append(answer)
|
|
|
|
| 159 |
except Exception as e:
|
| 160 |
answers.append(f"Error: {e}")
|
| 161 |
|
|
|
|
| 8 |
from qatch.connectors.sqlite_connector import SqliteConnector
|
| 9 |
from qatch.evaluate_dataset.metrics_evaluators import CellPrecision, CellRecall, ExecutionAccuracy, TupleCardinality, TupleConstraint, TupleOrder, ValidEfficiencyScore
|
| 10 |
import qatch.evaluate_dataset.orchestrator_evaluator as eva
|
| 11 |
+
import utils_get_db_tables_info
|
| 12 |
#import tiktoken
|
| 13 |
from transformers import AutoTokenizer
|
| 14 |
|
|
|
|
| 152 |
answers = []
|
| 153 |
for _, row in df.iterrows():
|
| 154 |
query = row["query"]
|
| 155 |
+
db_schema = row["db_schema"]
|
| 156 |
+
#db_path = row["db_path"]
|
| 157 |
+
try:
|
| 158 |
+
conn = utils_get_db_tables_info.create_db_temp(db_schema)
|
| 159 |
+
|
| 160 |
+
result = pd.read_sql_query(query, conn)
|
| 161 |
+
answer = result.values.tolist() # Convert the DataFrame to a list of lists
|
| 162 |
+
|
| 163 |
answers.append(answer)
|
| 164 |
+
conn.close()
|
| 165 |
except Exception as e:
|
| 166 |
answers.append(f"Error: {e}")
|
| 167 |
|
utils_get_db_tables_info.py
CHANGED
|
@@ -49,11 +49,15 @@ def _get_schema_entries(cursor, sql=None, get_insert_into=False, model: str | No
|
|
| 49 |
tables = [tbl[0] for tbl in cursor.fetchall()]
|
| 50 |
|
| 51 |
for table in tables:
|
|
|
|
| 52 |
# Retrieve the CREATE TABLE statement for each table
|
| 53 |
cursor.execute(f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table}' AND sql IS NOT NULL;")
|
| 54 |
create_table_stmt = cursor.fetchone()
|
| 55 |
if create_table_stmt:
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
if get_insert_into:
|
| 59 |
# Retrieve all data from the table
|
|
@@ -70,9 +74,10 @@ def _get_schema_entries(cursor, sql=None, get_insert_into=False, model: str | No
|
|
| 70 |
for row in rows[:max_len]:
|
| 71 |
values = ', '.join(f"'{str(value)}'" if isinstance(value, str) else str(value) for value in row)
|
| 72 |
insert_stmt = f"INSERT INTO {table} ({', '.join(column_names)}) VALUES ({values});"
|
| 73 |
-
|
| 74 |
|
| 75 |
-
|
|
|
|
| 76 |
|
| 77 |
return entries
|
| 78 |
|
|
@@ -112,3 +117,26 @@ def _combine_schema_entries(schema_entries, normalize):
|
|
| 112 |
)
|
| 113 |
for entry in schema_entries
|
| 114 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
tables = [tbl[0] for tbl in cursor.fetchall()]
|
| 50 |
|
| 51 |
for table in tables:
|
| 52 |
+
entries_per_table = []
|
| 53 |
# Retrieve the CREATE TABLE statement for each table
|
| 54 |
cursor.execute(f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table}' AND sql IS NOT NULL;")
|
| 55 |
create_table_stmt = cursor.fetchone()
|
| 56 |
if create_table_stmt:
|
| 57 |
+
stmt = create_table_stmt[0].strip()
|
| 58 |
+
if not stmt.endswith(';'):
|
| 59 |
+
stmt += ';'
|
| 60 |
+
entries_per_table.append(stmt)
|
| 61 |
|
| 62 |
if get_insert_into:
|
| 63 |
# Retrieve all data from the table
|
|
|
|
| 74 |
for row in rows[:max_len]:
|
| 75 |
values = ', '.join(f"'{str(value)}'" if isinstance(value, str) else str(value) for value in row)
|
| 76 |
insert_stmt = f"INSERT INTO {table} ({', '.join(column_names)}) VALUES ({values});"
|
| 77 |
+
entries_per_table.append(insert_stmt)
|
| 78 |
|
| 79 |
+
if model != None : entries_per_table = us.crop_entries_per_token(entries_per_table, model, prompt)
|
| 80 |
+
entries.extend(entries_per_table)
|
| 81 |
|
| 82 |
return entries
|
| 83 |
|
|
|
|
| 117 |
)
|
| 118 |
for entry in schema_entries
|
| 119 |
)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def create_db_temp(schema_sql: str) -> sqlite3.Connection:
|
| 123 |
+
"""
|
| 124 |
+
Creates a temporary SQLite database in memory by executing the provided SQL schema.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
schema_sql (str): The SQL code containing CREATE TABLE and INSERT INTO.
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
sqlite3.Connection: Connection object to the temporary database.
|
| 131 |
+
"""
|
| 132 |
+
conn = sqlite3.connect(':memory:')
|
| 133 |
+
cursor = conn.cursor()
|
| 134 |
+
|
| 135 |
+
try:
|
| 136 |
+
cursor.executescript(schema_sql)
|
| 137 |
+
conn.commit()
|
| 138 |
+
except sqlite3.Error as e:
|
| 139 |
+
conn.close()
|
| 140 |
+
raise
|
| 141 |
+
|
| 142 |
+
return conn
|