make ask drias asynchronous
Browse files
    	
        climateqa/engine/talk_to_data/main.py
    CHANGED
    
    | 
         @@ -37,7 +37,7 @@ def ask_llm_column_names(sql_query: str, llm) -> list[str]: 
     | 
|
| 37 | 
         
             
                columns_list = ast.literal_eval(columns.strip("```python\n").strip())
         
     | 
| 38 | 
         
             
                return columns_list
         
     | 
| 39 | 
         | 
| 40 | 
         
            -
            def ask_drias(query: str, index_state: int = 0) -> tuple:
         
     | 
| 41 | 
         
             
                """Main function to process a DRIAS query and return results.
         
     | 
| 42 | 
         | 
| 43 | 
         
             
                This function orchestrates the DRIAS workflow, processing a user query to generate
         
     | 
| 
         @@ -60,7 +60,7 @@ def ask_drias(query: str, index_state: int = 0) -> tuple: 
     | 
|
| 60 | 
         
             
                        - table_list (list): List of table names used
         
     | 
| 61 | 
         
             
                        - error (str): Error message if any
         
     | 
| 62 | 
         
             
                """
         
     | 
| 63 | 
         
            -
                final_state = drias_workflow(query)
         
     | 
| 64 | 
         
             
                sql_queries = []
         
     | 
| 65 | 
         
             
                result_dataframes = []
         
     | 
| 66 | 
         
             
                figures = []
         
     | 
| 
         | 
|
| 37 | 
         
             
                columns_list = ast.literal_eval(columns.strip("```python\n").strip())
         
     | 
| 38 | 
         
             
                return columns_list
         
     | 
| 39 | 
         | 
| 40 | 
         
            +
            async def ask_drias(query: str, index_state: int = 0) -> tuple:
         
     | 
| 41 | 
         
             
                """Main function to process a DRIAS query and return results.
         
     | 
| 42 | 
         | 
| 43 | 
         
             
                This function orchestrates the DRIAS workflow, processing a user query to generate
         
     | 
| 
         | 
|
| 60 | 
         
             
                        - table_list (list): List of table names used
         
     | 
| 61 | 
         
             
                        - error (str): Error message if any
         
     | 
| 62 | 
         
             
                """
         
     | 
| 63 | 
         
            +
                final_state = await drias_workflow(query)
         
     | 
| 64 | 
         
             
                sql_queries = []
         
     | 
| 65 | 
         
             
                result_dataframes = []
         
     | 
| 66 | 
         
             
                figures = []
         
     | 
    	
        climateqa/engine/talk_to_data/sql_query.py
    CHANGED
    
    | 
         @@ -1,8 +1,10 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 1 | 
         
             
            from typing import TypedDict
         
     | 
| 2 | 
         
             
            import duckdb
         
     | 
| 3 | 
         
             
            import pandas as pd
         
     | 
| 4 | 
         | 
| 5 | 
         
            -
            def execute_sql_query(sql_query: str) -> pd.DataFrame:
         
     | 
| 6 | 
         
             
                """Executes a SQL query on the DRIAS database and returns the results.
         
     | 
| 7 | 
         | 
| 8 | 
         
             
                This function connects to the DuckDB database containing DRIAS climate data
         
     | 
| 
         @@ -18,11 +20,16 @@ def execute_sql_query(sql_query: str) -> pd.DataFrame: 
     | 
|
| 18 | 
         
             
                Raises:
         
     | 
| 19 | 
         
             
                    duckdb.Error: If there is an error executing the SQL query
         
     | 
| 20 | 
         
             
                """
         
     | 
| 21 | 
         
            -
                 
     | 
| 22 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 23 | 
         | 
| 24 | 
         
            -
                #  
     | 
| 25 | 
         
            -
                 
     | 
| 
         | 
|
| 
         | 
|
| 26 | 
         | 
| 27 | 
         | 
| 28 | 
         
             
            class IndicatorPerYearAtLocationQueryParams(TypedDict, total=False):
         
     | 
| 
         | 
|
| 1 | 
         
            +
            import asyncio
         
     | 
| 2 | 
         
            +
            from concurrent.futures import ThreadPoolExecutor
         
     | 
| 3 | 
         
             
            from typing import TypedDict
         
     | 
| 4 | 
         
             
            import duckdb
         
     | 
