Spaces:
Runtime error
Runtime error
| from services.ai import rank_tables,generate_sql,generate_answer , correct_sql , evaluate_difficulty | |
| from services.utils import filter_tables | |
| from openai import OpenAI | |
| from database import Database | |
| from filesource import FileSource | |
| import os | |
| MAX_TABLE = 3 | |
| client = OpenAI( | |
| base_url=os.getenv("LLM_ENDPOINT"), | |
| api_key=os.getenv("LLM_KEY") | |
| ) | |
| def run_agent(database,prompt): | |
| retry = 5 | |
| tables = database.get_tables_array() | |
| use_thinking = False | |
| if len(tables) > MAX_TABLE: | |
| print(f"using reranking because number of tables is greater than {MAX_TABLE}") | |
| ranked = rank_tables(prompt,tables) | |
| tables = filter_tables(0,ranked)[:MAX_TABLE] | |
| dif = int(evaluate_difficulty(client,prompt)) | |
| if dif > 7: | |
| print("difficulty is > 7 so we enable thinking mode") | |
| use_thinking = True | |
| sql = generate_sql(client,prompt,tables,use_thinking) | |
| nb_try = 0 | |
| success = False | |
| while nb_try < retry and not success: | |
| nb_try = nb_try + 1 | |
| try: | |
| print("try to launch sql request") | |
| result = database.query(sql) | |
| success = True | |
| except Exception as e: | |
| print(f"Error : {e}") | |
| print("Try to self correct...") | |
| error = f"{type(e).__name__} - {str(e)}" | |
| if nb_try < retry - 2: | |
| sql = correct_sql(client,prompt,sql,tables,error,True) | |
| else: | |
| sql = correct_sql(client,prompt,sql,tables,error,False) | |
| print(sql) | |
| if success: | |
| print(result.to_markdown()) | |
| return generate_answer(client,sql,prompt,result.to_markdown(),use_thinking) | |
| # db = Database("mysql://user:password@localhost:3306/Pokemon") | |
| # db.connect() | |
| # file = FileSource("./Wines.csv") | |
| # file.connect() | |
| # print(run_agent(file,"What is the quality og the win with the less of alcohol ?")) | |