|
|
|
import os |
|
import gradio as gr |
|
|
|
from azure.storage.fileshare import ShareServiceClient |
|
|
|
|
|
from climateqa.engine.embeddings import get_embeddings_function |
|
from climateqa.engine.llm import get_llm |
|
from climateqa.engine.vectorstore import get_pinecone_vectorstore |
|
from climateqa.engine.reranker import get_reranker |
|
from climateqa.engine.graph import make_graph_agent, make_graph_agent_poc |
|
from climateqa.engine.chains.retrieve_papers import find_papers |
|
from climateqa.chat import start_chat, chat_stream, finish_chat |
|
from climateqa.engine.talk_to_data.main import ask_drias, DRIAS_MODELS |
|
from climateqa.engine.talk_to_data.myVanna import MyVanna |
|
|
|
from front.tabs import create_config_modal, cqa_tab, create_about_tab |
|
from front.tabs import MainTabPanel, ConfigPanel |
|
from front.utils import process_figures |
|
from gradio_modal import Modal |
|
|
|
|
|
from utils import create_user_id |
|
import logging |
|
|
|
logging.basicConfig(level=logging.WARNING) |
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" |
|
logging.getLogger().setLevel(logging.WARNING) |
|
|
|
|
|
|
|
try: |
|
from dotenv import load_dotenv |
|
|
|
load_dotenv() |
|
except Exception as e: |
|
pass |
|
|
|
|
|
|
|
theme = gr.themes.Base( |
|
primary_hue="blue", |
|
secondary_hue="red", |
|
font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"], |
|
) |
|
|
|
|
|
account_key = os.environ["BLOB_ACCOUNT_KEY"] |
|
if len(account_key) == 86: |
|
account_key += "==" |
|
|
|
credential = { |
|
"account_key": account_key, |
|
"account_name": os.environ["BLOB_ACCOUNT_NAME"], |
|
} |
|
|
|
account_url = os.environ["BLOB_ACCOUNT_URL"] |
|
file_share_name = "climateqa" |
|
service = ShareServiceClient(account_url=account_url, credential=credential) |
|
share_client = service.get_share_client(file_share_name) |
|
|
|
user_id = create_user_id() |
|
|
|
|
|
|
|
embeddings_function = get_embeddings_function() |
|
vectorstore = get_pinecone_vectorstore( |
|
embeddings_function, index_name=os.getenv("PINECONE_API_INDEX") |
|
) |
|
vectorstore_graphs = get_pinecone_vectorstore( |
|
embeddings_function, |
|
index_name=os.getenv("PINECONE_API_INDEX_OWID"), |
|
text_key="description", |
|
) |
|
vectorstore_region = get_pinecone_vectorstore( |
|
embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_LOCAL_V2") |
|
) |
|
|
|
llm = get_llm(provider="openai", max_tokens=1024, temperature=0.0) |
|
if os.environ["GRADIO_ENV"] == "local": |
|
reranker = get_reranker("nano") |
|
else: |
|
reranker = get_reranker("large") |
|
|
|
agent = make_graph_agent( |
|
llm=llm, |
|
vectorstore_ipcc=vectorstore, |
|
vectorstore_graphs=vectorstore_graphs, |
|
vectorstore_region=vectorstore_region, |
|
reranker=reranker, |
|
threshold_docs=0.2, |
|
) |
|
agent_poc = make_graph_agent_poc( |
|
llm=llm, |
|
vectorstore_ipcc=vectorstore, |
|
vectorstore_graphs=vectorstore_graphs, |
|
vectorstore_region=vectorstore_region, |
|
reranker=reranker, |
|
threshold_docs=0, |
|
version="v4", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ask_drias_query(query: str, index_state: int): |
|
return ask_drias(query, index_state) |
|
|
|
|
|
async def chat( |
|
query, |
|
history, |
|
audience, |
|
sources, |
|
reports, |
|
relevant_content_sources_selection, |
|
search_only, |
|
): |
|
print("chat cqa - message received") |
|
async for event in chat_stream( |
|
agent, |
|
query, |
|
history, |
|
audience, |
|
sources, |
|
reports, |
|
relevant_content_sources_selection, |
|
search_only, |
|
share_client, |
|
user_id, |
|
): |
|
yield event |
|
|
|
|
|
async def chat_poc( |
|
query, |
|
history, |
|
audience, |
|
sources, |
|
reports, |
|
relevant_content_sources_selection, |
|
search_only, |
|
): |
|
print("chat poc - message received") |
|
async for event in chat_stream( |
|
agent_poc, |
|
query, |
|
history, |
|
audience, |
|
sources, |
|
reports, |
|
relevant_content_sources_selection, |
|
search_only, |
|
share_client, |
|
user_id, |
|
): |
|
yield event |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def update_config_modal_visibility(config_open): |
|
print(config_open) |
|
new_config_visibility_status = not config_open |
|
return Modal(visible=new_config_visibility_status), new_config_visibility_status |
|
|
|
|
|
def update_sources_number_display( |
|
sources_textbox, figures_cards, current_graphs, papers_html |
|
): |
|
sources_number = sources_textbox.count("<h2>") |
|
figures_number = figures_cards.count("<h2>") |
|
graphs_number = current_graphs.count("<iframe") |
|
papers_number = papers_html.count("<h2>") |
|
sources_notif_label = f"Sources ({sources_number})" |
|
figures_notif_label = f"Figures ({figures_number})" |
|
graphs_notif_label = f"Graphs ({graphs_number})" |
|
papers_notif_label = f"Papers ({papers_number})" |
|
recommended_content_notif_label = ( |
|
f"Recommended content ({figures_number + graphs_number + papers_number})" |
|
) |
|
|
|
return ( |
|
gr.update(label=recommended_content_notif_label), |
|
gr.update(label=sources_notif_label), |
|
gr.update(label=figures_notif_label), |
|
gr.update(label=graphs_notif_label), |
|
gr.update(label=papers_notif_label), |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def show_results(sql_queries_state, dataframes_state, plots_state): |
|
if not sql_queries_state or not dataframes_state or not plots_state: |
|
|
|
return ( |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
) |
|
else: |
|
|
|
return ( |
|
gr.update(visible=False), |
|
gr.update(visible=True), |
|
gr.update(visible=True), |
|
gr.update(visible=True), |
|
gr.update(visible=True), |
|
gr.update(visible=True), |
|
gr.update(visible=True), |
|
) |
|
|
|
|
|
def filter_by_model(dataframes, figures, index_state, model_selection): |
|
df = dataframes[index_state] |
|
if model_selection != "ALL": |
|
df = df[df["model"] == model_selection] |
|
figure = figures[index_state](df) |
|
return df, figure |
|
|
|
|
|
def update_pagination(index, sql_queries): |
|
pagination = f"{index + 1}/{len(sql_queries)}" if sql_queries else "" |
|
return pagination |
|
|
|
|
|
def create_drias_tab(): |
|
details_text = """ |
|
Hi, I'm **Talk to Drias**, designed to answer your questions using [**DRIAS - TRACC 2023**](https://www.drias-climat.fr/accompagnement/sections/401) data. |
|
I'll answer by displaying a list of SQL queries, graphs and data most relevant to your question. |
|
|
|
❓ **How to use?** |
|
You can ask me anything about these climate indicators: **temperature**, **precipitation** or **drought**. |
|
You can specify **location** and/or **year**. |
|
You can choose from a list of climate models. By default, we take the **average of each model**. |
|
|
|
For example, you can ask: |
|
- What will the temperature be like in Paris? |
|
- What will be the total rainfall in France in 2030? |
|
- How frequent will extreme events be in Lyon? |
|
|
|
**Example of indicators in the data**: |
|
- Mean temperature (annual, winter, summer) |
|
- Total precipitation (annual, winter, summer) |
|
- Number of days with remarkable precipitations, with dry ground, with temperature above 30°C |
|
|
|
⚠️ **Limitations**: |
|
- You can't ask anything that isn't related to **DRIAS - TRACC 2023** data. |
|
- You can only ask about **locations in France**. |
|
- If you specify a year, there may be **no data for that year for some models**. |
|
- You **cannot compare two models**. |
|
|
|
🛈 **Information** |
|
Please note that we **log your questions for meta-analysis purposes**, so avoid sharing any sensitive or personal information. |
|
""" |
|
with gr.Tab("Beta - Talk to DRIAS", elem_id="tab-vanna", id=6): |
|
|
|
with gr.Accordion(label="Details"): |
|
gr.Markdown(details_text) |
|
|
|
with gr.Row(): |
|
drias_direct_question = gr.Textbox( |
|
label="Direct Question", |
|
placeholder="You can write direct question here", |
|
elem_id="direct-question", |
|
interactive=True, |
|
) |
|
|
|
result_text = gr.Textbox( |
|
label="", elem_id="no-result-label", interactive=False, visible=True |
|
) |
|
|
|
with gr.Accordion(label="SQL Query Used", visible=False) as query_accordion: |
|
drias_sql_query = gr.Textbox( |
|
label="", elem_id="sql-query", interactive=False |
|
) |
|
|
|
with gr.Accordion(label="Chart", visible=False) as chart_accordion: |
|
model_selection = gr.Dropdown( |
|
label="Model", choices=DRIAS_MODELS, value="ALL", interactive=True |
|
) |
|
drias_display = gr.Plot(elem_id="vanna-plot") |
|
|
|
with gr.Accordion( |
|
label="Data used", open=False, visible=False |
|
) as table_accordion: |
|
drias_table = gr.DataFrame([], elem_id="vanna-table") |
|
|
|
pagination_display = gr.Markdown(value="", visible=False, elem_id="pagination-display") |
|
|
|
with gr.Row(): |
|
prev_button = gr.Button("Previous", visible=False) |
|
next_button = gr.Button("Next", visible=False) |
|
|
|
sql_queries_state = gr.State([]) |
|
dataframes_state = gr.State([]) |
|
plots_state = gr.State([]) |
|
index_state = gr.State(0) |
|
|
|
drias_direct_question.submit( |
|
ask_drias_query, |
|
inputs=[drias_direct_question, index_state], |
|
outputs=[ |
|
drias_sql_query, |
|
drias_table, |
|
drias_display, |
|
sql_queries_state, |
|
dataframes_state, |
|
plots_state, |
|
index_state, |
|
result_text, |
|
], |
|
).then( |
|
show_results, |
|
inputs=[sql_queries_state, dataframes_state, plots_state], |
|
outputs=[ |
|
result_text, |
|
query_accordion, |
|
table_accordion, |
|
chart_accordion, |
|
prev_button, |
|
next_button, |
|
pagination_display |
|
], |
|
).then( |
|
update_pagination, |
|
inputs=[index_state, sql_queries_state], |
|
outputs=[pagination_display], |
|
) |
|
|
|
model_selection.change( |
|
filter_by_model, |
|
inputs=[dataframes_state, plots_state, index_state, model_selection], |
|
outputs=[drias_table, drias_display], |
|
) |
|
|
|
def show_previous(index, sql_queries, dataframes, plots): |
|
if index > 0: |
|
index -= 1 |
|
return ( |
|
sql_queries[index], |
|
dataframes[index], |
|
plots[index](dataframes[index]), |
|
index, |
|
) |
|
|
|
def show_next(index, sql_queries, dataframes, plots): |
|
if index < len(sql_queries) - 1: |
|
index += 1 |
|
return ( |
|
sql_queries[index], |
|
dataframes[index], |
|
plots[index](dataframes[index]), |
|
index, |
|
) |
|
|
|
prev_button.click( |
|
show_previous, |
|
inputs=[index_state, sql_queries_state, dataframes_state, plots_state], |
|
outputs=[drias_sql_query, drias_table, drias_display, index_state], |
|
).then( |
|
update_pagination, |
|
inputs=[index_state, sql_queries_state], |
|
outputs=[pagination_display], |
|
) |
|
|
|
next_button.click( |
|
show_next, |
|
inputs=[index_state, sql_queries_state, dataframes_state, plots_state], |
|
outputs=[drias_sql_query, drias_table, drias_display, index_state], |
|
).then( |
|
update_pagination, |
|
inputs=[index_state, sql_queries_state], |
|
outputs=[pagination_display], |
|
) |
|
|
|
|
|
def config_event_handling( |
|
main_tabs_components: list[MainTabPanel], config_componenets: ConfigPanel |
|
): |
|
config_open = config_componenets.config_open |
|
config_modal = config_componenets.config_modal |
|
close_config_modal = config_componenets.close_config_modal_button |
|
|
|
for button in [close_config_modal] + [ |
|
main_tab_component.config_button for main_tab_component in main_tabs_components |
|
]: |
|
button.click( |
|
fn=update_config_modal_visibility, |
|
inputs=[config_open], |
|
outputs=[config_modal, config_open], |
|
) |
|
|
|
|
|
def event_handling( |
|
main_tab_components: MainTabPanel, |
|
config_components: ConfigPanel, |
|
tab_name="ClimateQ&A", |
|
): |
|
chatbot = main_tab_components.chatbot |
|
textbox = main_tab_components.textbox |
|
tabs = main_tab_components.tabs |
|
sources_raw = main_tab_components.sources_raw |
|
new_figures = main_tab_components.new_figures |
|
current_graphs = main_tab_components.current_graphs |
|
examples_hidden = main_tab_components.examples_hidden |
|
sources_textbox = main_tab_components.sources_textbox |
|
figures_cards = main_tab_components.figures_cards |
|
gallery_component = main_tab_components.gallery_component |
|
papers_direct_search = main_tab_components.papers_direct_search |
|
papers_html = main_tab_components.papers_html |
|
citations_network = main_tab_components.citations_network |
|
papers_summary = main_tab_components.papers_summary |
|
tab_recommended_content = main_tab_components.tab_recommended_content |
|
tab_sources = main_tab_components.tab_sources |
|
tab_figures = main_tab_components.tab_figures |
|
tab_graphs = main_tab_components.tab_graphs |
|
tab_papers = main_tab_components.tab_papers |
|
graphs_container = main_tab_components.graph_container |
|
follow_up_examples = main_tab_components.follow_up_examples |
|
follow_up_examples_hidden = main_tab_components.follow_up_examples_hidden |
|
|
|
dropdown_sources = config_components.dropdown_sources |
|
dropdown_reports = config_components.dropdown_reports |
|
dropdown_external_sources = config_components.dropdown_external_sources |
|
search_only = config_components.search_only |
|
dropdown_audience = config_components.dropdown_audience |
|
after = config_components.after |
|
output_query = config_components.output_query |
|
output_language = config_components.output_language |
|
|
|
new_sources_hmtl = gr.State([]) |
|
ttd_data = gr.State([]) |
|
|
|
if tab_name == "ClimateQ&A": |
|
print("chat cqa - message sent") |
|
|
|
|
|
( |
|
textbox.submit( |
|
start_chat, |
|
[textbox, chatbot, search_only], |
|
[textbox, tabs, chatbot, sources_raw], |
|
queue=False, |
|
api_name=f"start_chat_{textbox.elem_id}", |
|
) |
|
.then( |
|
chat, |
|
[ |
|
textbox, |
|
chatbot, |
|
dropdown_audience, |
|
dropdown_sources, |
|
dropdown_reports, |
|
dropdown_external_sources, |
|
search_only, |
|
], |
|
[ |
|
chatbot, |
|
new_sources_hmtl, |
|
output_query, |
|
output_language, |
|
new_figures, |
|
current_graphs, |
|
follow_up_examples.dataset, |
|
], |
|
concurrency_limit=8, |
|
api_name=f"chat_{textbox.elem_id}", |
|
) |
|
.then( |
|
finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}" |
|
) |
|
) |
|
|
|
( |
|
examples_hidden.change( |
|
start_chat, |
|
[examples_hidden, chatbot, search_only], |
|
[examples_hidden, tabs, chatbot, sources_raw], |
|
queue=False, |
|
api_name=f"start_chat_{examples_hidden.elem_id}", |
|
) |
|
.then( |
|
chat, |
|
[ |
|
examples_hidden, |
|
chatbot, |
|
dropdown_audience, |
|
dropdown_sources, |
|
dropdown_reports, |
|
dropdown_external_sources, |
|
search_only, |
|
], |
|
[ |
|
chatbot, |
|
new_sources_hmtl, |
|
output_query, |
|
output_language, |
|
new_figures, |
|
current_graphs, |
|
follow_up_examples.dataset, |
|
], |
|
concurrency_limit=8, |
|
api_name=f"chat_{examples_hidden.elem_id}", |
|
) |
|
.then( |
|
finish_chat, |
|
None, |
|
[textbox], |
|
api_name=f"finish_chat_{examples_hidden.elem_id}", |
|
) |
|
) |
|
( |
|
follow_up_examples_hidden.change( |
|
start_chat, |
|
[follow_up_examples_hidden, chatbot, search_only], |
|
[follow_up_examples_hidden, tabs, chatbot, sources_raw], |
|
queue=False, |
|
api_name=f"start_chat_{examples_hidden.elem_id}", |
|
) |
|
.then( |
|
chat, |
|
[ |
|
follow_up_examples_hidden, |
|
chatbot, |
|
dropdown_audience, |
|
dropdown_sources, |
|
dropdown_reports, |
|
dropdown_external_sources, |
|
search_only, |
|
], |
|
[ |
|
chatbot, |
|
new_sources_hmtl, |
|
output_query, |
|
output_language, |
|
new_figures, |
|
current_graphs, |
|
follow_up_examples.dataset, |
|
], |
|
concurrency_limit=8, |
|
api_name=f"chat_{examples_hidden.elem_id}", |
|
) |
|
.then( |
|
finish_chat, |
|
None, |
|
[textbox], |
|
api_name=f"finish_chat_{follow_up_examples_hidden.elem_id}", |
|
) |
|
) |
|
|
|
elif tab_name == "Beta - POC Adapt'Action": |
|
print("chat poc - message sent") |
|
|
|
( |
|
textbox.submit( |
|
start_chat, |
|
[textbox, chatbot, search_only], |
|
[textbox, tabs, chatbot, sources_raw], |
|
queue=False, |
|
api_name=f"start_chat_{textbox.elem_id}", |
|
) |
|
.then( |
|
chat_poc, |
|
[ |
|
textbox, |
|
chatbot, |
|
dropdown_audience, |
|
dropdown_sources, |
|
dropdown_reports, |
|
dropdown_external_sources, |
|
search_only, |
|
], |
|
[ |
|
chatbot, |
|
new_sources_hmtl, |
|
output_query, |
|
output_language, |
|
new_figures, |
|
current_graphs, |
|
], |
|
concurrency_limit=8, |
|
api_name=f"chat_{textbox.elem_id}", |
|
) |
|
.then( |
|
finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}" |
|
) |
|
) |
|
|
|
( |
|
examples_hidden.change( |
|
start_chat, |
|
[examples_hidden, chatbot, search_only], |
|
[examples_hidden, tabs, chatbot, sources_raw], |
|
queue=False, |
|
api_name=f"start_chat_{examples_hidden.elem_id}", |
|
) |
|
.then( |
|
chat_poc, |
|
[ |
|
examples_hidden, |
|
chatbot, |
|
dropdown_audience, |
|
dropdown_sources, |
|
dropdown_reports, |
|
dropdown_external_sources, |
|
search_only, |
|
], |
|
[ |
|
chatbot, |
|
new_sources_hmtl, |
|
output_query, |
|
output_language, |
|
new_figures, |
|
current_graphs, |
|
], |
|
concurrency_limit=8, |
|
api_name=f"chat_{examples_hidden.elem_id}", |
|
) |
|
.then( |
|
finish_chat, |
|
None, |
|
[textbox], |
|
api_name=f"finish_chat_{examples_hidden.elem_id}", |
|
) |
|
) |
|
( |
|
follow_up_examples_hidden.change( |
|
start_chat, |
|
[follow_up_examples_hidden, chatbot, search_only], |
|
[follow_up_examples_hidden, tabs, chatbot, sources_raw], |
|
queue=False, |
|
api_name=f"start_chat_{examples_hidden.elem_id}", |
|
) |
|
.then( |
|
chat, |
|
[ |
|
follow_up_examples_hidden, |
|
chatbot, |
|
dropdown_audience, |
|
dropdown_sources, |
|
dropdown_reports, |
|
dropdown_external_sources, |
|
search_only, |
|
], |
|
[ |
|
chatbot, |
|
new_sources_hmtl, |
|
output_query, |
|
output_language, |
|
new_figures, |
|
current_graphs, |
|
follow_up_examples.dataset, |
|
], |
|
concurrency_limit=8, |
|
api_name=f"chat_{examples_hidden.elem_id}", |
|
) |
|
.then( |
|
finish_chat, |
|
None, |
|
[textbox], |
|
api_name=f"finish_chat_{follow_up_examples_hidden.elem_id}", |
|
) |
|
) |
|
|
|
new_sources_hmtl.change( |
|
lambda x: x, inputs=[new_sources_hmtl], outputs=[sources_textbox] |
|
) |
|
current_graphs.change( |
|
lambda x: x, inputs=[current_graphs], outputs=[graphs_container] |
|
) |
|
new_figures.change( |
|
process_figures, |
|
inputs=[sources_raw, new_figures], |
|
outputs=[sources_raw, figures_cards, gallery_component], |
|
) |
|
|
|
|
|
for component in [sources_textbox, figures_cards, current_graphs, papers_html]: |
|
component.change( |
|
update_sources_number_display, |
|
[sources_textbox, figures_cards, current_graphs, papers_html], |
|
[tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers], |
|
) |
|
|
|
|
|
for component in [textbox, examples_hidden, papers_direct_search]: |
|
component.submit( |
|
find_papers, |
|
[component, after, dropdown_external_sources], |
|
[papers_html, citations_network, papers_summary], |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def main_ui(): |
|
|
|
with gr.Blocks( |
|
title="Climate Q&A", |
|
css_paths=os.getcwd() + "/style.css", |
|
theme=theme, |
|
elem_id="main-component", |
|
) as demo: |
|
config_components = create_config_modal() |
|
|
|
with gr.Tabs(): |
|
cqa_components = cqa_tab(tab_name="ClimateQ&A") |
|
local_cqa_components = cqa_tab(tab_name="Beta - POC Adapt'Action") |
|
create_drias_tab() |
|
|
|
create_about_tab() |
|
|
|
event_handling(cqa_components, config_components, tab_name="ClimateQ&A") |
|
event_handling( |
|
local_cqa_components, config_components, tab_name="Beta - POC Adapt'Action" |
|
) |
|
|
|
config_event_handling([cqa_components, local_cqa_components], config_components) |
|
|
|
demo.queue() |
|
|
|
return demo |
|
|
|
|
|
demo = main_ui() |
|
demo.launch(ssr_mode=False) |
|
|