| 5 | 
         
             
            import pandas as pd
         
     | 
| 6 | 
         | 
| 7 | 
         
            +
            async def execute_sql_query(sql_query: str) -> pd.DataFrame:
         
     | 
| 8 | 
         
             
                """Executes a SQL query on the DRIAS database and returns the results.
         
     | 
| 9 | 
         | 
| 10 | 
         
             
                This function connects to the DuckDB database containing DRIAS climate data
         
     | 
| 
         | 
|
| 20 | 
         
             
                Raises:
         
     | 
| 21 | 
         
             
                    duckdb.Error: If there is an error executing the SQL query
         
     | 
| 22 | 
         
             
                """
         
     | 
| 23 | 
         
            +
                def _execute_query():
         
     | 
| 24 | 
         
            +
                    # Execute the query
         
     | 
| 25 | 
         
            +
                    results = duckdb.sql(sql_query)
         
     | 
| 26 | 
         
            +
                    # return fetched data
         
     | 
| 27 | 
         
            +
                    return results.fetchdf()
         
     | 
| 28 | 
         | 
| 29 | 
         
            +
                # Run the query in a thread pool to avoid blocking
         
     | 
| 30 | 
         
            +
                loop = asyncio.get_event_loop()
         
     | 
| 31 | 
         
            +
                with ThreadPoolExecutor() as executor:
         
     | 
| 32 | 
         
            +
                    return await loop.run_in_executor(executor, _execute_query)
         
     | 
| 33 | 
         | 
| 34 | 
         | 
| 35 | 
         
             
            class IndicatorPerYearAtLocationQueryParams(TypedDict, total=False):
         
     | 
    	
        climateqa/engine/talk_to_data/utils.py
    CHANGED
    
    | 
         @@ -9,7 +9,7 @@ from climateqa.engine.talk_to_data.plot import PLOTS, Plot 
     | 
|
| 9 | 
         
             
            from langchain_core.prompts import ChatPromptTemplate
         
     | 
| 10 | 
         | 
| 11 | 
         | 
| 12 | 
         
            -
            def detect_location_with_openai(sentence):
         
     | 
| 13 | 
         
             
                """
         
     | 
| 14 | 
         
             
                Detects locations in a sentence using OpenAI's API via LangChain.
         
     | 
| 15 | 
         
             
                """
         
     | 
| 
         @@ -22,7 +22,7 @@ def detect_location_with_openai(sentence): 
     | 
|
| 22 | 
         
             
                Sentence: "{sentence}"
         
     | 
| 23 | 
         
             
                """
         
     | 
| 24 | 
         | 
| 25 | 
         
            -
                response = llm. 
     | 
| 26 | 
         
             
                location_list = ast.literal_eval(response.content.strip("```python\n").strip())
         
     | 
| 27 | 
         
             
                if location_list:
         
     | 
| 28 | 
         
             
                    return location_list[0]
         
     | 
| 
         @@ -40,7 +40,7 @@ class ArrayOutput(TypedDict): 
     | 
|
| 40 | 
         
             
                """
         
     | 
| 41 | 
         
             
                array: Annotated[str, "Syntactically valid python array."]
         
     | 
| 42 | 
         | 
| 43 | 
         
            -
            def detect_year_with_openai(sentence: str) -> str:
         
     | 
| 44 | 
         
             
                """
         
     | 
| 45 | 
         
             
                Detects years in a sentence using OpenAI's API via LangChain.
         
     | 
| 46 | 
         
             
                """
         
     | 
| 
         @@ -56,7 +56,7 @@ def detect_year_with_openai(sentence: str) -> str: 
     | 
|
| 56 | 
         
             
                prompt = ChatPromptTemplate.from_template(prompt)
         
     | 
| 57 | 
         
             
                structured_llm = llm.with_structured_output(ArrayOutput)
         
     | 
| 58 | 
         
             
                chain = prompt | structured_llm
         
     | 
| 59 | 
         
            -
                response: ArrayOutput = chain. 
     | 
| 60 | 
         
             
                years_list = eval(response['array'])
         
     | 
| 61 | 
         
             
                if len(years_list) > 0:
         
     | 
| 62 | 
         
             
                    return years_list[0]
         
     | 
