armanddemasson's picture
feat: model filtering and UI upgrade for TTD
26bb643
raw
history blame
25.7 kB
# Import necessary libraries
import os
import gradio as gr
from azure.storage.fileshare import ShareServiceClient
# Import custom modules
from climateqa.engine.embeddings import get_embeddings_function
from climateqa.engine.llm import get_llm
from climateqa.engine.vectorstore import get_pinecone_vectorstore
from climateqa.engine.reranker import get_reranker
from climateqa.engine.graph import make_graph_agent, make_graph_agent_poc
from climateqa.engine.chains.retrieve_papers import find_papers
from climateqa.chat import start_chat, chat_stream, finish_chat
from climateqa.engine.talk_to_data.main import ask_drias, DRIAS_MODELS
from climateqa.engine.talk_to_data.myVanna import MyVanna
from front.tabs import create_config_modal, cqa_tab, create_about_tab
from front.tabs import MainTabPanel, ConfigPanel
from front.utils import process_figures
from gradio_modal import Modal
from utils import create_user_id
import logging
logging.basicConfig(level=logging.WARNING)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Suppresses INFO and WARNING logs
logging.getLogger().setLevel(logging.WARNING)
# Load environment variables in local mode
try:
from dotenv import load_dotenv
load_dotenv()
except Exception as e:
pass
# Set up Gradio Theme
theme = gr.themes.Base(
primary_hue="blue",
secondary_hue="red",
font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"],
)
# Azure Blob Storage credentials
account_key = os.environ["BLOB_ACCOUNT_KEY"]
if len(account_key) == 86:
account_key += "=="
credential = {
"account_key": account_key,
"account_name": os.environ["BLOB_ACCOUNT_NAME"],
}
account_url = os.environ["BLOB_ACCOUNT_URL"]
file_share_name = "climateqa"
service = ShareServiceClient(account_url=account_url, credential=credential)
share_client = service.get_share_client(file_share_name)
user_id = create_user_id()
# Create vectorstore and retriever
embeddings_function = get_embeddings_function()
vectorstore = get_pinecone_vectorstore(
embeddings_function, index_name=os.getenv("PINECONE_API_INDEX")
)
vectorstore_graphs = get_pinecone_vectorstore(
embeddings_function,
index_name=os.getenv("PINECONE_API_INDEX_OWID"),
text_key="description",
)
vectorstore_region = get_pinecone_vectorstore(
embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_LOCAL_V2")
)
llm = get_llm(provider="openai", max_tokens=1024, temperature=0.0)
if os.environ["GRADIO_ENV"] == "local":
reranker = get_reranker("nano")
else:
reranker = get_reranker("large")
agent = make_graph_agent(
llm=llm,
vectorstore_ipcc=vectorstore,
vectorstore_graphs=vectorstore_graphs,
vectorstore_region=vectorstore_region,
reranker=reranker,
threshold_docs=0.2,
)
agent_poc = make_graph_agent_poc(
llm=llm,
vectorstore_ipcc=vectorstore,
vectorstore_graphs=vectorstore_graphs,
vectorstore_region=vectorstore_region,
reranker=reranker,
threshold_docs=0,
version="v4",
) # TODO put back default 0.2
# Vanna object
# vn = MyVanna(config = {"temperature": 0, "api_key": os.getenv('THEO_API_KEY'), 'model': os.getenv('VANNA_MODEL'), 'pc_api_key': os.getenv('VANNA_PINECONE_API_KEY'), 'index_name': os.getenv('VANNA_INDEX_NAME'), "top_k" : 4})
# db_vanna_path = os.path.join(os.getcwd(), "data/drias/drias.db")
# vn.connect_to_sqlite(db_vanna_path)
# def ask_vanna_query(query):
# return ask_vanna(vn, db_vanna_path, query)
def ask_drias_query(query: str, index_state: int):
return ask_drias(query, index_state)
async def chat(
query,
history,
audience,
sources,
reports,
relevant_content_sources_selection,
search_only,
):
print("chat cqa - message received")
async for event in chat_stream(
agent,
query,
history,
audience,
sources,
reports,
relevant_content_sources_selection,
search_only,
share_client,
user_id,
):
yield event
async def chat_poc(
query,
history,
audience,
sources,
reports,
relevant_content_sources_selection,
search_only,
):
print("chat poc - message received")
async for event in chat_stream(
agent_poc,
query,
history,
audience,
sources,
reports,
relevant_content_sources_selection,
search_only,
share_client,
user_id,
):
yield event
# --------------------------------------------------------------------
# Gradio
# --------------------------------------------------------------------
# Function to update modal visibility
def update_config_modal_visibility(config_open):
print(config_open)
new_config_visibility_status = not config_open
return Modal(visible=new_config_visibility_status), new_config_visibility_status
def update_sources_number_display(
sources_textbox, figures_cards, current_graphs, papers_html
):
sources_number = sources_textbox.count("<h2>")
figures_number = figures_cards.count("<h2>")
graphs_number = current_graphs.count("<iframe")
papers_number = papers_html.count("<h2>")
sources_notif_label = f"Sources ({sources_number})"
figures_notif_label = f"Figures ({figures_number})"
graphs_notif_label = f"Graphs ({graphs_number})"
papers_notif_label = f"Papers ({papers_number})"
recommended_content_notif_label = (
f"Recommended content ({figures_number + graphs_number + papers_number})"
)
return (
gr.update(label=recommended_content_notif_label),
gr.update(label=sources_notif_label),
gr.update(label=figures_notif_label),
gr.update(label=graphs_notif_label),
gr.update(label=papers_notif_label),
)
# def create_drias_tab():
# with gr.Tab("Beta - Talk to DRIAS", elem_id="tab-vanna", id=6) as tab_vanna:
# vanna_direct_question = gr.Textbox(label="Direct Question", placeholder="You can write direct question here",elem_id="direct-question", interactive=True)
# with gr.Accordion("Details",elem_id = 'vanna-details', open=False) as vanna_details :
# vanna_sql_query = gr.Textbox(label="SQL Query Used", elem_id="sql-query", interactive=False)
# show_vanna_table = gr.Button("Show Table", elem_id="show-table")
# with Modal(visible=False) as vanna_table_modal:
# vanna_table = gr.DataFrame([], elem_id="vanna-table")
# close_vanna_modal = gr.Button("Close", elem_id="close-vanna-modal")
# close_vanna_modal.click(lambda: Modal(visible=False),None, [vanna_table_modal])
# show_vanna_table.click(lambda: Modal(visible=True),None ,[vanna_table_modal])
# vanna_display = gr.Plot()
# vanna_direct_question.submit(ask_drias_query, [vanna_direct_question], [vanna_sql_query ,vanna_table, vanna_display])
def show_results(sql_queries_state, dataframes_state, plots_state):
if not sql_queries_state or not dataframes_state or not plots_state:
# If all results are empty, show "No result"
return (
gr.update(visible=True),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
)
else:
# Show the appropriate components with their data
return (
gr.update(visible=False),
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
)
def filter_by_model(dataframes, figures, index_state, model_selection):
df = dataframes[index_state]
if model_selection != "ALL":
df = df[df["model"] == model_selection]
figure = figures[index_state](df)
return df, figure
def update_pagination(index, sql_queries):
pagination = f"{index + 1}/{len(sql_queries)}" if sql_queries else ""
return pagination
def create_drias_tab():
details_text = """
Hi, I'm **Talk to Drias**, designed to answer your questions using [**DRIAS - TRACC 2023**](https://www.drias-climat.fr/accompagnement/sections/401) data.
I'll answer by displaying a list of SQL queries, graphs and data most relevant to your question.
❓ **How to use?**
You can ask me anything about these climate indicators: **temperature**, **precipitation** or **drought**.
You can specify **location** and/or **year**.
You can choose from a list of climate models. By default, we take the **average of each model**.
For example, you can ask:
- What will the temperature be like in Paris?
- What will be the total rainfall in France in 2030?
- How frequent will extreme events be in Lyon?
**Example of indicators in the data**:
- Mean temperature (annual, winter, summer)
- Total precipitation (annual, winter, summer)
- Number of days with remarkable precipitations, with dry ground, with temperature above 30°C
⚠️ **Limitations**:
- You can't ask anything that isn't related to **DRIAS - TRACC 2023** data.
- You can only ask about **locations in France**.
- If you specify a year, there may be **no data for that year for some models**.
- You **cannot compare two models**.
🛈 **Information**
Please note that we **log your questions for meta-analysis purposes**, so avoid sharing any sensitive or personal information.
"""
with gr.Tab("Beta - Talk to DRIAS", elem_id="tab-vanna", id=6):
with gr.Accordion(label="Details"):
gr.Markdown(details_text)
with gr.Row():
drias_direct_question = gr.Textbox(
label="Direct Question",
placeholder="You can write direct question here",
elem_id="direct-question",
interactive=True,
)
result_text = gr.Textbox(
label="", elem_id="no-result-label", interactive=False, visible=True
)
with gr.Accordion(label="SQL Query Used", visible=False) as query_accordion:
drias_sql_query = gr.Textbox(
label="", elem_id="sql-query", interactive=False
)
with gr.Accordion(label="Chart", visible=False) as chart_accordion:
model_selection = gr.Dropdown(
label="Model", choices=DRIAS_MODELS, value="ALL", interactive=True
)
drias_display = gr.Plot(elem_id="vanna-plot")
with gr.Accordion(
label="Data used", open=False, visible=False
) as table_accordion:
drias_table = gr.DataFrame([], elem_id="vanna-table")
pagination_display = gr.Markdown(value="", visible=False, elem_id="pagination-display")
with gr.Row():
prev_button = gr.Button("Previous", visible=False)
next_button = gr.Button("Next", visible=False)
sql_queries_state = gr.State([])
dataframes_state = gr.State([])
plots_state = gr.State([])
index_state = gr.State(0)
drias_direct_question.submit(
ask_drias_query,
inputs=[drias_direct_question, index_state],
outputs=[
drias_sql_query,
drias_table,
drias_display,
sql_queries_state,
dataframes_state,
plots_state,
index_state,
result_text,
],
).then(
show_results,
inputs=[sql_queries_state, dataframes_state, plots_state],
outputs=[
result_text,
query_accordion,
table_accordion,
chart_accordion,
prev_button,
next_button,
pagination_display
],
).then(
update_pagination,
inputs=[index_state, sql_queries_state],
outputs=[pagination_display],
)
model_selection.change(
filter_by_model,
inputs=[dataframes_state, plots_state, index_state, model_selection],
outputs=[drias_table, drias_display],
)
def show_previous(index, sql_queries, dataframes, plots):
if index > 0:
index -= 1
return (
sql_queries[index],
dataframes[index],
plots[index](dataframes[index]),
index,
)
def show_next(index, sql_queries, dataframes, plots):
if index < len(sql_queries) - 1:
index += 1
return (
sql_queries[index],
dataframes[index],
plots[index](dataframes[index]),
index,
)
prev_button.click(
show_previous,
inputs=[index_state, sql_queries_state, dataframes_state, plots_state],
outputs=[drias_sql_query, drias_table, drias_display, index_state],
).then(
update_pagination,
inputs=[index_state, sql_queries_state],
outputs=[pagination_display],
)
next_button.click(
show_next,
inputs=[index_state, sql_queries_state, dataframes_state, plots_state],
outputs=[drias_sql_query, drias_table, drias_display, index_state],
).then(
update_pagination,
inputs=[index_state, sql_queries_state],
outputs=[pagination_display],
)
def config_event_handling(
main_tabs_components: list[MainTabPanel], config_componenets: ConfigPanel
):
config_open = config_componenets.config_open
config_modal = config_componenets.config_modal
close_config_modal = config_componenets.close_config_modal_button
for button in [close_config_modal] + [
main_tab_component.config_button for main_tab_component in main_tabs_components
]:
button.click(
fn=update_config_modal_visibility,
inputs=[config_open],
outputs=[config_modal, config_open],
)
def event_handling(
main_tab_components: MainTabPanel,
config_components: ConfigPanel,
tab_name="ClimateQ&A",
):
chatbot = main_tab_components.chatbot
textbox = main_tab_components.textbox
tabs = main_tab_components.tabs
sources_raw = main_tab_components.sources_raw
new_figures = main_tab_components.new_figures
current_graphs = main_tab_components.current_graphs
examples_hidden = main_tab_components.examples_hidden
sources_textbox = main_tab_components.sources_textbox
figures_cards = main_tab_components.figures_cards
gallery_component = main_tab_components.gallery_component
papers_direct_search = main_tab_components.papers_direct_search
papers_html = main_tab_components.papers_html
citations_network = main_tab_components.citations_network
papers_summary = main_tab_components.papers_summary
tab_recommended_content = main_tab_components.tab_recommended_content
tab_sources = main_tab_components.tab_sources
tab_figures = main_tab_components.tab_figures
tab_graphs = main_tab_components.tab_graphs
tab_papers = main_tab_components.tab_papers
graphs_container = main_tab_components.graph_container
follow_up_examples = main_tab_components.follow_up_examples
follow_up_examples_hidden = main_tab_components.follow_up_examples_hidden
dropdown_sources = config_components.dropdown_sources
dropdown_reports = config_components.dropdown_reports
dropdown_external_sources = config_components.dropdown_external_sources
search_only = config_components.search_only
dropdown_audience = config_components.dropdown_audience
after = config_components.after
output_query = config_components.output_query
output_language = config_components.output_language
new_sources_hmtl = gr.State([])
ttd_data = gr.State([])
if tab_name == "ClimateQ&A":
print("chat cqa - message sent")
# Event for textbox
(
textbox.submit(
start_chat,
[textbox, chatbot, search_only],
[textbox, tabs, chatbot, sources_raw],
queue=False,
api_name=f"start_chat_{textbox.elem_id}",
)
.then(
chat,
[
textbox,
chatbot,
dropdown_audience,
dropdown_sources,
dropdown_reports,
dropdown_external_sources,
search_only,
],
[
chatbot,
new_sources_hmtl,
output_query,
output_language,
new_figures,
current_graphs,
follow_up_examples.dataset,
],
concurrency_limit=8,
api_name=f"chat_{textbox.elem_id}",
)
.then(
finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}"
)
)
# Event for examples_hidden
(
examples_hidden.change(
start_chat,
[examples_hidden, chatbot, search_only],
[examples_hidden, tabs, chatbot, sources_raw],
queue=False,
api_name=f"start_chat_{examples_hidden.elem_id}",
)
.then(
chat,
[
examples_hidden,
chatbot,
dropdown_audience,
dropdown_sources,
dropdown_reports,
dropdown_external_sources,
search_only,
],
[
chatbot,
new_sources_hmtl,
output_query,
output_language,
new_figures,
current_graphs,
follow_up_examples.dataset,
],
concurrency_limit=8,
api_name=f"chat_{examples_hidden.elem_id}",
)
.then(
finish_chat,
None,
[textbox],
api_name=f"finish_chat_{examples_hidden.elem_id}",
)
)
(
follow_up_examples_hidden.change(
start_chat,
[follow_up_examples_hidden, chatbot, search_only],
[follow_up_examples_hidden, tabs, chatbot, sources_raw],
queue=False,
api_name=f"start_chat_{examples_hidden.elem_id}",
)
.then(
chat,
[
follow_up_examples_hidden,
chatbot,
dropdown_audience,
dropdown_sources,
dropdown_reports,
dropdown_external_sources,
search_only,
],
[
chatbot,
new_sources_hmtl,
output_query,
output_language,
new_figures,
current_graphs,
follow_up_examples.dataset,
],
concurrency_limit=8,
api_name=f"chat_{examples_hidden.elem_id}",
)
.then(
finish_chat,
None,
[textbox],
api_name=f"finish_chat_{follow_up_examples_hidden.elem_id}",
)
)
elif tab_name == "Beta - POC Adapt'Action":
print("chat poc - message sent")
# Event for textbox
(
textbox.submit(
start_chat,
[textbox, chatbot, search_only],
[textbox, tabs, chatbot, sources_raw],
queue=False,
api_name=f"start_chat_{textbox.elem_id}",
)
.then(
chat_poc,
[
textbox,
chatbot,
dropdown_audience,
dropdown_sources,
dropdown_reports,
dropdown_external_sources,
search_only,
],
[
chatbot,
new_sources_hmtl,
output_query,
output_language,
new_figures,
current_graphs,
],
concurrency_limit=8,
api_name=f"chat_{textbox.elem_id}",
)
.then(
finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}"
)
)
# Event for examples_hidden
(
examples_hidden.change(
start_chat,
[examples_hidden, chatbot, search_only],
[examples_hidden, tabs, chatbot, sources_raw],
queue=False,
api_name=f"start_chat_{examples_hidden.elem_id}",
)
.then(
chat_poc,
[
examples_hidden,
chatbot,
dropdown_audience,
dropdown_sources,
dropdown_reports,
dropdown_external_sources,
search_only,
],
[
chatbot,
new_sources_hmtl,
output_query,
output_language,
new_figures,
current_graphs,
],
concurrency_limit=8,
api_name=f"chat_{examples_hidden.elem_id}",
)
.then(
finish_chat,
None,
[textbox],
api_name=f"finish_chat_{examples_hidden.elem_id}",
)
)
(
follow_up_examples_hidden.change(
start_chat,
[follow_up_examples_hidden, chatbot, search_only],
[follow_up_examples_hidden, tabs, chatbot, sources_raw],
queue=False,
api_name=f"start_chat_{examples_hidden.elem_id}",
)
.then(
chat,
[
follow_up_examples_hidden,
chatbot,
dropdown_audience,
dropdown_sources,
dropdown_reports,
dropdown_external_sources,
search_only,
],
[
chatbot,
new_sources_hmtl,
output_query,
output_language,
new_figures,
current_graphs,
follow_up_examples.dataset,
],
concurrency_limit=8,
api_name=f"chat_{examples_hidden.elem_id}",
)
.then(
finish_chat,
None,
[textbox],
api_name=f"finish_chat_{follow_up_examples_hidden.elem_id}",
)
)
new_sources_hmtl.change(
lambda x: x, inputs=[new_sources_hmtl], outputs=[sources_textbox]
)
current_graphs.change(
lambda x: x, inputs=[current_graphs], outputs=[graphs_container]
)
new_figures.change(
process_figures,
inputs=[sources_raw, new_figures],
outputs=[sources_raw, figures_cards, gallery_component],
)
# Update sources numbers
for component in [sources_textbox, figures_cards, current_graphs, papers_html]:
component.change(
update_sources_number_display,
[sources_textbox, figures_cards, current_graphs, papers_html],
[tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers],
)
# Search for papers
for component in [textbox, examples_hidden, papers_direct_search]:
component.submit(
find_papers,
[component, after, dropdown_external_sources],
[papers_html, citations_network, papers_summary],
)
# if tab_name == "Beta - POC Adapt'Action": # Not untill results are good enough
# # Drias search
# textbox.submit(ask_vanna, [textbox], [vanna_sql_query ,vanna_table, vanna_display])
def main_ui():
# config_open = gr.State(True)
with gr.Blocks(
title="Climate Q&A",
css_paths=os.getcwd() + "/style.css",
theme=theme,
elem_id="main-component",
) as demo:
config_components = create_config_modal()
with gr.Tabs():
cqa_components = cqa_tab(tab_name="ClimateQ&A")
local_cqa_components = cqa_tab(tab_name="Beta - POC Adapt'Action")
create_drias_tab()
create_about_tab()
event_handling(cqa_components, config_components, tab_name="ClimateQ&A")
event_handling(
local_cqa_components, config_components, tab_name="Beta - POC Adapt'Action"
)
config_event_handling([cqa_components, local_cqa_components], config_components)
demo.queue()
return demo
demo = main_ui()
demo.launch(ssr_mode=False)