simone-papicchio commited on
Commit
dbef406
·
1 Parent(s): 9aa07eb

fix: add schema database for single query

Browse files
Files changed (2) hide show
  1. app.py +8 -7
  2. utils_get_db_tables_info.py +0 -7
app.py CHANGED
@@ -620,13 +620,6 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
620
  #TODO s
621
  target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_include=input_data['data']['selected_tables'])
622
  #target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_includes=None)
623
-
624
- schema_text = utils_get_db_tables_info.utils_extract_db_schema_as_string(
625
- db_id = input_data["db_name"],
626
- base_path = input_data["data_path"],
627
- normalize=False,
628
- sql=None
629
- )
630
 
631
  predictor = ModelPrediction()
632
 
@@ -649,6 +642,14 @@ with gr.Blocks(theme='shivi/calm_seafoam', css_paths='style.css', js=js_func) as
649
  yield gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model]for model in model_list]
650
  start_time = time.time()
651
  samples = us.generate_some_samples(input_data['data']['db'], row["tbl_name"])
 
 
 
 
 
 
 
 
652
  prompt_to_send = us.prepare_prompt(input_data["prompt"], question, schema_text, samples)
653
  #PREDICTION SQL
654
  if prompt_to_send == prompt_default:
 
620
  #TODO s
621
  target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_include=input_data['data']['selected_tables'])
622
  #target_df = orchestrator_generator.generate_dataset(connector=input_data['data']['db'], tables_to_includes=None)
 
 
 
 
 
 
 
623
 
624
  predictor = ModelPrediction()
625
 
 
642
  yield gr.Image(), gr.Markdown(load_text), gr.Markdown(display_question), gr.Markdown(), metrics_conc, *[predictions_dict[model]for model in model_list]
643
  start_time = time.time()
644
  samples = us.generate_some_samples(input_data['data']['db'], row["tbl_name"])
645
+
646
+ schema_text = utils_get_db_tables_info.utils_extract_db_schema_as_string(
647
+ db_id = row["db_name"],
648
+ base_path = row["data_path"],
649
+ normalize=False,
650
+ sql=row["query"]
651
+ )
652
+
653
  prompt_to_send = us.prepare_prompt(input_data["prompt"], question, schema_text, samples)
654
  #PREDICTION SQL
655
  if prompt_to_send == prompt_default:
utils_get_db_tables_info.py CHANGED
@@ -15,13 +15,6 @@ def utils_extract_db_schema_as_string(
15
  :param sql: Optional SQL query to filter specific tables.
16
  :return: Schema of the database as a single string.
17
  """
18
- #db_path = os.path.join(base_path, db_id, f"{db_id}.sqlite")
19
-
20
- # Connect to the SQLite database
21
-
22
- #if not os.path.exists(db_path):
23
- # raise FileNotFoundError(f"Database file not found at: {db_path}")
24
-
25
  connection = sqlite3.connect(base_path)
26
  cursor = connection.cursor()
27
 
 
15
  :param sql: Optional SQL query to filter specific tables.
16
  :return: Schema of the database as a single string.
17
  """
 
 
 
 
 
 
 
18
  connection = sqlite3.connect(base_path)
19
  cursor = connection.cursor()
20