| 
         @@ -146,7 +146,7 @@ def nearestNeighbourSQL(location: tuple, table: str) -> tuple[str, str]: 
     | 
|
| 146 | 
         
             
                return results['latitude'].iloc[0], results['longitude'].iloc[0]
         
     | 
| 147 | 
         | 
| 148 | 
         | 
| 149 | 
         
            -
            def detect_relevant_tables(user_question: str, plot: Plot, llm) -> list[str]:
         
     | 
| 150 | 
         
             
                """Identifies relevant tables for a plot based on user input.
         
     | 
| 151 | 
         | 
| 152 | 
         
             
                This function uses an LLM to analyze the user's question and the plot
         
     | 
| 
         @@ -183,7 +183,7 @@ def detect_relevant_tables(user_question: str, plot: Plot, llm) -> list[str]: 
     | 
|
| 183 | 
         
             
                )
         
     | 
| 184 | 
         | 
| 185 | 
         
             
                table_names = ast.literal_eval(
         
     | 
| 186 | 
         
            -
                    llm. 
     | 
| 187 | 
         
             
                )
         
     | 
| 188 | 
         
             
                return table_names
         
     | 
| 189 | 
         | 
| 
         @@ -197,7 +197,7 @@ def replace_coordonates(coords, query, coords_tables): 
     | 
|
| 197 | 
         
             
                return query
         
     | 
| 198 | 
         | 
| 199 | 
         | 
| 200 | 
         
            -
            def detect_relevant_plots(user_question: str, llm):
         
     | 
| 201 | 
         
             
                plots_description = ""
         
     | 
| 202 | 
         
             
                for plot in PLOTS:
         
     | 
| 203 | 
         
             
                    plots_description += "Name: " + plot["name"]
         
     | 
| 
         @@ -223,7 +223,7 @@ def detect_relevant_plots(user_question: str, llm): 
     | 
|
| 223 | 
         
             
                # )
         
     | 
| 224 | 
         | 
| 225 | 
         
             
                plot_names = ast.literal_eval(
         
     | 
| 226 | 
         
            -
                    llm. 
     | 
| 227 | 
         
             
                )
         
     | 
| 228 | 
         
             
                return plot_names
         
     | 
| 229 | 
         | 
| 
         | 
|
| 9 | 
         
             
            from langchain_core.prompts import ChatPromptTemplate
         
     | 
| 10 | 
         | 
| 11 | 
         | 
| 12 | 
         
            +
            async def detect_location_with_openai(sentence):
         
     | 
| 13 | 
         
             
                """
         
     | 
| 14 | 
         
             
                Detects locations in a sentence using OpenAI's API via LangChain.
         
     | 
| 15 | 
         
             
                """
         
     | 
| 
         | 
|
| 22 | 
         
             
                Sentence: "{sentence}"
         
     | 
| 23 | 
         
             
                """
         
     | 
| 24 | 
         | 
| 25 | 
         
            +
                response = await llm.ainvoke(prompt)
         
     | 
| 26 | 
         
             
                location_list = ast.literal_eval(response.content.strip("```python\n").strip())
         
     | 
| 27 | 
         
             
                if location_list:
         
     | 
| 28 | 
         
             
                    return location_list[0]
         
     | 
| 
         | 
|
| 40 | 
         
             
                """
         
     | 
| 41 | 
         
             
                array: Annotated[str, "Syntactically valid python array."]
         
     | 
| 42 | 
         | 
| 43 | 
         
            +
            async def detect_year_with_openai(sentence: str) -> str:
         
     | 
| 44 | 
         
             
                """
         
     | 
| 45 | 
         
             
                Detects years in a sentence using OpenAI's API via LangChain.
         
     | 
| 46 | 
         
             
                """
         
     | 
| 
         | 
|
| 56 | 
         
             
                prompt = ChatPromptTemplate.from_template(prompt)
         
     | 
| 57 | 
         
             
                structured_llm = llm.with_structured_output(ArrayOutput)
         
     | 
| 58 | 
         
             
                chain = prompt | structured_llm
         
     | 
| 59 | 
         
            +
                response: ArrayOutput = await chain.ainvoke({"sentence": sentence})
         
     | 
| 60 | 
         
             
                years_list = eval(response['array'])
         
     | 
| 61 | 
         
             
                if len(years_list) > 0:
         
     | 
| 62 | 
         
             
                    return years_list[0]
         
     | 
| 
         | 
|
| 146 | 
         
             
                return results['latitude'].iloc[0], results['longitude'].iloc[0]
         
     | 
| 147 | 
         | 
| 148 | 
         | 
| 149 | 
         
            +
            async def detect_relevant_tables(user_question: str, plot: Plot, llm) -> list[str]:
         
     | 
| 150 | 
         
             
                """Identifies relevant tables for a plot based on user input.
         
     | 
| 151 | 
         | 
| 152 | 
         
             
                This function uses an LLM to analyze the user's question and the plot
         
     | 
| 
         | 
|
| 183 | 
         
             
                )
         
     | 
| 184 | 
         | 
| 185 | 
         
             
                table_names = ast.literal_eval(
         
     | 
| 186 | 
         
            +
                    (await llm.ainvoke(prompt)).content.strip("```python\n").strip()
         
     | 
| 187 | 
         
             
                )
         
     | 
| 188 | 
         
             
                return table_names
         
     | 
| 189 | 
         | 
| 
         | 
|
| 197 | 
         
             
                return query
         
     | 
| 198 | 
         | 
| 199 | 
         | 
| 200 | 
         
            +
            async def detect_relevant_plots(user_question: str, llm):
         
     | 
| 201 | 
         
             
                plots_description = ""
         
     | 
| 202 | 
         
             
                for plot in PLOTS:
         
     | 
| 203 | 
         
             
                    plots_description += "Name: " + plot["name"]
         
     | 
| 
         | 
|
| 223 | 
         
             
                # )
         
     | 
| 224 | 
         | 
| 225 | 
         
             
                plot_names = ast.literal_eval(
         
     | 
| 226 | 
         
            +
                    (await llm.ainvoke(prompt)).content.strip("```python\n").strip()
         
     | 
| 227 | 
         
             
                )
         
     | 
| 228 | 
         
             
                return plot_names
         
     | 
| 229 | 
         | 
    	
        climateqa/engine/talk_to_data/workflow.py
    CHANGED
    
    | 
         @@ -61,7 +61,7 @@ class State(TypedDict): 
     | 
|
| 61 | 
         
             
                plot_states: dict[str, PlotState]
         
     | 
| 62 | 
         
             
                error: NotRequired[str]
         
     | 
| 63 | 
         | 
| 64 | 
         
            -
            def drias_workflow(user_input: str) -> State:
         
     | 
| 65 | 
         
             
                """Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated
         
     | 
| 66 | 
         | 
| 67 | 
         
             
                Args:
         
     | 
| 
         @@ -78,7 +78,7 @@ def drias_workflow(user_input: str) -> State: 
     | 
|
| 78 | 
         | 
| 79 | 
         
             
                llm = get_llm(provider="openai")
         
     | 
| 80 | 
         | 
| 81 | 
         
            -
                plots = find_relevant_plots(state, llm)
         
     | 
| 82 | 
         
             
                state['plots'] = plots
         
     | 
| 83 | 
         | 
| 84 | 
         
             
                if not state['plots']:
         
     | 
| 
         @@ -102,7 +102,7 @@ def drias_workflow(user_input: str) -> State: 
     | 
|
| 102 | 
         | 
| 103 | 
         
             
                    plot_state['plot_name'] = plot_name
         
     | 
| 104 | 
         | 
| 105 | 
         
            -
                    relevant_tables = find_relevant_tables_per_plot(state, plot, llm)
         
     | 
| 106 | 
         
             
                    if len(relevant_tables) > 0 :
         
     | 
| 107 | 
         
             
                        have_relevant_table = True
         
     | 
| 108 | 
         | 
| 
         @@ -110,7 +110,7 @@ def drias_workflow(user_input: str) -> State: 
     | 
|
| 110 | 
         | 
| 111 | 
         
             
                    params = {}
         
     | 
| 112 | 
         
             
                    for param_name in plot['params']:
         
     | 
| 113 | 
         
            -
                        param = find_param(state, param_name, relevant_tables[0])
         
     | 
| 114 | 
         
             
                        if param:
         
     | 
| 115 | 
         
             
                            params.update(param)
         
     | 
| 116 | 
         | 
| 
         @@ -135,7 +135,7 @@ def drias_workflow(user_input: str) -> State: 
     | 
|
| 135 | 
         
             
                            have_sql_query = True
         
     | 
| 136 | 
         | 
| 137 | 
         
             
                        table_state['sql_query'] = sql_query
         
     | 
| 138 | 
         
            -
                        df = execute_sql_query(sql_query)
         
     | 
| 139 | 
         | 
| 140 | 
         
             
                        if len(df) > 0:
         
     | 
| 141 | 
         
             
                            have_dataframe = True
         
     | 
| 
         @@ -154,22 +154,19 @@ def drias_workflow(user_input: str) -> State: 
     | 
|
| 154 | 
         
             
                elif not have_dataframe:
         
     | 
| 155 | 
         
             
                    state['error'] = "There is no data in our table that can answer to your question"
         
     | 
| 156 | 
         | 
| 157 | 
         
            -
             
     | 
| 158 | 
         
             
                return state
         
     | 
| 159 | 
         | 
| 160 | 
         
            -
             
     | 
| 161 | 
         
            -
            def find_relevant_plots(state: State, llm) -> list[str]:
         
     | 
| 162 | 
         
             
                print("---- Find relevant plots ----")
         
     | 
| 163 | 
         
            -
                relevant_plots = detect_relevant_plots(state['user_input'], llm)
         
     | 
| 164 | 
         
             
                return relevant_plots
         
     | 
| 165 | 
         | 
| 166 | 
         
            -
            def find_relevant_tables_per_plot(state: State, plot: Plot, llm) -> list[str]:
         
     | 
| 167 | 
         
             
                print(f"---- Find relevant tables for {plot['name']} ----")
         
     | 
| 168 | 
         
            -
                relevant_tables = detect_relevant_tables(state['user_input'], plot, llm)
         
     | 
| 169 | 
         
             
                return relevant_tables
         
     | 
| 170 | 
         | 
| 171 | 
         
            -
             
     | 
| 172 | 
         
            -
            def find_param(state: State, param_name:str, table: str) -> dict[str, Any] | None:
         
     | 
| 173 | 
         
             
                """Perform the good method to retrieve the desired parameter
         
     | 
| 174 | 
         | 
| 175 | 
         
             
                Args:
         
     | 
| 
         @@ -181,25 +178,21 @@ def find_param(state: State, param_name:str, table: str) -> dict[str, Any] | Non 
     | 
|
| 181 | 
         
             
                    dict[str, Any] | None: 
         
     | 
| 182 | 
         
             
                """
         
     | 
| 183 | 
         
             
                if param_name == 'location':
         
     | 
| 184 | 
         
            -
                    location = find_location(state['user_input'], table)
         
     | 
| 185 | 
         
             
                    return location
         
     | 
| 186 | 
         
            -
                # if param_name == 'indicator_column':
         
     | 
| 187 | 
         
            -
                #     indicator_column = find_indicator_column(table)
         
     | 
| 188 | 
         
            -
                #     return {'indicator_column': indicator_column} 
         
     | 
| 189 | 
         
             
                if param_name == 'year':
         
     | 
| 190 | 
         
            -
                    year = find_year(state['user_input'])
         
     | 
| 191 | 
         
             
                    return {'year': year}
         
     | 
| 192 | 
         
             
                return None
         
     | 
| 193 | 
         | 
| 194 | 
         
            -
             
     | 
| 195 | 
         
             
            class Location(TypedDict):
         
     | 
| 196 | 
         
             
                location: str
         
     | 
| 197 | 
         
             
                latitude: NotRequired[str]
         
     | 
| 198 | 
         
             
                longitude: NotRequired[str]
         
     | 
| 199 | 
         | 
| 200 | 
         
            -
            def find_location(user_input: str, table: str) -> Location:
         
     | 
| 201 | 
         
             
                print(f"---- Find location in table {table} ----")
         
     | 
| 202 | 
         
            -
                location = detect_location_with_openai(user_input)
         
     | 
| 203 | 
         
             
                output: Location = {'location' : location}
         
     | 
| 204 | 
         
             
                if location:
         
     | 
| 205 | 
         
             
                    coords = loc2coords(location)
         
     | 
| 
         @@ -210,7 +203,7 @@ def find_location(user_input: str, table: str) -> Location: 
     | 
|
| 210 | 
         
             
                    })
         
     | 
| 211 | 
         
             
                return output
         
     | 
| 212 | 
         | 
| 213 | 
         
            -
            def find_year(user_input: str) -> str:
         
     | 
| 214 | 
         
             
                """Extracts year information from user input using LLM.
         
     | 
| 215 | 
         | 
| 216 | 
         
             
                This function uses an LLM to identify and extract year information from the
         
     | 
| 
         @@ -223,7 +216,7 @@ def find_year(user_input: str) -> str: 
     | 
|
| 223 | 
         
             
                    str: The extracted year, or empty string if no year found
         
     | 
| 224 | 
         
             
                """
         
     | 
| 225 | 
         
             
                print(f"---- Find year ---")
         
     | 
| 226 | 
         
            -
                year = detect_year_with_openai(user_input)
         
     | 
| 227 | 
         
             
                return year
         
     | 
| 228 | 
         | 
| 229 | 
         
             
            def find_indicator_column(table: str) -> str:
         
     | 
| 
         | 
|
| 61 | 
         
             
                plot_states: dict[str, PlotState]
         
     | 
| 62 | 
         
             
                error: NotRequired[str]
         
     | 
| 63 | 
         | 
| 64 | 
         
            +
            async def drias_workflow(user_input: str) -> State:
         
     | 
| 65 | 
         
             
                """Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated
         
     | 
| 66 | 
         | 
| 67 | 
         
             
                Args:
         
     | 
| 
         | 
|
| 78 | 
         | 
| 79 | 
         
             
                llm = get_llm(provider="openai")
         
     | 
| 80 | 
         | 
| 81 | 
         
            +
                plots = await find_relevant_plots(state, llm)
         
     | 
| 82 | 
         
             
                state['plots'] = plots
         
     | 
| 83 | 
         | 
| 84 | 
         
             
                if not state['plots']:
         
     | 
| 
         | 
|
| 102 | 
         | 
| 103 | 
         
             
                    plot_state['plot_name'] = plot_name
         
     | 
| 104 | 
         | 
| 105 | 
         
            +
                    relevant_tables = await find_relevant_tables_per_plot(state, plot, llm)
         
     | 
| 106 | 
         
             
                    if len(relevant_tables) > 0 :
         
     | 
| 107 | 
         
             
                        have_relevant_table = True
         
     | 
| 108 | 
         | 
| 
         | 
|
| 110 | 
         | 
| 111 | 
         
             
                    params = {}
         
     | 
| 112 | 
         
             
                    for param_name in plot['params']:
         
     | 
| 113 | 
         
            +
                        param = await find_param(state, param_name, relevant_tables[0])
         
     | 
| 114 | 
         
             
                        if param:
         
     | 
| 115 | 
         
             
                            params.update(param)
         
     | 
| 116 | 
         | 
| 
         | 
|
| 135 | 
         
             
                            have_sql_query = True
         
     | 
| 136 | 
         | 
| 137 | 
         
             
                        table_state['sql_query'] = sql_query
         
     | 
| 138 | 
         
            +
                        df = await execute_sql_query(sql_query)
         
     | 
| 139 | 
         | 
| 140 | 
         
             
                        if len(df) > 0:
         
     | 
| 141 | 
         
             
                            have_dataframe = True
         
     | 
| 
         | 
|
| 154 | 
         
             
                elif not have_dataframe:
         
     | 
| 155 | 
         
             
                    state['error'] = "There is no data in our table that can answer to your question"
         
     | 
| 156 | 
         | 
| 
         | 
|
| 157 | 
         
             
                return state
         
     | 
| 158 | 
         | 
| 159 | 
         
            +
            async def find_relevant_plots(state: State, llm) -> list[str]:
         
     | 
| 
         | 
|
| 160 | 
         
             
                print("---- Find relevant plots ----")
         
     | 
| 161 | 
         
            +
                relevant_plots = await detect_relevant_plots(state['user_input'], llm)
         
     | 
| 162 | 
         
             
                return relevant_plots
         
     | 
| 163 | 
         | 
| 164 | 
         
            +
            async def find_relevant_tables_per_plot(state: State, plot: Plot, llm) -> list[str]:
         
     | 
| 165 | 
         
             
                print(f"---- Find relevant tables for {plot['name']} ----")
         
     | 
| 166 | 
         
            +
                relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm)
         
     | 
| 167 | 
         
             
                return relevant_tables
         
     | 
| 168 | 
         | 
| 169 | 
         
            +
            async def find_param(state: State, param_name:str, table: str) -> dict[str, Any] | None:
         
     | 
| 
         | 
|
| 170 | 
         
             
                """Perform the good method to retrieve the desired parameter
         
     | 
| 171 | 
         | 
| 172 | 
         
             
                Args:
         
     | 
| 
         | 
|
| 178 | 
         
             
                    dict[str, Any] | None: 
         
     | 
| 179 | 
         
             
                """
         
     | 
| 180 | 
         
             
                if param_name == 'location':
         
     | 
| 181 | 
         
            +
                    location = await find_location(state['user_input'], table)
         
     | 
| 182 | 
         
             
                    return location
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 183 | 
         
             
                if param_name == 'year':
         
     | 
| 184 | 
         
            +
                    year = await find_year(state['user_input'])
         
     | 
| 185 | 
         
             
                    return {'year': year}
         
     | 
| 186 | 
         
             
                return None
         
     | 
| 187 | 
         | 
| 
         | 
|
| 188 | 
         
             
            class Location(TypedDict):
         
     | 
| 189 | 
         
             
                location: str
         
     | 
| 190 | 
         
             
                latitude: NotRequired[str]
         
     | 
| 191 | 
         
             
                longitude: NotRequired[str]
         
     | 
| 192 | 
         | 
| 193 | 
         
            +
            async def find_location(user_input: str, table: str) -> Location:
         
     | 
| 194 | 
         
             
                print(f"---- Find location in table {table} ----")
         
     | 
| 195 | 
         
            +
                location = await detect_location_with_openai(user_input)
         
     | 
| 196 | 
         
             
                output: Location = {'location' : location}
         
     | 
| 197 | 
         
             
                if location:
         
     | 
| 198 | 
         
             
                    coords = loc2coords(location)
         
     | 
| 
         | 
|
| 203 | 
         
             
                    })
         
     | 
| 204 | 
         
             
                return output
         
     | 
| 205 | 
         | 
| 206 | 
         
            +
            async def find_year(user_input: str) -> str:
         
     | 
| 207 | 
         
             
                """Extracts year information from user input using LLM.
         
     | 
| 208 | 
         | 
| 209 | 
         
             
                This function uses an LLM to identify and extract year information from the
         
     | 
| 
         | 
|
| 216 | 
         
             
                    str: The extracted year, or empty string if no year found
         
     | 
| 217 | 
         
             
                """
         
     | 
| 218 | 
         
             
                print(f"---- Find year ---")
         
     | 
| 219 | 
         
            +
                year = await detect_year_with_openai(user_input)
         
     | 
| 220 | 
         
             
                return year
         
     | 
| 221 | 
         | 
| 222 | 
         
             
            def find_indicator_column(table: str) -> str:
         
     | 
    	
        front/tabs/tab_drias.py
    CHANGED
    
    | 
         @@ -4,8 +4,8 @@ from climateqa.engine.talk_to_data.main import ask_drias 
     | 
|
| 4 | 
         
             
            from climateqa.engine.talk_to_data.config import DRIAS_MODELS, DRIAS_UI_TEXT
         
     | 
| 5 | 
         | 
| 6 | 
         | 
| 7 | 
         
            -
            def ask_drias_query(query: str, index_state: int):
         
     | 
| 8 | 
         
            -
                return ask_drias(query, index_state)
         
     | 
| 9 | 
         | 
| 10 | 
         | 
| 11 | 
         
             
            def show_results(sql_queries_state, dataframes_state, plots_state):
         
     | 
| 
         | 
|
| 4 | 
         
             
            from climateqa.engine.talk_to_data.config import DRIAS_MODELS, DRIAS_UI_TEXT
         
     | 
| 5 | 
         | 
| 6 | 
         | 
| 7 | 
         
            +
            async def ask_drias_query(query: str, index_state: int):
         
     | 
| 8 | 
         
            +
                return await ask_drias(query, index_state)
         
     | 
| 9 | 
         | 
| 10 | 
         | 
| 11 | 
         
             
            def show_results(sql_queries_state, dataframes_state, plots_state):
         
     |