test-tim
#12
by
timeki
- opened
This view is limited to 50 files because it contains too many changes.
See the raw diff here.
- .gitignore +0 -8
- README.md +1 -1
- app.py +607 -224
- climateqa/chat.py +0 -198
- climateqa/constants.py +1 -24
- climateqa/engine/chains/__init__.py +0 -0
- climateqa/engine/chains/answer_ai_impact.py +0 -46
- climateqa/engine/chains/answer_chitchat.py +0 -56
- climateqa/engine/chains/chitchat_categorization.py +0 -43
- climateqa/engine/chains/graph_retriever.py +0 -130
- climateqa/engine/chains/intent_categorization.py +0 -90
- climateqa/engine/chains/keywords_extraction.py +0 -40
- climateqa/engine/chains/query_transformation.py +0 -298
- climateqa/engine/chains/retrieve_documents.py +0 -465
- climateqa/engine/chains/retrieve_papers.py +0 -95
- climateqa/engine/chains/retriever.py +0 -126
- climateqa/engine/chains/sample_router.py +0 -66
- climateqa/engine/chains/set_defaults.py +0 -13
- climateqa/engine/chains/translation.py +0 -42
- climateqa/engine/embeddings.py +3 -6
- climateqa/engine/graph.py +0 -333
- climateqa/engine/graph_retriever.py +0 -88
- climateqa/engine/keywords.py +1 -3
- climateqa/engine/llm/__init__.py +0 -3
- climateqa/engine/llm/ollama.py +0 -6
- climateqa/engine/llm/openai.py +1 -1
- climateqa/engine/{chains/prompts.py → prompts.py} +6 -56
- climateqa/engine/{chains/answer_rag.py → rag.py} +60 -39
- climateqa/engine/{chains/reformulation.py → reformulation.py} +1 -1
- climateqa/engine/reranker.py +0 -55
- climateqa/engine/retriever.py +163 -0
- climateqa/engine/utils.py +0 -17
- climateqa/engine/vectorstore.py +2 -4
- climateqa/handle_stream_events.py +0 -126
- climateqa/knowledge/__init__.py +0 -0
- climateqa/knowledge/retriever.py +0 -102
- climateqa/papers/__init__.py +43 -0
- climateqa/{knowledge → papers}/openalex.py +15 -68
- climateqa/utils.py +0 -13
- front/__init__.py +0 -0
- front/callbacks.py +0 -0
- front/deprecated.py +0 -46
- front/event_listeners.py +0 -0
- front/tabs/__init__.py +0 -6
- front/tabs/chat_interface.py +0 -55
- front/tabs/main_tab.py +0 -69
- front/tabs/tab_about.py +0 -38
- front/tabs/tab_config.py +0 -123
- front/tabs/tab_examples.py +0 -40
- front/tabs/tab_figures.py +0 -31
.gitignore
CHANGED
@@ -5,11 +5,3 @@ __pycache__/utils.cpython-38.pyc
|
|
5 |
|
6 |
notebooks/
|
7 |
*.pyc
|
8 |
-
|
9 |
-
**/.ipynb_checkpoints/
|
10 |
-
**/.flashrank_cache/
|
11 |
-
|
12 |
-
data/
|
13 |
-
sandbox/
|
14 |
-
|
15 |
-
*.db
|
|
|
5 |
|
6 |
notebooks/
|
7 |
*.pyc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
@@ -4,7 +4,7 @@ emoji: 🌍
|
|
4 |
colorFrom: blue
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
fullWidth: true
|
10 |
pinned: false
|
|
|
4 |
colorFrom: blue
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.19.1
|
8 |
app_file: app.py
|
9 |
fullWidth: true
|
10 |
pinned: false
|
app.py
CHANGED
@@ -1,30 +1,44 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
import gradio as gr
|
4 |
|
5 |
-
from
|
|
|
6 |
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
15 |
|
16 |
-
from
|
17 |
-
from front.utils import process_figures
|
18 |
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
from utils import create_user_id
|
21 |
-
import logging
|
22 |
|
23 |
-
logging.basicConfig(level=logging.WARNING)
|
24 |
-
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Suppresses INFO and WARNING logs
|
25 |
-
logging.getLogger().setLevel(logging.WARNING)
|
26 |
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
# Load environment variables in local mode
|
30 |
try:
|
@@ -33,7 +47,6 @@ try:
|
|
33 |
except Exception as e:
|
34 |
pass
|
35 |
|
36 |
-
|
37 |
# Set up Gradio Theme
|
38 |
theme = gr.themes.Base(
|
39 |
primary_hue="blue",
|
@@ -41,7 +54,15 @@ theme = gr.themes.Base(
|
|
41 |
font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"],
|
42 |
)
|
43 |
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
account_key = os.environ["BLOB_ACCOUNT_KEY"]
|
46 |
if len(account_key) == 86:
|
47 |
account_key += "=="
|
@@ -60,235 +81,597 @@ user_id = create_user_id()
|
|
60 |
|
61 |
|
62 |
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
|
|
|
|
69 |
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
|
70 |
-
if os.getenv("ENV")=="GRADIO_ENV":
|
71 |
-
reranker = get_reranker("nano")
|
72 |
-
else:
|
73 |
-
reranker = get_reranker("large")
|
74 |
|
75 |
-
|
76 |
-
|
|
|
|
|
77 |
|
78 |
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
|
90 |
# --------------------------------------------------------------------
|
91 |
# Gradio
|
92 |
# --------------------------------------------------------------------
|
93 |
|
94 |
-
# Function to update modal visibility
|
95 |
-
def update_config_modal_visibility(config_open):
|
96 |
-
new_config_visibility_status = not config_open
|
97 |
-
return gr.update(visible=new_config_visibility_status), new_config_visibility_status
|
98 |
-
|
99 |
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
def
|
116 |
-
|
117 |
-
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
with gr.Row(elem_id="chatbot-row"):
|
120 |
-
# Left column - Chat interface
|
121 |
with gr.Column(scale=2):
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
|
124 |
-
# Right column - Content panels
|
125 |
-
with gr.Column(scale=2, variant="panel", elem_id="right-panel"):
|
126 |
-
with gr.Tabs(elem_id="right_panel_tab") as tabs:
|
127 |
-
# Examples tab
|
128 |
-
with gr.TabItem("Examples", elem_id="tab-examples", id=0):
|
129 |
-
examples_hidden = create_examples_tab()
|
130 |
|
131 |
-
# Sources tab
|
132 |
-
with gr.Tab("Sources", elem_id="tab-sources", id=1) as tab_sources:
|
133 |
-
sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox")
|
134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
-
|
137 |
-
|
138 |
-
with gr.Tabs(elem_id="group-subtabs") as tabs_recommended_content:
|
139 |
-
# Figures subtab
|
140 |
-
with gr.Tab("Figures", elem_id="tab-figures", id=3) as tab_figures:
|
141 |
-
sources_raw, new_figures, used_figures, gallery_component, figures_cards, figure_modal = create_figures_tab()
|
142 |
|
143 |
-
|
144 |
-
with gr.Tab("Papers", elem_id="tab-citations", id=4) as tab_papers:
|
145 |
-
papers_summary, papers_html, citations_network, papers_modal = create_papers_tab()
|
146 |
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
)
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
papers_html = main_tab_components["papers_html"]
|
195 |
-
citations_network = main_tab_components["citations_network"]
|
196 |
-
papers_summary = main_tab_components["papers_summary"]
|
197 |
-
tab_recommended_content = main_tab_components["tab_recommended_content"]
|
198 |
-
tab_sources = main_tab_components["tab_sources"]
|
199 |
-
tab_figures = main_tab_components["tab_figures"]
|
200 |
-
tab_graphs = main_tab_components["tab_graphs"]
|
201 |
-
tab_papers = main_tab_components["tab_papers"]
|
202 |
-
graphs_container = main_tab_components["graph_container"]
|
203 |
-
|
204 |
-
config_open = config_components["config_open"]
|
205 |
-
config_modal = config_components["config_modal"]
|
206 |
-
dropdown_sources = config_components["dropdown_sources"]
|
207 |
-
dropdown_reports = config_components["dropdown_reports"]
|
208 |
-
dropdown_external_sources = config_components["dropdown_external_sources"]
|
209 |
-
search_only = config_components["search_only"]
|
210 |
-
dropdown_audience = config_components["dropdown_audience"]
|
211 |
-
after = config_components["after"]
|
212 |
-
output_query = config_components["output_query"]
|
213 |
-
output_language = config_components["output_language"]
|
214 |
-
close_config_modal = config_components["close_config_modal_button"]
|
215 |
-
|
216 |
-
new_sources_hmtl = gr.State([])
|
217 |
-
|
218 |
-
print("textbox id : ", textbox.elem_id)
|
219 |
-
|
220 |
-
for button in [config_button, close_config_modal]:
|
221 |
-
button.click(
|
222 |
-
fn=update_config_modal_visibility,
|
223 |
-
inputs=[config_open],
|
224 |
-
outputs=[config_modal, config_open]
|
225 |
-
)
|
226 |
-
|
227 |
-
if tab_name == "ClimateQ&A":
|
228 |
-
print("chat cqa - message sent")
|
229 |
-
|
230 |
-
# Event for textbox
|
231 |
-
(textbox
|
232 |
-
.submit(start_chat, [textbox, chatbot, search_only], [textbox, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{textbox.elem_id}")
|
233 |
-
.then(chat, [textbox, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name=f"chat_{textbox.elem_id}")
|
234 |
-
.then(finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}")
|
235 |
-
)
|
236 |
-
# Event for examples_hidden
|
237 |
-
(examples_hidden
|
238 |
-
.change(start_chat, [examples_hidden, chatbot, search_only], [examples_hidden, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{examples_hidden.elem_id}")
|
239 |
-
.then(chat, [examples_hidden, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name=f"chat_{examples_hidden.elem_id}")
|
240 |
-
.then(finish_chat, None, [textbox], api_name=f"finish_chat_{examples_hidden.elem_id}")
|
241 |
-
)
|
242 |
-
|
243 |
-
elif tab_name == "Beta - POC Adapt'Action":
|
244 |
-
print("chat poc - message sent")
|
245 |
-
# Event for textbox
|
246 |
-
(textbox
|
247 |
-
.submit(start_chat, [textbox, chatbot, search_only], [textbox, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{textbox.elem_id}")
|
248 |
-
.then(chat_poc, [textbox, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name=f"chat_{textbox.elem_id}")
|
249 |
-
.then(finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}")
|
250 |
-
)
|
251 |
-
# Event for examples_hidden
|
252 |
-
(examples_hidden
|
253 |
-
.change(start_chat, [examples_hidden, chatbot, search_only], [examples_hidden, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{examples_hidden.elem_id}")
|
254 |
-
.then(chat_poc, [examples_hidden, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name=f"chat_{examples_hidden.elem_id}")
|
255 |
-
.then(finish_chat, None, [textbox], api_name=f"finish_chat_{examples_hidden.elem_id}")
|
256 |
-
)
|
257 |
-
|
258 |
-
|
259 |
-
new_sources_hmtl.change(lambda x : x, inputs = [new_sources_hmtl], outputs = [sources_textbox])
|
260 |
-
current_graphs.change(lambda x: x, inputs=[current_graphs], outputs=[graphs_container])
|
261 |
-
new_figures.change(process_figures, inputs=[sources_raw, new_figures], outputs=[sources_raw, figures_cards, gallery_component])
|
262 |
|
263 |
-
|
264 |
-
|
265 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
266 |
|
267 |
-
|
268 |
-
|
269 |
-
component.submit(find_papers, [component, after, dropdown_external_sources], [papers_html, citations_network, papers_summary])
|
270 |
-
|
271 |
-
|
272 |
|
|
|
|
|
|
|
|
|
|
|
273 |
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
291 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
292 |
|
293 |
-
|
294 |
-
demo.
|
|
|
|
|
|
1 |
+
from climateqa.engine.embeddings import get_embeddings_function
|
2 |
+
embeddings_function = get_embeddings_function()
|
|
|
3 |
|
4 |
+
from climateqa.papers.openalex import OpenAlex
|
5 |
+
from sentence_transformers import CrossEncoder
|
6 |
|
7 |
+
reranker = CrossEncoder("mixedbread-ai/mxbai-rerank-xsmall-v1")
|
8 |
+
oa = OpenAlex()
|
9 |
+
|
10 |
+
import gradio as gr
|
11 |
+
import pandas as pd
|
12 |
+
import numpy as np
|
13 |
+
import os
|
14 |
+
import time
|
15 |
+
import re
|
16 |
+
import json
|
17 |
|
18 |
+
# from gradio_modal import Modal
|
|
|
19 |
|
20 |
+
from io import BytesIO
|
21 |
+
import base64
|
22 |
+
|
23 |
+
from datetime import datetime
|
24 |
+
from azure.storage.fileshare import ShareServiceClient
|
25 |
|
26 |
from utils import create_user_id
|
|
|
27 |
|
|
|
|
|
|
|
28 |
|
29 |
|
30 |
+
# ClimateQ&A imports
|
31 |
+
from climateqa.engine.llm import get_llm
|
32 |
+
from climateqa.engine.rag import make_rag_chain
|
33 |
+
from climateqa.engine.vectorstore import get_pinecone_vectorstore
|
34 |
+
from climateqa.engine.retriever import ClimateQARetriever
|
35 |
+
from climateqa.engine.embeddings import get_embeddings_function
|
36 |
+
from climateqa.engine.prompts import audience_prompts
|
37 |
+
from climateqa.sample_questions import QUESTIONS
|
38 |
+
from climateqa.constants import POSSIBLE_REPORTS
|
39 |
+
from climateqa.utils import get_image_from_azure_blob_storage
|
40 |
+
from climateqa.engine.keywords import make_keywords_chain
|
41 |
+
from climateqa.engine.rag import make_rag_papers_chain
|
42 |
|
43 |
# Load environment variables in local mode
|
44 |
try:
|
|
|
47 |
except Exception as e:
|
48 |
pass
|
49 |
|
|
|
50 |
# Set up Gradio Theme
|
51 |
theme = gr.themes.Base(
|
52 |
primary_hue="blue",
|
|
|
54 |
font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"],
|
55 |
)
|
56 |
|
57 |
+
|
58 |
+
|
59 |
+
init_prompt = ""
|
60 |
+
|
61 |
+
system_template = {
|
62 |
+
"role": "system",
|
63 |
+
"content": init_prompt,
|
64 |
+
}
|
65 |
+
|
66 |
account_key = os.environ["BLOB_ACCOUNT_KEY"]
|
67 |
if len(account_key) == 86:
|
68 |
account_key += "=="
|
|
|
81 |
|
82 |
|
83 |
|
84 |
+
def parse_output_llm_with_sources(output):
|
85 |
+
# Split the content into a list of text and "[Doc X]" references
|
86 |
+
content_parts = re.split(r'\[(Doc\s?\d+(?:,\s?Doc\s?\d+)*)\]', output)
|
87 |
+
parts = []
|
88 |
+
for part in content_parts:
|
89 |
+
if part.startswith("Doc"):
|
90 |
+
subparts = part.split(",")
|
91 |
+
subparts = [subpart.lower().replace("doc","").strip() for subpart in subparts]
|
92 |
+
subparts = [f"""<a href="#doc{subpart}" class="a-doc-ref" target="_self"><span class='doc-ref'><sup>{subpart}</sup></span></a>""" for subpart in subparts]
|
93 |
+
parts.append("".join(subparts))
|
94 |
+
else:
|
95 |
+
parts.append(part)
|
96 |
+
content_parts = "".join(parts)
|
97 |
+
return content_parts
|
98 |
+
|
99 |
|
100 |
+
# Create vectorstore and retriever
|
101 |
+
vectorstore = get_pinecone_vectorstore(embeddings_function)
|
102 |
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
|
|
|
|
|
|
|
|
|
103 |
|
104 |
+
|
105 |
+
def make_pairs(lst):
|
106 |
+
"""from a list of even lenght, make tupple pairs"""
|
107 |
+
return [(lst[i], lst[i + 1]) for i in range(0, len(lst), 2)]
|
108 |
|
109 |
|
110 |
+
def serialize_docs(docs):
|
111 |
+
new_docs = []
|
112 |
+
for doc in docs:
|
113 |
+
new_doc = {}
|
114 |
+
new_doc["page_content"] = doc.page_content
|
115 |
+
new_doc["metadata"] = doc.metadata
|
116 |
+
new_docs.append(new_doc)
|
117 |
+
return new_docs
|
118 |
+
|
119 |
+
|
120 |
+
|
121 |
+
async def chat(query,history,audience,sources,reports):
|
122 |
+
"""taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of:
|
123 |
+
(messages in gradio format, messages in langchain format, source documents)"""
|
124 |
+
|
125 |
+
print(f">> NEW QUESTION : {query}")
|
126 |
+
|
127 |
+
if audience == "Children":
|
128 |
+
audience_prompt = audience_prompts["children"]
|
129 |
+
elif audience == "General public":
|
130 |
+
audience_prompt = audience_prompts["general"]
|
131 |
+
elif audience == "Experts":
|
132 |
+
audience_prompt = audience_prompts["experts"]
|
133 |
+
else:
|
134 |
+
audience_prompt = audience_prompts["experts"]
|
135 |
+
|
136 |
+
# Prepare default values
|
137 |
+
if len(sources) == 0:
|
138 |
+
sources = ["IPCC"]
|
139 |
+
|
140 |
+
if len(reports) == 0:
|
141 |
+
reports = []
|
142 |
+
|
143 |
+
retriever = ClimateQARetriever(vectorstore=vectorstore,sources = sources,min_size = 200,reports = reports,k_summary = 3,k_total = 15,threshold=0.5)
|
144 |
+
rag_chain = make_rag_chain(retriever,llm)
|
145 |
+
|
146 |
+
inputs = {"query": query,"audience": audience_prompt}
|
147 |
+
result = rag_chain.astream_log(inputs) #{"callbacks":[MyCustomAsyncHandler()]})
|
148 |
+
# result = rag_chain.stream(inputs)
|
149 |
+
|
150 |
+
path_reformulation = "/logs/reformulation/final_output"
|
151 |
+
path_keywords = "/logs/keywords/final_output"
|
152 |
+
path_retriever = "/logs/find_documents/final_output"
|
153 |
+
path_answer = "/logs/answer/streamed_output_str/-"
|
154 |
+
|
155 |
+
docs_html = ""
|
156 |
+
output_query = ""
|
157 |
+
output_language = ""
|
158 |
+
output_keywords = ""
|
159 |
+
gallery = []
|
160 |
+
|
161 |
+
try:
|
162 |
+
async for op in result:
|
163 |
+
|
164 |
+
op = op.ops[0]
|
165 |
+
|
166 |
+
if op['path'] == path_reformulation: # reforulated question
|
167 |
+
try:
|
168 |
+
output_language = op['value']["language"] # str
|
169 |
+
output_query = op["value"]["question"]
|
170 |
+
except Exception as e:
|
171 |
+
raise gr.Error(f"ClimateQ&A Error: {e} - The error has been noted, try another question and if the error remains, you can contact us :)")
|
172 |
+
|
173 |
+
if op["path"] == path_keywords:
|
174 |
+
try:
|
175 |
+
output_keywords = op['value']["keywords"] # str
|
176 |
+
output_keywords = " AND ".join(output_keywords)
|
177 |
+
except Exception as e:
|
178 |
+
pass
|
179 |
+
|
180 |
+
|
181 |
+
elif op['path'] == path_retriever: # documents
|
182 |
+
try:
|
183 |
+
docs = op['value']['docs'] # List[Document]
|
184 |
+
docs_html = []
|
185 |
+
for i, d in enumerate(docs, 1):
|
186 |
+
docs_html.append(make_html_source(d, i))
|
187 |
+
docs_html = "".join(docs_html)
|
188 |
+
except TypeError:
|
189 |
+
print("No documents found")
|
190 |
+
print("op: ",op)
|
191 |
+
continue
|
192 |
+
|
193 |
+
elif op['path'] == path_answer: # final answer
|
194 |
+
new_token = op['value'] # str
|
195 |
+
# time.sleep(0.01)
|
196 |
+
previous_answer = history[-1][1]
|
197 |
+
previous_answer = previous_answer if previous_answer is not None else ""
|
198 |
+
answer_yet = previous_answer + new_token
|
199 |
+
answer_yet = parse_output_llm_with_sources(answer_yet)
|
200 |
+
history[-1] = (query,answer_yet)
|
201 |
+
|
202 |
+
|
203 |
+
|
204 |
+
else:
|
205 |
+
continue
|
206 |
+
|
207 |
+
history = [tuple(x) for x in history]
|
208 |
+
yield history,docs_html,output_query,output_language,gallery,output_query,output_keywords
|
209 |
+
|
210 |
+
except Exception as e:
|
211 |
+
raise gr.Error(f"{e}")
|
212 |
+
|
213 |
+
|
214 |
+
try:
|
215 |
+
# Log answer on Azure Blob Storage
|
216 |
+
if os.getenv("GRADIO_ENV") != "local":
|
217 |
+
timestamp = str(datetime.now().timestamp())
|
218 |
+
file = timestamp + ".json"
|
219 |
+
prompt = history[-1][0]
|
220 |
+
logs = {
|
221 |
+
"user_id": str(user_id),
|
222 |
+
"prompt": prompt,
|
223 |
+
"query": prompt,
|
224 |
+
"question":output_query,
|
225 |
+
"sources":sources,
|
226 |
+
"docs":serialize_docs(docs),
|
227 |
+
"answer": history[-1][1],
|
228 |
+
"time": timestamp,
|
229 |
+
}
|
230 |
+
log_on_azure(file, logs, share_client)
|
231 |
+
except Exception as e:
|
232 |
+
print(f"Error logging on Azure Blob Storage: {e}")
|
233 |
+
raise gr.Error(f"ClimateQ&A Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)")
|
234 |
+
|
235 |
+
image_dict = {}
|
236 |
+
for i,doc in enumerate(docs):
|
237 |
|
238 |
+
if doc.metadata["chunk_type"] == "image":
|
239 |
+
try:
|
240 |
+
key = f"Image {i+1}"
|
241 |
+
image_path = doc.metadata["image_path"].split("documents/")[1]
|
242 |
+
img = get_image_from_azure_blob_storage(image_path)
|
243 |
+
|
244 |
+
# Convert the image to a byte buffer
|
245 |
+
buffered = BytesIO()
|
246 |
+
img.save(buffered, format="PNG")
|
247 |
+
img_str = base64.b64encode(buffered.getvalue()).decode()
|
248 |
+
|
249 |
+
# Embedding the base64 string in Markdown
|
250 |
+
markdown_image = f""
|
251 |
+
image_dict[key] = {"img":img,"md":markdown_image,"caption":doc.page_content,"key":key,"figure_code":doc.metadata["figure_code"]}
|
252 |
+
except Exception as e:
|
253 |
+
print(f"Skipped adding image {i} because of {e}")
|
254 |
+
|
255 |
+
if len(image_dict) > 0:
|
256 |
+
|
257 |
+
gallery = [x["img"] for x in list(image_dict.values())]
|
258 |
+
img = list(image_dict.values())[0]
|
259 |
+
img_md = img["md"]
|
260 |
+
img_caption = img["caption"]
|
261 |
+
img_code = img["figure_code"]
|
262 |
+
if img_code != "N/A":
|
263 |
+
img_name = f"{img['key']} - {img['figure_code']}"
|
264 |
+
else:
|
265 |
+
img_name = f"{img['key']}"
|
266 |
+
|
267 |
+
answer_yet = history[-1][1] + f"\n\n{img_md}\n<p class='chatbot-caption'><b>{img_name}</b> - {img_caption}</p>"
|
268 |
+
history[-1] = (history[-1][0],answer_yet)
|
269 |
+
history = [tuple(x) for x in history]
|
270 |
+
|
271 |
+
# gallery = [x.metadata["image_path"] for x in docs if (len(x.metadata["image_path"]) > 0 and "IAS" in x.metadata["image_path"])]
|
272 |
+
# if len(gallery) > 0:
|
273 |
+
# gallery = list(set("|".join(gallery).split("|")))
|
274 |
+
# gallery = [get_image_from_azure_blob_storage(x) for x in gallery]
|
275 |
+
|
276 |
+
yield history,docs_html,output_query,output_language,gallery,output_query,output_keywords
|
277 |
+
|
278 |
+
|
279 |
+
def make_html_source(source,i):
|
280 |
+
meta = source.metadata
|
281 |
+
# content = source.page_content.split(":",1)[1].strip()
|
282 |
+
content = source.page_content.strip()
|
283 |
+
|
284 |
+
toc_levels = []
|
285 |
+
for j in range(2):
|
286 |
+
level = meta[f"toc_level{j}"]
|
287 |
+
if level != "N/A":
|
288 |
+
toc_levels.append(level)
|
289 |
+
else:
|
290 |
+
break
|
291 |
+
toc_levels = " > ".join(toc_levels)
|
292 |
+
|
293 |
+
if len(toc_levels) > 0:
|
294 |
+
name = f"<b>{toc_levels}</b><br/>{meta['name']}"
|
295 |
+
else:
|
296 |
+
name = meta['name']
|
297 |
+
|
298 |
+
if meta["chunk_type"] == "text":
|
299 |
+
|
300 |
+
card = f"""
|
301 |
+
<div class="card" id="doc{i}">
|
302 |
+
<div class="card-content">
|
303 |
+
<h2>Doc {i} - {meta['short_name']} - Page {int(meta['page_number'])}</h2>
|
304 |
+
<p>{content}</p>
|
305 |
+
</div>
|
306 |
+
<div class="card-footer">
|
307 |
+
<span>{name}</span>
|
308 |
+
<a href="{meta['url']}#page={int(meta['page_number'])}" target="_blank" class="pdf-link">
|
309 |
+
<span role="img" aria-label="Open PDF">🔗</span>
|
310 |
+
</a>
|
311 |
+
</div>
|
312 |
+
</div>
|
313 |
+
"""
|
314 |
+
|
315 |
+
else:
|
316 |
+
|
317 |
+
if meta["figure_code"] != "N/A":
|
318 |
+
title = f"{meta['figure_code']} - {meta['short_name']}"
|
319 |
+
else:
|
320 |
+
title = f"{meta['short_name']}"
|
321 |
+
|
322 |
+
card = f"""
|
323 |
+
<div class="card card-image">
|
324 |
+
<div class="card-content">
|
325 |
+
<h2>Image {i} - {title} - Page {int(meta['page_number'])}</h2>
|
326 |
+
<p>{content}</p>
|
327 |
+
<p class='ai-generated'>AI-generated description</p>
|
328 |
+
</div>
|
329 |
+
<div class="card-footer">
|
330 |
+
<span>{name}</span>
|
331 |
+
<a href="{meta['url']}#page={int(meta['page_number'])}" target="_blank" class="pdf-link">
|
332 |
+
<span role="img" aria-label="Open PDF">🔗</span>
|
333 |
+
</a>
|
334 |
+
</div>
|
335 |
+
</div>
|
336 |
+
"""
|
337 |
+
|
338 |
+
return card
|
339 |
+
|
340 |
+
|
341 |
+
|
342 |
+
# else:
|
343 |
+
# docs_string = "No relevant passages found in the climate science reports (IPCC and IPBES)"
|
344 |
+
# complete_response = "**No relevant passages found in the climate science reports (IPCC and IPBES), you may want to ask a more specific question (specifying your question on climate issues).**"
|
345 |
+
# messages.append({"role": "assistant", "content": complete_response})
|
346 |
+
# gradio_format = make_pairs([a["content"] for a in messages[1:]])
|
347 |
+
# yield gradio_format, messages, docs_string
|
348 |
+
|
349 |
+
|
350 |
+
def save_feedback(feed: str, user_id):
|
351 |
+
if len(feed) > 1:
|
352 |
+
timestamp = str(datetime.now().timestamp())
|
353 |
+
file = user_id + timestamp + ".json"
|
354 |
+
logs = {
|
355 |
+
"user_id": user_id,
|
356 |
+
"feedback": feed,
|
357 |
+
"time": timestamp,
|
358 |
+
}
|
359 |
+
log_on_azure(file, logs, share_client)
|
360 |
+
return "Feedback submitted, thank you!"
|
361 |
+
|
362 |
+
|
363 |
+
|
364 |
+
|
365 |
+
def log_on_azure(file, logs, share_client):
|
366 |
+
logs = json.dumps(logs)
|
367 |
+
file_client = share_client.get_file_client(file)
|
368 |
+
file_client.upload_file(logs)
|
369 |
+
|
370 |
+
|
371 |
+
def generate_keywords(query):
|
372 |
+
chain = make_keywords_chain(llm)
|
373 |
+
keywords = chain.invoke(query)
|
374 |
+
keywords = " AND ".join(keywords["keywords"])
|
375 |
+
return keywords
|
376 |
+
|
377 |
+
|
378 |
+
|
379 |
+
papers_cols_widths = {
|
380 |
+
"doc":50,
|
381 |
+
"id":100,
|
382 |
+
"title":300,
|
383 |
+
"doi":100,
|
384 |
+
"publication_year":100,
|
385 |
+
"abstract":500,
|
386 |
+
"rerank_score":100,
|
387 |
+
"is_oa":50,
|
388 |
+
}
|
389 |
+
|
390 |
+
papers_cols = list(papers_cols_widths.keys())
|
391 |
+
papers_cols_widths = list(papers_cols_widths.values())
|
392 |
+
|
393 |
+
async def find_papers(query, keywords,after):
|
394 |
+
|
395 |
+
summary = ""
|
396 |
+
|
397 |
+
df_works = oa.search(keywords,after = after)
|
398 |
+
df_works = df_works.dropna(subset=["abstract"])
|
399 |
+
df_works = oa.rerank(query,df_works,reranker)
|
400 |
+
df_works = df_works.sort_values("rerank_score",ascending=False)
|
401 |
+
G = oa.make_network(df_works)
|
402 |
+
|
403 |
+
height = "750px"
|
404 |
+
network = oa.show_network(G,color_by = "rerank_score",notebook=False,height = height)
|
405 |
+
network_html = network.generate_html()
|
406 |
+
|
407 |
+
network_html = network_html.replace("'", "\"")
|
408 |
+
css_to_inject = "<style>#mynetwork { border: none !important; } .card { border: none !important; }</style>"
|
409 |
+
network_html = network_html + css_to_inject
|
410 |
+
|
411 |
+
|
412 |
+
network_html = f"""<iframe style="width: 100%; height: {height};margin:0 auto" name="result" allow="midi; geolocation; microphone; camera;
|
413 |
+
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
|
414 |
+
allow-scripts allow-same-origin allow-popups
|
415 |
+
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
|
416 |
+
allowpaymentrequest="" frameborder="0" srcdoc='{network_html}'></iframe>"""
|
417 |
+
|
418 |
+
|
419 |
+
docs = df_works["content"].head(15).tolist()
|
420 |
+
|
421 |
+
df_works = df_works.reset_index(drop = True).reset_index().rename(columns = {"index":"doc"})
|
422 |
+
df_works["doc"] = df_works["doc"] + 1
|
423 |
+
df_works = df_works[papers_cols]
|
424 |
+
|
425 |
+
yield df_works,network_html,summary
|
426 |
+
|
427 |
+
chain = make_rag_papers_chain(llm)
|
428 |
+
result = chain.astream_log({"question": query,"docs": docs,"language":"English"})
|
429 |
+
path_answer = "/logs/StrOutputParser/streamed_output/-"
|
430 |
+
|
431 |
+
async for op in result:
|
432 |
+
|
433 |
+
op = op.ops[0]
|
434 |
+
|
435 |
+
if op['path'] == path_answer: # reforulated question
|
436 |
+
new_token = op['value'] # str
|
437 |
+
summary += new_token
|
438 |
+
else:
|
439 |
+
continue
|
440 |
+
yield df_works,network_html,summary
|
441 |
+
|
442 |
|
443 |
|
444 |
# --------------------------------------------------------------------
|
445 |
# Gradio
|
446 |
# --------------------------------------------------------------------
|
447 |
|
|
|
|
|
|
|
|
|
|
|
448 |
|
449 |
+
init_prompt = """
|
450 |
+
Hello, I am ClimateQ&A, a conversational assistant designed to help you understand climate change and biodiversity loss. I will answer your questions by **sifting through the IPCC and IPBES scientific reports**.
|
451 |
+
|
452 |
+
❓ How to use
|
453 |
+
- **Language**: You can ask me your questions in any language.
|
454 |
+
- **Audience**: You can specify your audience (children, general public, experts) to get a more adapted answer.
|
455 |
+
- **Sources**: You can choose to search in the IPCC or IPBES reports, or both.
|
456 |
+
|
457 |
+
⚠️ Limitations
|
458 |
+
*Please note that the AI is not perfect and may sometimes give irrelevant answers. If you are not satisfied with the answer, please ask a more specific question or report your feedback to help us improve the system.*
|
459 |
+
|
460 |
+
What do you want to learn ?
|
461 |
+
"""
|
462 |
+
|
463 |
+
|
464 |
+
def vote(data: gr.LikeData):
|
465 |
+
if data.liked:
|
466 |
+
print(data.value)
|
467 |
+
else:
|
468 |
+
print(data)
|
469 |
+
|
470 |
+
|
471 |
+
|
472 |
+
with gr.Blocks(title="Climate Q&A", css="style.css", theme=theme,elem_id = "main-component") as demo:
|
473 |
+
# user_id_state = gr.State([user_id])
|
474 |
+
|
475 |
+
with gr.Tab("ClimateQ&A"):
|
476 |
+
|
477 |
with gr.Row(elem_id="chatbot-row"):
|
|
|
478 |
with gr.Column(scale=2):
|
479 |
+
# state = gr.State([system_template])
|
480 |
+
chatbot = gr.Chatbot(
|
481 |
+
value=[(None,init_prompt)],
|
482 |
+
show_copy_button=True,show_label = False,elem_id="chatbot",layout = "panel",
|
483 |
+
avatar_images = (None,"https://i.ibb.co/YNyd5W2/logo4.png"),
|
484 |
+
)#,avatar_images = ("assets/logo4.png",None))
|
485 |
+
|
486 |
+
# bot.like(vote,None,None)
|
487 |
|
|
|
|
|
|
|
|
|
|
|
|
|
488 |
|
|
|
|
|
|
|
489 |
|
490 |
+
with gr.Row(elem_id = "input-message"):
|
491 |
+
textbox=gr.Textbox(placeholder="Ask me anything here!",show_label=False,scale=7,lines = 1,interactive = True,elem_id="input-textbox")
|
492 |
+
# submit = gr.Button("",elem_id = "submit-button",scale = 1,interactive = True,icon = "https://static-00.iconduck.com/assets.00/settings-icon-2048x2046-cw28eevx.png")
|
493 |
+
|
494 |
+
|
495 |
+
with gr.Column(scale=1, variant="panel",elem_id = "right-panel"):
|
496 |
+
|
497 |
+
|
498 |
+
with gr.Tabs() as tabs:
|
499 |
+
with gr.TabItem("Examples",elem_id = "tab-examples",id = 0):
|
500 |
+
|
501 |
+
examples_hidden = gr.Textbox(visible = False)
|
502 |
+
first_key = list(QUESTIONS.keys())[0]
|
503 |
+
dropdown_samples = gr.Dropdown(QUESTIONS.keys(),value = first_key,interactive = True,show_label = True,label = "Select a category of sample questions",elem_id = "dropdown-samples")
|
504 |
|
505 |
+
samples = []
|
506 |
+
for i,key in enumerate(QUESTIONS.keys()):
|
|
|
|
|
|
|
|
|
507 |
|
508 |
+
examples_visible = True if i == 0 else False
|
|
|
|
|
509 |
|
510 |
+
with gr.Row(visible = examples_visible) as group_examples:
|
511 |
+
|
512 |
+
examples_questions = gr.Examples(
|
513 |
+
QUESTIONS[key],
|
514 |
+
[examples_hidden],
|
515 |
+
examples_per_page=8,
|
516 |
+
run_on_click=False,
|
517 |
+
elem_id=f"examples{i}",
|
518 |
+
api_name=f"examples{i}",
|
519 |
+
# label = "Click on the example question or enter your own",
|
520 |
+
# cache_examples=True,
|
521 |
)
|
522 |
+
|
523 |
+
samples.append(group_examples)
|
524 |
+
|
525 |
+
|
526 |
+
with gr.Tab("Sources",elem_id = "tab-citations",id = 1):
|
527 |
+
sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox")
|
528 |
+
docs_textbox = gr.State("")
|
529 |
+
|
530 |
+
# with Modal(visible = False) as config_modal:
|
531 |
+
with gr.Tab("Configuration",elem_id = "tab-config",id = 2):
|
532 |
+
|
533 |
+
gr.Markdown("Reminder: You can talk in any language, ClimateQ&A is multi-lingual!")
|
534 |
+
|
535 |
+
|
536 |
+
dropdown_sources = gr.CheckboxGroup(
|
537 |
+
["IPCC", "IPBES","IPOS"],
|
538 |
+
label="Select source",
|
539 |
+
value=["IPCC"],
|
540 |
+
interactive=True,
|
541 |
+
)
|
542 |
+
|
543 |
+
dropdown_reports = gr.Dropdown(
|
544 |
+
POSSIBLE_REPORTS,
|
545 |
+
label="Or select specific reports",
|
546 |
+
multiselect=True,
|
547 |
+
value=None,
|
548 |
+
interactive=True,
|
549 |
+
)
|
550 |
+
|
551 |
+
dropdown_audience = gr.Dropdown(
|
552 |
+
["Children","General public","Experts"],
|
553 |
+
label="Select audience",
|
554 |
+
value="Experts",
|
555 |
+
interactive=True,
|
556 |
+
)
|
557 |
+
|
558 |
+
output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False)
|
559 |
+
output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False)
|
560 |
+
|
561 |
+
|
562 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
563 |
|
564 |
+
|
565 |
+
|
566 |
+
#---------------------------------------------------------------------------------------
|
567 |
+
# OTHER TABS
|
568 |
+
#---------------------------------------------------------------------------------------
|
569 |
+
|
570 |
+
|
571 |
+
with gr.Tab("Figures",elem_id = "tab-images",elem_classes = "max-height other-tabs"):
|
572 |
+
gallery_component = gr.Gallery()
|
573 |
+
|
574 |
+
with gr.Tab("Papers (beta)",elem_id = "tab-papers",elem_classes = "max-height other-tabs"):
|
575 |
+
|
576 |
+
with gr.Row():
|
577 |
+
with gr.Column(scale=1):
|
578 |
+
query_papers = gr.Textbox(placeholder="Question",show_label=False,lines = 1,interactive = True,elem_id="query-papers")
|
579 |
+
keywords_papers = gr.Textbox(placeholder="Keywords",show_label=False,lines = 1,interactive = True,elem_id="keywords-papers")
|
580 |
+
after = gr.Slider(minimum=1950,maximum=2023,step=1,value=1960,label="Publication date",show_label=True,interactive=True,elem_id="date-papers")
|
581 |
+
search_papers = gr.Button("Search",elem_id="search-papers",interactive=True)
|
582 |
+
|
583 |
+
with gr.Column(scale=7):
|
584 |
+
|
585 |
+
with gr.Tab("Summary",elem_id="papers-summary-tab"):
|
586 |
+
papers_summary = gr.Markdown(visible=True,elem_id="papers-summary")
|
587 |
+
|
588 |
+
with gr.Tab("Relevant papers",elem_id="papers-results-tab"):
|
589 |
+
papers_dataframe = gr.Dataframe(visible=True,elem_id="papers-table",headers = papers_cols)
|
590 |
+
|
591 |
+
with gr.Tab("Citations network",elem_id="papers-network-tab"):
|
592 |
+
citations_network = gr.HTML(visible=True,elem_id="papers-citations-network")
|
593 |
+
|
594 |
+
|
595 |
+
|
596 |
+
with gr.Tab("About",elem_classes = "max-height other-tabs"):
|
597 |
+
with gr.Row():
|
598 |
+
with gr.Column(scale=1):
|
599 |
+
gr.Markdown("See more info at [https://climateqa.com](https://climateqa.com/docs/intro/)")
|
600 |
+
|
601 |
+
|
602 |
+
def start_chat(query,history):
|
603 |
+
history = history + [(query,None)]
|
604 |
+
history = [tuple(x) for x in history]
|
605 |
+
return (gr.update(interactive = False),gr.update(selected=1),history)
|
606 |
|
607 |
+
def finish_chat():
|
608 |
+
return (gr.update(interactive = True,value = ""))
|
|
|
|
|
|
|
609 |
|
610 |
+
(textbox
|
611 |
+
.submit(start_chat, [textbox,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_textbox")
|
612 |
+
.then(chat, [textbox,chatbot,dropdown_audience, dropdown_sources,dropdown_reports], [chatbot,sources_textbox,output_query,output_language,gallery_component,query_papers,keywords_papers],concurrency_limit = 8,api_name = "chat_textbox")
|
613 |
+
.then(finish_chat, None, [textbox],api_name = "finish_chat_textbox")
|
614 |
+
)
|
615 |
|
616 |
+
(examples_hidden
|
617 |
+
.change(start_chat, [examples_hidden,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_examples")
|
618 |
+
.then(chat, [examples_hidden,chatbot,dropdown_audience, dropdown_sources,dropdown_reports], [chatbot,sources_textbox,output_query,output_language,gallery_component,query_papers,keywords_papers],concurrency_limit = 8,api_name = "chat_examples")
|
619 |
+
.then(finish_chat, None, [textbox],api_name = "finish_chat_examples")
|
620 |
+
)
|
621 |
+
|
622 |
+
|
623 |
+
def change_sample_questions(key):
|
624 |
+
index = list(QUESTIONS.keys()).index(key)
|
625 |
+
visible_bools = [False] * len(samples)
|
626 |
+
visible_bools[index] = True
|
627 |
+
return [gr.update(visible=visible_bools[i]) for i in range(len(samples))]
|
628 |
+
|
629 |
+
|
630 |
+
|
631 |
+
dropdown_samples.change(change_sample_questions,dropdown_samples,samples)
|
632 |
+
|
633 |
+
query_papers.submit(generate_keywords,[query_papers], [keywords_papers])
|
634 |
+
search_papers.click(find_papers,[query_papers,keywords_papers,after], [papers_dataframe,citations_network,papers_summary])
|
635 |
+
|
636 |
+
# # textbox.submit(predict_climateqa,[textbox,bot],[None,bot,sources_textbox])
|
637 |
+
# (textbox
|
638 |
+
# .submit(answer_user, [textbox,examples_hidden, bot], [textbox, bot],queue = False)
|
639 |
+
# .success(change_tab,None,tabs)
|
640 |
+
# .success(fetch_sources,[textbox,dropdown_sources], [textbox,sources_textbox,docs_textbox,output_query,output_language])
|
641 |
+
# .success(answer_bot, [textbox,bot,docs_textbox,output_query,output_language,dropdown_audience], [textbox,bot],queue = True)
|
642 |
+
# .success(lambda x : textbox,[textbox],[textbox])
|
643 |
+
# )
|
644 |
|
645 |
+
# (examples_hidden
|
646 |
+
# .change(answer_user_example, [textbox,examples_hidden, bot], [textbox, bot],queue = False)
|
647 |
+
# .success(change_tab,None,tabs)
|
648 |
+
# .success(fetch_sources,[textbox,dropdown_sources], [textbox,sources_textbox,docs_textbox,output_query,output_language])
|
649 |
+
# .success(answer_bot, [textbox,bot,docs_textbox,output_query,output_language,dropdown_audience], [textbox,bot],queue=True)
|
650 |
+
# .success(lambda x : textbox,[textbox],[textbox])
|
651 |
+
# )
|
652 |
+
# submit_button.click(answer_user, [textbox, bot], [textbox, bot], queue=True).then(
|
653 |
+
# answer_bot, [textbox,bot,dropdown_audience,dropdown_sources], [textbox,bot,sources_textbox]
|
654 |
+
# )
|
655 |
+
|
656 |
+
|
657 |
+
# with Modal(visible=True) as first_modal:
|
658 |
+
# gr.Markdown("# Welcome to ClimateQ&A !")
|
659 |
+
|
660 |
+
# gr.Markdown("### Examples")
|
661 |
+
|
662 |
+
# examples = gr.Examples(
|
663 |
+
# ["Yo ça roule","ça boume"],
|
664 |
+
# [examples_hidden],
|
665 |
+
# examples_per_page=8,
|
666 |
+
# run_on_click=False,
|
667 |
+
# elem_id="examples",
|
668 |
+
# api_name="examples",
|
669 |
+
# )
|
670 |
+
|
671 |
+
|
672 |
+
# submit.click(lambda: Modal(visible=True), None, config_modal)
|
673 |
|
674 |
+
|
675 |
+
demo.queue()
|
676 |
+
|
677 |
+
demo.launch()
|
climateqa/chat.py
DELETED
@@ -1,198 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from datetime import datetime
|
3 |
-
import gradio as gr
|
4 |
-
# from .agent import agent
|
5 |
-
from gradio import ChatMessage
|
6 |
-
from langgraph.graph.state import CompiledStateGraph
|
7 |
-
import json
|
8 |
-
|
9 |
-
from .handle_stream_events import (
|
10 |
-
init_audience,
|
11 |
-
handle_retrieved_documents,
|
12 |
-
convert_to_docs_to_html,
|
13 |
-
stream_answer,
|
14 |
-
handle_retrieved_owid_graphs,
|
15 |
-
serialize_docs,
|
16 |
-
)
|
17 |
-
|
18 |
-
# Function to log data on Azure
|
19 |
-
def log_on_azure(file, logs, share_client):
|
20 |
-
logs = json.dumps(logs)
|
21 |
-
file_client = share_client.get_file_client(file)
|
22 |
-
file_client.upload_file(logs)
|
23 |
-
|
24 |
-
# Chat functions
|
25 |
-
def start_chat(query, history, search_only):
|
26 |
-
history = history + [ChatMessage(role="user", content=query)]
|
27 |
-
if not search_only:
|
28 |
-
return (gr.update(interactive=False), gr.update(selected=1), history, [])
|
29 |
-
else:
|
30 |
-
return (gr.update(interactive=False), gr.update(selected=2), history, [])
|
31 |
-
|
32 |
-
def finish_chat():
|
33 |
-
return gr.update(interactive=True, value="")
|
34 |
-
|
35 |
-
def log_interaction_to_azure(history, output_query, sources, docs, share_client, user_id):
|
36 |
-
try:
|
37 |
-
# Log interaction to Azure if not in local environment
|
38 |
-
if os.getenv("GRADIO_ENV") != "local":
|
39 |
-
timestamp = str(datetime.now().timestamp())
|
40 |
-
prompt = history[1]["content"]
|
41 |
-
logs = {
|
42 |
-
"user_id": str(user_id),
|
43 |
-
"prompt": prompt,
|
44 |
-
"query": prompt,
|
45 |
-
"question": output_query,
|
46 |
-
"sources": sources,
|
47 |
-
"docs": serialize_docs(docs),
|
48 |
-
"answer": history[-1].content,
|
49 |
-
"time": timestamp,
|
50 |
-
}
|
51 |
-
log_on_azure(f"{timestamp}.json", logs, share_client)
|
52 |
-
except Exception as e:
|
53 |
-
print(f"Error logging on Azure Blob Storage: {e}")
|
54 |
-
error_msg = f"ClimateQ&A Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)"
|
55 |
-
raise gr.Error(error_msg)
|
56 |
-
|
57 |
-
# Main chat function
|
58 |
-
async def chat_stream(
|
59 |
-
agent : CompiledStateGraph,
|
60 |
-
query: str,
|
61 |
-
history: list[ChatMessage],
|
62 |
-
audience: str,
|
63 |
-
sources: list[str],
|
64 |
-
reports: list[str],
|
65 |
-
relevant_content_sources_selection: list[str],
|
66 |
-
search_only: bool,
|
67 |
-
share_client,
|
68 |
-
user_id: str
|
69 |
-
) -> tuple[list, str, str, str, list, str]:
|
70 |
-
"""Process a chat query and return response with relevant sources and visualizations.
|
71 |
-
|
72 |
-
Args:
|
73 |
-
query (str): The user's question
|
74 |
-
history (list): Chat message history
|
75 |
-
audience (str): Target audience type
|
76 |
-
sources (list): Knowledge base sources to search
|
77 |
-
reports (list): Specific reports to search within sources
|
78 |
-
relevant_content_sources_selection (list): Types of content to retrieve (figures, papers, etc)
|
79 |
-
search_only (bool): Whether to only search without generating answer
|
80 |
-
|
81 |
-
Yields:
|
82 |
-
tuple: Contains:
|
83 |
-
- history: Updated chat history
|
84 |
-
- docs_html: HTML of retrieved documents
|
85 |
-
- output_query: Processed query
|
86 |
-
- output_language: Detected language
|
87 |
-
- related_contents: Related content
|
88 |
-
- graphs_html: HTML of relevant graphs
|
89 |
-
"""
|
90 |
-
# Log incoming question
|
91 |
-
date_now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
92 |
-
print(f">> NEW QUESTION ({date_now}) : {query}")
|
93 |
-
|
94 |
-
audience_prompt = init_audience(audience)
|
95 |
-
sources = sources or ["IPCC", "IPBES"]
|
96 |
-
reports = reports or []
|
97 |
-
|
98 |
-
# Prepare inputs for agent
|
99 |
-
inputs = {
|
100 |
-
"user_input": query,
|
101 |
-
"audience": audience_prompt,
|
102 |
-
"sources_input": sources,
|
103 |
-
"relevant_content_sources_selection": relevant_content_sources_selection,
|
104 |
-
"search_only": search_only,
|
105 |
-
"reports": reports
|
106 |
-
}
|
107 |
-
|
108 |
-
# Get streaming events from agent
|
109 |
-
result = agent.astream_events(inputs, version="v1")
|
110 |
-
|
111 |
-
# Initialize state variables
|
112 |
-
docs = []
|
113 |
-
related_contents = []
|
114 |
-
docs_html = ""
|
115 |
-
new_docs_html = ""
|
116 |
-
output_query = ""
|
117 |
-
output_language = ""
|
118 |
-
output_keywords = ""
|
119 |
-
start_streaming = False
|
120 |
-
graphs_html = ""
|
121 |
-
used_documents = []
|
122 |
-
retrieved_contents = []
|
123 |
-
answer_message_content = ""
|
124 |
-
|
125 |
-
# Define processing steps
|
126 |
-
steps_display = {
|
127 |
-
"categorize_intent": ("🔄️ Analyzing user message", True),
|
128 |
-
"transform_query": ("🔄️ Thinking step by step to answer the question", True),
|
129 |
-
"retrieve_documents": ("🔄️ Searching in the knowledge base", False),
|
130 |
-
"retrieve_local_data": ("🔄️ Searching in the knowledge base", False),
|
131 |
-
}
|
132 |
-
|
133 |
-
try:
|
134 |
-
# Process streaming events
|
135 |
-
async for event in result:
|
136 |
-
|
137 |
-
if "langgraph_node" in event["metadata"]:
|
138 |
-
node = event["metadata"]["langgraph_node"]
|
139 |
-
|
140 |
-
# Handle document retrieval
|
141 |
-
if event["event"] == "on_chain_end" and event["name"] in ["retrieve_documents","retrieve_local_data"] and event["data"]["output"] != None:
|
142 |
-
history, used_documents, retrieved_contents = handle_retrieved_documents(
|
143 |
-
event, history, used_documents, retrieved_contents
|
144 |
-
)
|
145 |
-
if event["event"] == "on_chain_end" and event["name"] == "answer_search" :
|
146 |
-
docs = event["data"]["input"]["documents"]
|
147 |
-
docs_html = convert_to_docs_to_html(docs)
|
148 |
-
related_contents = event["data"]["input"]["related_contents"]
|
149 |
-
|
150 |
-
# Handle intent categorization
|
151 |
-
elif (event["event"] == "on_chain_end" and
|
152 |
-
node == "categorize_intent" and
|
153 |
-
event["name"] == "_write"):
|
154 |
-
intent = event["data"]["output"]["intent"]
|
155 |
-
output_language = event["data"]["output"].get("language", "English")
|
156 |
-
history[-1].content = f"Language identified: {output_language}\nIntent identified: {intent}"
|
157 |
-
|
158 |
-
# Handle processing steps display
|
159 |
-
elif event["name"] in steps_display and event["event"] == "on_chain_start":
|
160 |
-
event_description, display_output = steps_display[node]
|
161 |
-
if (not hasattr(history[-1], 'metadata') or
|
162 |
-
history[-1].metadata["title"] != event_description):
|
163 |
-
history.append(ChatMessage(
|
164 |
-
role="assistant",
|
165 |
-
content="",
|
166 |
-
metadata={'title': event_description}
|
167 |
-
))
|
168 |
-
|
169 |
-
# Handle answer streaming
|
170 |
-
elif (event["name"] != "transform_query" and
|
171 |
-
event["event"] == "on_chat_model_stream" and
|
172 |
-
node in ["answer_rag","answer_rag_no_docs", "answer_search", "answer_chitchat"]):
|
173 |
-
history, start_streaming, answer_message_content = stream_answer(
|
174 |
-
history, event, start_streaming, answer_message_content
|
175 |
-
)
|
176 |
-
|
177 |
-
# Handle graph retrieval
|
178 |
-
elif event["name"] in ["retrieve_graphs", "retrieve_graphs_ai"] and event["event"] == "on_chain_end":
|
179 |
-
graphs_html = handle_retrieved_owid_graphs(event, graphs_html)
|
180 |
-
|
181 |
-
# Handle query transformation
|
182 |
-
if event["name"] == "transform_query" and event["event"] == "on_chain_end":
|
183 |
-
if hasattr(history[-1], "content"):
|
184 |
-
sub_questions = [q["question"] + "-> relevant sources : " + str(q["sources"]) for q in event["data"]["output"]["questions_list"]]
|
185 |
-
history[-1].content += "Decompose question into sub-questions:\n\n - " + "\n - ".join(sub_questions)
|
186 |
-
|
187 |
-
yield history, docs_html, output_query, output_language, related_contents, graphs_html
|
188 |
-
|
189 |
-
except Exception as e:
|
190 |
-
print(f"Event {event} has failed")
|
191 |
-
raise gr.Error(str(e))
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
# Call the function to log interaction
|
196 |
-
log_interaction_to_azure(history, output_query, sources, docs, share_client, user_id)
|
197 |
-
|
198 |
-
yield history, docs_html, output_query, output_language, related_contents, graphs_html
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/constants.py
CHANGED
@@ -1,6 +1,4 @@
|
|
1 |
POSSIBLE_REPORTS = [
|
2 |
-
"IPBES IABWFH SPM",
|
3 |
-
"IPBES CBL SPM",
|
4 |
"IPCC AR6 WGI SPM",
|
5 |
"IPCC AR6 WGI FR",
|
6 |
"IPCC AR6 WGI TS",
|
@@ -44,25 +42,4 @@ POSSIBLE_REPORTS = [
|
|
44 |
"IPBES IAS A C5",
|
45 |
"IPBES IAS A C6",
|
46 |
"IPBES IAS A SPM"
|
47 |
-
]
|
48 |
-
|
49 |
-
OWID_CATEGORIES = ['Access to Energy', 'Agricultural Production',
|
50 |
-
'Agricultural Regulation & Policy', 'Air Pollution',
|
51 |
-
'Animal Welfare', 'Antibiotics', 'Biodiversity', 'Biofuels',
|
52 |
-
'Biological & Chemical Weapons', 'CO2 & Greenhouse Gas Emissions',
|
53 |
-
'COVID-19', 'Clean Water', 'Clean Water & Sanitation',
|
54 |
-
'Climate Change', 'Crop Yields', 'Diet Compositions',
|
55 |
-
'Electricity', 'Electricity Mix', 'Energy', 'Energy Efficiency',
|
56 |
-
'Energy Prices', 'Environmental Impacts of Food Production',
|
57 |
-
'Environmental Protection & Regulation', 'Famines', 'Farm Size',
|
58 |
-
'Fertilizers', 'Fish & Overfishing', 'Food Supply', 'Food Trade',
|
59 |
-
'Food Waste', 'Food and Agriculture', 'Forests & Deforestation',
|
60 |
-
'Fossil Fuels', 'Future Population Growth',
|
61 |
-
'Hunger & Undernourishment', 'Indoor Air Pollution', 'Land Use',
|
62 |
-
'Land Use & Yields in Agriculture', 'Lead Pollution',
|
63 |
-
'Meat & Dairy Production', 'Metals & Minerals',
|
64 |
-
'Natural Disasters', 'Nuclear Energy', 'Nuclear Weapons',
|
65 |
-
'Oil Spills', 'Outdoor Air Pollution', 'Ozone Layer', 'Pandemics',
|
66 |
-
'Pesticides', 'Plastic Pollution', 'Renewable Energy', 'Soil',
|
67 |
-
'Transport', 'Urbanization', 'Waste Management', 'Water Pollution',
|
68 |
-
'Water Use & Stress', 'Wildfires']
|
|
|
1 |
POSSIBLE_REPORTS = [
|
|
|
|
|
2 |
"IPCC AR6 WGI SPM",
|
3 |
"IPCC AR6 WGI FR",
|
4 |
"IPCC AR6 WGI TS",
|
|
|
42 |
"IPBES IAS A C5",
|
43 |
"IPBES IAS A C6",
|
44 |
"IPBES IAS A SPM"
|
45 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/chains/__init__.py
DELETED
File without changes
|
climateqa/engine/chains/answer_ai_impact.py
DELETED
@@ -1,46 +0,0 @@
|
|
1 |
-
from langchain.prompts import ChatPromptTemplate
|
2 |
-
from langchain_core.output_parsers import StrOutputParser
|
3 |
-
|
4 |
-
|
5 |
-
prompt_template = """
|
6 |
-
You are ClimateQ&A, an helpful AI Assistant specialized in answering climate-related questions using info from the IPCC and/or IPBES reports.
|
7 |
-
Always stay true to climate and nature science and do not make up information.
|
8 |
-
If you do not know the answer, just say you do not know.
|
9 |
-
|
10 |
-
## Guidelines
|
11 |
-
- Explain that the environmental impact of AI is not covered by the IPCC or IPBES reports, but you can recommend info based on the sources below
|
12 |
-
- Answer the question in the original language of the question
|
13 |
-
|
14 |
-
## Sources
|
15 |
-
- You can propose to visit this page https://climateqa.com/docs/carbon-footprint/ to learn more about ClimateQ&A's own carbon footprint
|
16 |
-
- You can recommend to look at the work of the AI & climate expert scientist Sasha Luccioni with in in particular those papers
|
17 |
-
- Power Hungry Processing: Watts Driving the Cost of AI Deployment? - https://arxiv.org/abs/2311.16863 - about the carbon footprint at the inference stage of AI models
|
18 |
-
- Counting Carbon: A Survey of Factors Influencing the Emissions of Machine Learning - https://arxiv.org/abs/2302.08476
|
19 |
-
- Estimating the Carbon Footprint of BLOOM, a 176B Parameter Language Model - https://arxiv.org/abs/2211.02001 - about the carbon footprint of training a large language model
|
20 |
-
- You can also recommend the following tools to calculate the carbon footprint of AI models
|
21 |
-
- CodeCarbon - https://github.com/mlco2/codecarbon to measure the carbon footprint of your code
|
22 |
-
- Ecologits - https://ecologits.ai/ to measure the carbon footprint of using LLMs APIs such
|
23 |
-
"""
|
24 |
-
|
25 |
-
|
26 |
-
def make_ai_impact_chain(llm):
|
27 |
-
|
28 |
-
prompt = ChatPromptTemplate.from_messages([
|
29 |
-
("system", prompt_template),
|
30 |
-
("user", "{question}")
|
31 |
-
])
|
32 |
-
|
33 |
-
chain = prompt | llm | StrOutputParser()
|
34 |
-
chain = chain.with_config({"run_name":"ai_impact_chain"})
|
35 |
-
|
36 |
-
return chain
|
37 |
-
|
38 |
-
def make_ai_impact_node(llm):
|
39 |
-
|
40 |
-
ai_impact_chain = make_ai_impact_chain(llm)
|
41 |
-
|
42 |
-
async def answer_ai_impact(state,config):
|
43 |
-
answer = await ai_impact_chain.ainvoke({"question":state["user_input"]},config)
|
44 |
-
return {"answer":answer}
|
45 |
-
|
46 |
-
return answer_ai_impact
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/chains/answer_chitchat.py
DELETED
@@ -1,56 +0,0 @@
|
|
1 |
-
from langchain.prompts import ChatPromptTemplate
|
2 |
-
from langchain_core.output_parsers import StrOutputParser
|
3 |
-
|
4 |
-
|
5 |
-
chitchat_prompt_template = """
|
6 |
-
You are ClimateQ&A, an helpful AI Assistant specialized in answering climate-related questions using info from the IPCC and/or IPBES reports.
|
7 |
-
Always stay true to climate and nature science and do not make up information.
|
8 |
-
If you do not know the answer, just say you do not know.
|
9 |
-
|
10 |
-
## Guidelines
|
11 |
-
- If it's a conversational question, you can normally chat with the user
|
12 |
-
- If the question is not related to any topic about the environment, refuse to answer and politely ask the user to ask another question about the environment
|
13 |
-
- If the user ask if you speak any language, you can say you speak all languages :)
|
14 |
-
- If the user ask about the bot itself "ClimateQ&A", you can explain that you are an AI assistant specialized in answering climate-related questions using info from the IPCC and/or IPBES reports and propose to visit the website here https://climateqa.com/docs/intro/ for more information
|
15 |
-
- If the question is about ESG regulations, standards, or frameworks like the CSRD, TCFD, SASB, GRI, CDP, etc., you can explain that this is not a topic covered by the IPCC or IPBES reports.
|
16 |
-
- Precise that you are specialized in finding trustworthy information from the scientific reports of the IPCC and IPBES and other scientific litterature
|
17 |
-
- If relevant you can propose up to 3 example of questions they could ask from the IPCC or IPBES reports from the examples below
|
18 |
-
- Always answer in the original language of the question
|
19 |
-
|
20 |
-
## Examples of questions you can suggest (in the original language of the question)
|
21 |
-
"What evidence do we have of climate change?",
|
22 |
-
"Are human activities causing global warming?",
|
23 |
-
"What are the impacts of climate change?",
|
24 |
-
"Can climate change be reversed?",
|
25 |
-
"What is the difference between climate change and global warming?",
|
26 |
-
"""
|
27 |
-
|
28 |
-
|
29 |
-
def make_chitchat_chain(llm):
|
30 |
-
|
31 |
-
prompt = ChatPromptTemplate.from_messages([
|
32 |
-
("system", chitchat_prompt_template),
|
33 |
-
("user", "{question}")
|
34 |
-
])
|
35 |
-
|
36 |
-
chain = prompt | llm | StrOutputParser()
|
37 |
-
chain = chain.with_config({"run_name":"chitchat_chain"})
|
38 |
-
|
39 |
-
return chain
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
def make_chitchat_node(llm):
|
44 |
-
|
45 |
-
chitchat_chain = make_chitchat_chain(llm)
|
46 |
-
|
47 |
-
async def answer_chitchat(state,config):
|
48 |
-
print("---- Answer chitchat ----")
|
49 |
-
|
50 |
-
answer = await chitchat_chain.ainvoke({"question":state["user_input"]},config)
|
51 |
-
state["answer"] = answer
|
52 |
-
return state
|
53 |
-
# return {"answer":answer}
|
54 |
-
|
55 |
-
return answer_chitchat
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/chains/chitchat_categorization.py
DELETED
@@ -1,43 +0,0 @@
|
|
1 |
-
|
2 |
-
from langchain_core.pydantic_v1 import BaseModel, Field
|
3 |
-
from typing import List
|
4 |
-
from typing import Literal
|
5 |
-
from langchain.prompts import ChatPromptTemplate
|
6 |
-
from langchain_core.utils.function_calling import convert_to_openai_function
|
7 |
-
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
8 |
-
|
9 |
-
|
10 |
-
class IntentCategorizer(BaseModel):
|
11 |
-
"""Analyzing the user message input"""
|
12 |
-
|
13 |
-
environment: bool = Field(
|
14 |
-
description="Return 'True' if the question relates to climate change, the environment, nature, etc. (Example: should I eat fish?). Return 'False' if the question is just chit chat or not related to the environment or climate change.",
|
15 |
-
)
|
16 |
-
|
17 |
-
|
18 |
-
def make_chitchat_intent_categorization_chain(llm):
|
19 |
-
|
20 |
-
openai_functions = [convert_to_openai_function(IntentCategorizer)]
|
21 |
-
llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"IntentCategorizer"})
|
22 |
-
|
23 |
-
prompt = ChatPromptTemplate.from_messages([
|
24 |
-
("system", "You are a helpful assistant, you will analyze, translate and reformulate the user input message using the function provided"),
|
25 |
-
("user", "input: {input}")
|
26 |
-
])
|
27 |
-
|
28 |
-
chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
|
29 |
-
return chain
|
30 |
-
|
31 |
-
|
32 |
-
def make_chitchat_intent_categorization_node(llm):
|
33 |
-
|
34 |
-
categorization_chain = make_chitchat_intent_categorization_chain(llm)
|
35 |
-
|
36 |
-
def categorize_message(state):
|
37 |
-
output = categorization_chain.invoke({"input": state["user_input"]})
|
38 |
-
print(f"\n\nChit chat output intent categorization: {output}\n")
|
39 |
-
state["search_graphs_chitchat"] = output["environment"]
|
40 |
-
print(f"\n\nChit chat output intent categorization: {state}\n")
|
41 |
-
return state
|
42 |
-
|
43 |
-
return categorize_message
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/chains/graph_retriever.py
DELETED
@@ -1,130 +0,0 @@
|
|
1 |
-
import sys
|
2 |
-
import os
|
3 |
-
from contextlib import contextmanager
|
4 |
-
|
5 |
-
from ..reranker import rerank_docs
|
6 |
-
from ..graph_retriever import retrieve_graphs # GraphRetriever
|
7 |
-
from ...utils import remove_duplicates_keep_highest_score
|
8 |
-
|
9 |
-
|
10 |
-
def divide_into_parts(target, parts):
|
11 |
-
# Base value for each part
|
12 |
-
base = target // parts
|
13 |
-
# Remainder to distribute
|
14 |
-
remainder = target % parts
|
15 |
-
# List to hold the result
|
16 |
-
result = []
|
17 |
-
|
18 |
-
for i in range(parts):
|
19 |
-
if i < remainder:
|
20 |
-
# These parts get base value + 1
|
21 |
-
result.append(base + 1)
|
22 |
-
else:
|
23 |
-
# The rest get the base value
|
24 |
-
result.append(base)
|
25 |
-
|
26 |
-
return result
|
27 |
-
|
28 |
-
|
29 |
-
@contextmanager
|
30 |
-
def suppress_output():
|
31 |
-
# Open a null device
|
32 |
-
with open(os.devnull, 'w') as devnull:
|
33 |
-
# Store the original stdout and stderr
|
34 |
-
old_stdout = sys.stdout
|
35 |
-
old_stderr = sys.stderr
|
36 |
-
# Redirect stdout and stderr to the null device
|
37 |
-
sys.stdout = devnull
|
38 |
-
sys.stderr = devnull
|
39 |
-
try:
|
40 |
-
yield
|
41 |
-
finally:
|
42 |
-
# Restore stdout and stderr
|
43 |
-
sys.stdout = old_stdout
|
44 |
-
sys.stderr = old_stderr
|
45 |
-
|
46 |
-
|
47 |
-
def make_graph_retriever_node(vectorstore, reranker, rerank_by_question=True, k_final=15, k_before_reranking=100):
|
48 |
-
|
49 |
-
async def node_retrieve_graphs(state):
|
50 |
-
print("---- Retrieving graphs ----")
|
51 |
-
|
52 |
-
POSSIBLE_SOURCES = ["IEA", "OWID"]
|
53 |
-
# questions = state["remaining_questions"] if state["remaining_questions"] is not None and state["remaining_questions"]!=[] else [state["query"]]
|
54 |
-
questions = state["questions_list"] if state["questions_list"] is not None and state["questions_list"]!=[] else [state["query"]]
|
55 |
-
|
56 |
-
# sources_input = state["sources_input"]
|
57 |
-
sources_input = ["auto"]
|
58 |
-
|
59 |
-
auto_mode = "auto" in sources_input
|
60 |
-
|
61 |
-
# There are several options to get the final top k
|
62 |
-
# Option 1 - Get 100 documents by question and rerank by question
|
63 |
-
# Option 2 - Get 100/n documents by question and rerank the total
|
64 |
-
if rerank_by_question:
|
65 |
-
k_by_question = divide_into_parts(k_final,len(questions))
|
66 |
-
|
67 |
-
docs = []
|
68 |
-
|
69 |
-
for i,q in enumerate(questions):
|
70 |
-
|
71 |
-
question = q["question"] if isinstance(q, dict) else q
|
72 |
-
|
73 |
-
print(f"Subquestion {i}: {question}")
|
74 |
-
|
75 |
-
# If auto mode, we use all sources
|
76 |
-
if auto_mode:
|
77 |
-
sources = POSSIBLE_SOURCES
|
78 |
-
# Otherwise, we use the config
|
79 |
-
else:
|
80 |
-
sources = sources_input
|
81 |
-
|
82 |
-
if any([x in POSSIBLE_SOURCES for x in sources]):
|
83 |
-
|
84 |
-
sources = [x for x in sources if x in POSSIBLE_SOURCES]
|
85 |
-
|
86 |
-
# Search the document store using the retriever
|
87 |
-
docs_question = await retrieve_graphs(
|
88 |
-
query = question,
|
89 |
-
vectorstore = vectorstore,
|
90 |
-
sources = sources,
|
91 |
-
k_total = k_before_reranking,
|
92 |
-
threshold = 0.5,
|
93 |
-
)
|
94 |
-
# docs_question = retriever.get_relevant_documents(question)
|
95 |
-
|
96 |
-
# Rerank
|
97 |
-
if reranker is not None and docs_question!=[]:
|
98 |
-
with suppress_output():
|
99 |
-
docs_question = rerank_docs(reranker,docs_question,question)
|
100 |
-
else:
|
101 |
-
# Add a default reranking score
|
102 |
-
for doc in docs_question:
|
103 |
-
doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
|
104 |
-
|
105 |
-
# If rerank by question we select the top documents for each question
|
106 |
-
if rerank_by_question:
|
107 |
-
docs_question = docs_question[:k_by_question[i]]
|
108 |
-
|
109 |
-
# Add sources used in the metadata
|
110 |
-
for doc in docs_question:
|
111 |
-
doc.metadata["sources_used"] = sources
|
112 |
-
|
113 |
-
print(f"{len(docs_question)} graphs retrieved for subquestion {i + 1}: {docs_question}")
|
114 |
-
|
115 |
-
docs.extend(docs_question)
|
116 |
-
|
117 |
-
else:
|
118 |
-
print(f"There are no graphs which match the sources filtered on. Sources filtered on: {sources}. Sources available: {POSSIBLE_SOURCES}.")
|
119 |
-
|
120 |
-
# Remove duplicates and keep the duplicate document with the highest reranking score
|
121 |
-
docs = remove_duplicates_keep_highest_score(docs)
|
122 |
-
|
123 |
-
# Sorting the list in descending order by rerank_score
|
124 |
-
# Then select the top k
|
125 |
-
docs = sorted(docs, key=lambda x: x.metadata["reranking_score"], reverse=True)
|
126 |
-
docs = docs[:k_final]
|
127 |
-
|
128 |
-
return {"recommended_content": docs}
|
129 |
-
|
130 |
-
return node_retrieve_graphs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/chains/intent_categorization.py
DELETED
@@ -1,90 +0,0 @@
|
|
1 |
-
|
2 |
-
from langchain_core.pydantic_v1 import BaseModel, Field
|
3 |
-
from typing import List
|
4 |
-
from typing import Literal
|
5 |
-
from langchain.prompts import ChatPromptTemplate
|
6 |
-
from langchain_core.utils.function_calling import convert_to_openai_function
|
7 |
-
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
8 |
-
|
9 |
-
|
10 |
-
class IntentCategorizer(BaseModel):
|
11 |
-
"""Analyzing the user message input"""
|
12 |
-
|
13 |
-
language: str = Field(
|
14 |
-
description="Find the language of the message input in full words (ex: French, English, Spanish, ...), defaults to English",
|
15 |
-
default="English",
|
16 |
-
)
|
17 |
-
intent: str = Field(
|
18 |
-
enum=[
|
19 |
-
"ai_impact",
|
20 |
-
# "geo_info",
|
21 |
-
# "esg",
|
22 |
-
"search",
|
23 |
-
"chitchat",
|
24 |
-
],
|
25 |
-
description="""
|
26 |
-
Categorize the user input in one of the following category
|
27 |
-
Any question
|
28 |
-
|
29 |
-
Examples:
|
30 |
-
- ai_impact = Environmental impacts of AI: "What are the environmental impacts of AI", "How does AI affect the environment"
|
31 |
-
- search = Searching for any quesiton about climate change, energy, biodiversity, nature, and everything we can find the IPCC or IPBES reports or scientific papers,
|
32 |
-
- chitchat = Any general question that is not related to the environment or climate change or just conversational, or if you don't think searching the IPCC or IPBES reports would be relevant
|
33 |
-
""",
|
34 |
-
# - geo_info = Geolocated info about climate change: Any question where the user wants to know localized impacts of climate change, eg: "What will be the temperature in Marseille in 2050"
|
35 |
-
# - esg = Any question about the ESG regulation, frameworks and standards like the CSRD, TCFD, SASB, GRI, CDP, etc.
|
36 |
-
|
37 |
-
)
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
def make_intent_categorization_chain(llm):
|
42 |
-
|
43 |
-
openai_functions = [convert_to_openai_function(IntentCategorizer)]
|
44 |
-
llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"IntentCategorizer"})
|
45 |
-
|
46 |
-
prompt = ChatPromptTemplate.from_messages([
|
47 |
-
("system", "You are a helpful assistant, you will analyze, translate and categorize the user input message using the function provided. Categorize the user input as ai ONLY if it is related to Artificial Intelligence, search if it is related to the environment, climate change, energy, biodiversity, nature, etc. and chitchat if it is just general conversation."),
|
48 |
-
("user", "input: {input}")
|
49 |
-
])
|
50 |
-
|
51 |
-
chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
|
52 |
-
return chain
|
53 |
-
|
54 |
-
|
55 |
-
def make_intent_categorization_node(llm):
|
56 |
-
|
57 |
-
categorization_chain = make_intent_categorization_chain(llm)
|
58 |
-
|
59 |
-
def categorize_message(state):
|
60 |
-
print("---- Categorize_message ----")
|
61 |
-
|
62 |
-
output = categorization_chain.invoke({"input": state["user_input"]})
|
63 |
-
print(f"\n\nOutput intent categorization: {output}\n")
|
64 |
-
if "language" not in output: output["language"] = "English"
|
65 |
-
output["query"] = state["user_input"]
|
66 |
-
return output
|
67 |
-
|
68 |
-
return categorize_message
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
# SAMPLE_QUESTIONS = [
|
74 |
-
# "Est-ce que l'IA a un impact sur l'environnement ?",
|
75 |
-
# "Que dit le GIEC sur l'impact de l'IA",
|
76 |
-
# "Qui sont les membres du GIEC",
|
77 |
-
# "What is the impact of El Nino ?",
|
78 |
-
# "Yo",
|
79 |
-
# "Hello ça va bien ?",
|
80 |
-
# "Par qui as tu été créé ?",
|
81 |
-
# "What role do cloud formations play in modulating the Earth's radiative balance, and how are they represented in current climate models?",
|
82 |
-
# "Which industries have the highest GHG emissions?",
|
83 |
-
# "What are invasive alien species and how do they threaten biodiversity and ecosystems?",
|
84 |
-
# "Are human activities causing global warming?",
|
85 |
-
# "What is the motivation behind mining the deep seabed?",
|
86 |
-
# "Tu peux m'écrire un poème sur le changement climatique ?",
|
87 |
-
# "Tu peux m'écrire un poème sur les bonbons ?",
|
88 |
-
# "What will be the temperature in 2100 in Strasbourg?",
|
89 |
-
# "C'est quoi le lien entre biodiversity and changement climatique ?",
|
90 |
-
# ]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/chains/keywords_extraction.py
DELETED
@@ -1,40 +0,0 @@
|
|
1 |
-
|
2 |
-
from langchain_core.pydantic_v1 import BaseModel, Field
|
3 |
-
from typing import List
|
4 |
-
from typing import Literal
|
5 |
-
from langchain.prompts import ChatPromptTemplate
|
6 |
-
from langchain_core.utils.function_calling import convert_to_openai_function
|
7 |
-
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
8 |
-
|
9 |
-
|
10 |
-
class KeywordExtraction(BaseModel):
|
11 |
-
"""
|
12 |
-
Analyzing the user query to extract keywords to feed a search engine
|
13 |
-
"""
|
14 |
-
|
15 |
-
keywords: List[str] = Field(
|
16 |
-
description="""
|
17 |
-
Extract the keywords from the user query to feed a search engine as a list
|
18 |
-
Avoid adding super specific keywords to prefer general keywords
|
19 |
-
Maximum 3 keywords
|
20 |
-
|
21 |
-
Examples:
|
22 |
-
- "What is the impact of deep sea mining ?" -> ["deep sea mining"]
|
23 |
-
- "How will El Nino be impacted by climate change" -> ["el nino","climate change"]
|
24 |
-
- "Is climate change a hoax" -> ["climate change","hoax"]
|
25 |
-
"""
|
26 |
-
)
|
27 |
-
|
28 |
-
|
29 |
-
def make_keywords_extraction_chain(llm):
|
30 |
-
|
31 |
-
openai_functions = [convert_to_openai_function(KeywordExtraction)]
|
32 |
-
llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"KeywordExtraction"})
|
33 |
-
|
34 |
-
prompt = ChatPromptTemplate.from_messages([
|
35 |
-
("system", "You are a helpful assistant"),
|
36 |
-
("user", "input: {input}")
|
37 |
-
])
|
38 |
-
|
39 |
-
chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
|
40 |
-
return chain
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/chains/query_transformation.py
DELETED
@@ -1,298 +0,0 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
from langchain_core.pydantic_v1 import BaseModel, Field
|
4 |
-
from typing import List
|
5 |
-
from typing import Literal
|
6 |
-
from langchain.prompts import ChatPromptTemplate
|
7 |
-
from langchain_core.utils.function_calling import convert_to_openai_function
|
8 |
-
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
9 |
-
|
10 |
-
# OLD QUERY ANALYSIS
|
11 |
-
# keywords: List[str] = Field(
|
12 |
-
# description="""
|
13 |
-
# Extract the keywords from the user query to feed a search engine as a list
|
14 |
-
# Maximum 3 keywords
|
15 |
-
|
16 |
-
# Examples:
|
17 |
-
# - "What is the impact of deep sea mining ?" -> deep sea mining
|
18 |
-
# - "How will El Nino be impacted by climate change" -> el nino;climate change
|
19 |
-
# - "Is climate change a hoax" -> climate change;hoax
|
20 |
-
# """
|
21 |
-
# )
|
22 |
-
|
23 |
-
# alternative_queries: List[str] = Field(
|
24 |
-
# description="""
|
25 |
-
# Generate alternative search questions from the user query to feed a search engine
|
26 |
-
# """
|
27 |
-
# )
|
28 |
-
|
29 |
-
# step_back_question: str = Field(
|
30 |
-
# description="""
|
31 |
-
# You are an expert at world knowledge. Your task is to step back and paraphrase a question to a more generic step-back question, which is easier to answer.
|
32 |
-
# This questions should help you get more context and information about the user query
|
33 |
-
# """
|
34 |
-
# )
|
35 |
-
# - OpenAlex is for any other questions that are not in the previous categories but could be found in the scientific litterature
|
36 |
-
#
|
37 |
-
|
38 |
-
|
39 |
-
# topics: List[Literal[
|
40 |
-
# "Climate change",
|
41 |
-
# "Biodiversity",
|
42 |
-
# "Energy",
|
43 |
-
# "Decarbonization",
|
44 |
-
# "Climate science",
|
45 |
-
# "Nature",
|
46 |
-
# "Climate policy and justice",
|
47 |
-
# "Oceans",
|
48 |
-
# "Deep sea mining",
|
49 |
-
# "ESG and regulations",
|
50 |
-
# "CSRD",
|
51 |
-
# ]] = Field(
|
52 |
-
# ...,
|
53 |
-
# description = """
|
54 |
-
# Choose the topics that are most relevant to the user query, ex: Climate change, Energy, Biodiversity, ...
|
55 |
-
# """,
|
56 |
-
# )
|
57 |
-
# date: str = Field(description="The date or period mentioned, ex: 2050, between 2020 and 2050")
|
58 |
-
# location:Location
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
ROUTING_INDEX = {
|
63 |
-
"IPx":["IPCC", "IPBES", "IPOS"],
|
64 |
-
"POC": ["AcclimaTerra", "PCAET","Biodiv"],
|
65 |
-
"OpenAlex":["OpenAlex"],
|
66 |
-
}
|
67 |
-
|
68 |
-
POSSIBLE_SOURCES = [y for values in ROUTING_INDEX.values() for y in values]
|
69 |
-
|
70 |
-
# Prompt from the original paper https://arxiv.org/pdf/2305.14283
|
71 |
-
# Query Rewriting for Retrieval-Augmented Large Language Models
|
72 |
-
class QueryDecomposition(BaseModel):
|
73 |
-
"""
|
74 |
-
Decompose the user query into smaller parts to think step by step to answer this question
|
75 |
-
Act as a simple planning agent
|
76 |
-
"""
|
77 |
-
|
78 |
-
questions: List[str] = Field(
|
79 |
-
description="""
|
80 |
-
Think step by step to answer this question, and provide one or several search engine questions in the provided language for knowledge that you need.
|
81 |
-
Suppose that the user is looking for information about climate change, energy, biodiversity, nature, and everything we can find the IPCC reports and scientific literature
|
82 |
-
- If it's already a standalone and explicit question, just return the reformulated question for the search engine
|
83 |
-
- If you need to decompose the question, output a list of maximum 2 to 3 questions
|
84 |
-
"""
|
85 |
-
)
|
86 |
-
|
87 |
-
|
88 |
-
class Location(BaseModel):
|
89 |
-
country:str = Field(...,description="The country if directly mentioned or inferred from the location (cities, regions, adresses), ex: France, USA, ...")
|
90 |
-
location:str = Field(...,description="The specific place if mentioned (cities, regions, addresses), ex: Marseille, New York, Wisconsin, ...")
|
91 |
-
|
92 |
-
class QueryTranslation(BaseModel):
|
93 |
-
"""Translate the query into a given language"""
|
94 |
-
|
95 |
-
question : str = Field(
|
96 |
-
description="""
|
97 |
-
Translate the questions into the given language
|
98 |
-
If the question is alrealdy in the given language, just return the same question
|
99 |
-
""",
|
100 |
-
)
|
101 |
-
|
102 |
-
|
103 |
-
class QueryAnalysis(BaseModel):
|
104 |
-
"""
|
105 |
-
Analyze the user query to extract the relevant sources
|
106 |
-
|
107 |
-
Deprecated:
|
108 |
-
Analyzing the user query to extract topics, sources and date
|
109 |
-
Also do query expansion to get alternative search queries
|
110 |
-
Also provide simple keywords to feed a search engine
|
111 |
-
"""
|
112 |
-
|
113 |
-
sources: List[Literal["IPCC", "IPBES", "IPOS", "AcclimaTerra", "PCAET","Biodiv"]] = Field( #,"OpenAlex"]] = Field(
|
114 |
-
...,
|
115 |
-
description="""
|
116 |
-
Given a user question choose which documents would be most relevant for answering their question,
|
117 |
-
- IPCC is for questions about climate change, energy, impacts, and everything we can find the IPCC reports
|
118 |
-
- IPBES is for questions about biodiversity and nature
|
119 |
-
- IPOS is for questions about the ocean and deep sea mining
|
120 |
-
- AcclimaTerra is for questions about any specific place in, or close to, the french region "Nouvelle-Aquitaine"
|
121 |
-
- PCAET is the Plan Climat Eneregie Territorial for the city of Paris
|
122 |
-
- Biodiv is the Biodiversity plan for the city of Paris
|
123 |
-
""",
|
124 |
-
)
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
def make_query_decomposition_chain(llm):
|
129 |
-
"""Chain to decompose a query into smaller parts to think step by step to answer this question
|
130 |
-
|
131 |
-
Args:
|
132 |
-
llm (_type_): _description_
|
133 |
-
|
134 |
-
Returns:
|
135 |
-
_type_: _description_
|
136 |
-
"""
|
137 |
-
|
138 |
-
openai_functions = [convert_to_openai_function(QueryDecomposition)]
|
139 |
-
llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"QueryDecomposition"})
|
140 |
-
|
141 |
-
prompt = ChatPromptTemplate.from_messages([
|
142 |
-
("system", "You are a helpful assistant, you will analyze, translate and reformulate the user input message using the function provided"),
|
143 |
-
("user", "input: {input}")
|
144 |
-
])
|
145 |
-
|
146 |
-
chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
|
147 |
-
return chain
|
148 |
-
|
149 |
-
|
150 |
-
def make_query_analysis_chain(llm):
|
151 |
-
"""Analyze the user query to extract the relevant sources"""
|
152 |
-
|
153 |
-
openai_functions = [convert_to_openai_function(QueryAnalysis)]
|
154 |
-
llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"QueryAnalysis"})
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
prompt = ChatPromptTemplate.from_messages([
|
159 |
-
("system", "You are a helpful assistant, you will analyze the user input message using the function provided"),
|
160 |
-
("user", "input: {input}")
|
161 |
-
])
|
162 |
-
|
163 |
-
|
164 |
-
chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
|
165 |
-
return chain
|
166 |
-
|
167 |
-
|
168 |
-
def make_query_translation_chain(llm):
|
169 |
-
"""Analyze the user query to extract the relevant sources"""
|
170 |
-
|
171 |
-
openai_functions = [convert_to_openai_function(QueryTranslation)]
|
172 |
-
llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"QueryTranslation"})
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
prompt = ChatPromptTemplate.from_messages([
|
177 |
-
("system", "You are a helpful assistant, translate the question into {language}"),
|
178 |
-
("user", "input: {input}")
|
179 |
-
])
|
180 |
-
|
181 |
-
|
182 |
-
chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
|
183 |
-
return chain
|
184 |
-
|
185 |
-
def group_by_sources_types(sources):
|
186 |
-
sources_types = {}
|
187 |
-
IPx_sources = ["IPCC", "IPBES", "IPOS"]
|
188 |
-
local_sources = ["AcclimaTerra", "PCAET","Biodiv"]
|
189 |
-
if any(source in IPx_sources for source in sources):
|
190 |
-
sources_types["IPx"] = list(set(sources).intersection(IPx_sources))
|
191 |
-
if any(source in local_sources for source in sources):
|
192 |
-
sources_types["POC"] = list(set(sources).intersection(local_sources))
|
193 |
-
return sources_types
|
194 |
-
|
195 |
-
|
196 |
-
def make_query_transform_node(llm,k_final=15):
|
197 |
-
"""
|
198 |
-
Creates a query transformation node that processes and transforms a given query state.
|
199 |
-
Args:
|
200 |
-
llm: The language model to be used for query decomposition and rewriting.
|
201 |
-
k_final (int, optional): The final number of questions to be generated. Defaults to 15.
|
202 |
-
Returns:
|
203 |
-
function: A function that takes a query state and returns a transformed state.
|
204 |
-
The returned function performs the following steps:
|
205 |
-
1. Checks if the query should be processed in auto mode based on the state.
|
206 |
-
2. Retrieves the input sources from the state or defaults to a predefined routing index.
|
207 |
-
3. Decomposes the query using the decomposition chain.
|
208 |
-
4. Analyzes each decomposed question using the rewriter chain.
|
209 |
-
5. Ensures that the sources returned by the language model are valid.
|
210 |
-
6. Explodes the questions into multiple questions with different sources based on the mode.
|
211 |
-
7. Constructs a new state with the transformed questions and their respective sources.
|
212 |
-
"""
|
213 |
-
|
214 |
-
|
215 |
-
decomposition_chain = make_query_decomposition_chain(llm)
|
216 |
-
query_analysis_chain = make_query_analysis_chain(llm)
|
217 |
-
query_translation_chain = make_query_translation_chain(llm)
|
218 |
-
|
219 |
-
def transform_query(state):
|
220 |
-
print("---- Transform query ----")
|
221 |
-
|
222 |
-
auto_mode = state.get("sources_auto", True)
|
223 |
-
sources_input = state.get("sources_input", ROUTING_INDEX["IPx"])
|
224 |
-
|
225 |
-
|
226 |
-
new_state = {}
|
227 |
-
|
228 |
-
# Decomposition
|
229 |
-
decomposition_output = decomposition_chain.invoke({"input":state["query"]})
|
230 |
-
new_state.update(decomposition_output)
|
231 |
-
|
232 |
-
|
233 |
-
# Query Analysis
|
234 |
-
questions = []
|
235 |
-
for question in new_state["questions"]:
|
236 |
-
question_state = {"question":question}
|
237 |
-
query_analysis_output = query_analysis_chain.invoke({"input":question})
|
238 |
-
|
239 |
-
# TODO WARNING llm should always return smthg
|
240 |
-
# The case when the llm does not return any sources or wrong ouput
|
241 |
-
if not query_analysis_output["sources"] or not all(source in ["IPCC", "IPBS", "IPOS","AcclimaTerra", "PCAET","Biodiv"] for source in query_analysis_output["sources"]):
|
242 |
-
query_analysis_output["sources"] = ["IPCC", "IPBES", "IPOS"]
|
243 |
-
|
244 |
-
sources_types = group_by_sources_types(query_analysis_output["sources"])
|
245 |
-
for source_type,sources in sources_types.items():
|
246 |
-
question_state = {
|
247 |
-
"question":question,
|
248 |
-
"sources":sources,
|
249 |
-
"source_type":source_type
|
250 |
-
}
|
251 |
-
|
252 |
-
questions.append(question_state)
|
253 |
-
|
254 |
-
# Translate question into the document language
|
255 |
-
for q in questions:
|
256 |
-
if q["source_type"]=="IPx":
|
257 |
-
translation_output = query_translation_chain.invoke({"input":q["question"],"language":"English"})
|
258 |
-
q["question"] = translation_output["question"]
|
259 |
-
elif q["source_type"]=="POC":
|
260 |
-
translation_output = query_translation_chain.invoke({"input":q["question"],"language":"French"})
|
261 |
-
q["question"] = translation_output["question"]
|
262 |
-
|
263 |
-
# Explode the questions into multiple questions with different sources
|
264 |
-
new_questions = []
|
265 |
-
for q in questions:
|
266 |
-
question,sources,source_type = q["question"],q["sources"], q["source_type"]
|
267 |
-
|
268 |
-
# If not auto mode we take the configuration
|
269 |
-
if not auto_mode:
|
270 |
-
sources = sources_input
|
271 |
-
|
272 |
-
for index,index_sources in ROUTING_INDEX.items():
|
273 |
-
selected_sources = list(set(sources).intersection(index_sources))
|
274 |
-
if len(selected_sources) > 0:
|
275 |
-
new_questions.append({"question":question,"sources":selected_sources,"index":index, "source_type":source_type})
|
276 |
-
|
277 |
-
# # Add the number of questions to search
|
278 |
-
# k_by_question = k_final // len(new_questions)
|
279 |
-
# for q in new_questions:
|
280 |
-
# q["k"] = k_by_question
|
281 |
-
|
282 |
-
# new_state["questions"] = new_questions
|
283 |
-
# new_state["remaining_questions"] = new_questions
|
284 |
-
|
285 |
-
n_questions = {
|
286 |
-
"total":len(new_questions),
|
287 |
-
"IPx":len([q for q in new_questions if q["index"] == "IPx"]),
|
288 |
-
"POC":len([q for q in new_questions if q["index"] == "POC"]),
|
289 |
-
}
|
290 |
-
|
291 |
-
new_state = {
|
292 |
-
"questions_list":new_questions,
|
293 |
-
"n_questions":n_questions,
|
294 |
-
"handled_questions_index":[],
|
295 |
-
}
|
296 |
-
return new_state
|
297 |
-
|
298 |
-
return transform_query
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/chains/retrieve_documents.py
DELETED
@@ -1,465 +0,0 @@
|
|
1 |
-
import sys
|
2 |
-
import os
|
3 |
-
from contextlib import contextmanager
|
4 |
-
|
5 |
-
from langchain_core.tools import tool
|
6 |
-
from langchain_core.runnables import chain
|
7 |
-
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
|
8 |
-
from langchain_core.runnables import RunnableLambda
|
9 |
-
|
10 |
-
from ..reranker import rerank_docs, rerank_and_sort_docs
|
11 |
-
# from ...knowledge.retriever import ClimateQARetriever
|
12 |
-
from ...knowledge.openalex import OpenAlexRetriever
|
13 |
-
from .keywords_extraction import make_keywords_extraction_chain
|
14 |
-
from ..utils import log_event
|
15 |
-
from langchain_core.vectorstores import VectorStore
|
16 |
-
from typing import List
|
17 |
-
from langchain_core.documents.base import Document
|
18 |
-
import asyncio
|
19 |
-
|
20 |
-
from typing import Any, Dict, List, Tuple
|
21 |
-
|
22 |
-
|
23 |
-
def divide_into_parts(target, parts):
|
24 |
-
# Base value for each part
|
25 |
-
base = target // parts
|
26 |
-
# Remainder to distribute
|
27 |
-
remainder = target % parts
|
28 |
-
# List to hold the result
|
29 |
-
result = []
|
30 |
-
|
31 |
-
for i in range(parts):
|
32 |
-
if i < remainder:
|
33 |
-
# These parts get base value + 1
|
34 |
-
result.append(base + 1)
|
35 |
-
else:
|
36 |
-
# The rest get the base value
|
37 |
-
result.append(base)
|
38 |
-
|
39 |
-
return result
|
40 |
-
|
41 |
-
|
42 |
-
@contextmanager
|
43 |
-
def suppress_output():
|
44 |
-
# Open a null device
|
45 |
-
with open(os.devnull, 'w') as devnull:
|
46 |
-
# Store the original stdout and stderr
|
47 |
-
old_stdout = sys.stdout
|
48 |
-
old_stderr = sys.stderr
|
49 |
-
# Redirect stdout and stderr to the null device
|
50 |
-
sys.stdout = devnull
|
51 |
-
sys.stderr = devnull
|
52 |
-
try:
|
53 |
-
yield
|
54 |
-
finally:
|
55 |
-
# Restore stdout and stderr
|
56 |
-
sys.stdout = old_stdout
|
57 |
-
sys.stderr = old_stderr
|
58 |
-
|
59 |
-
|
60 |
-
@tool
|
61 |
-
def query_retriever(question):
|
62 |
-
"""Just a dummy tool to simulate the retriever query"""
|
63 |
-
return question
|
64 |
-
|
65 |
-
def _add_sources_used_in_metadata(docs,sources,question,index):
|
66 |
-
for doc in docs:
|
67 |
-
doc.metadata["sources_used"] = sources
|
68 |
-
doc.metadata["question_used"] = question
|
69 |
-
doc.metadata["index_used"] = index
|
70 |
-
return docs
|
71 |
-
|
72 |
-
def _get_k_summary_by_question(n_questions):
|
73 |
-
if n_questions == 0:
|
74 |
-
return 0
|
75 |
-
elif n_questions == 1:
|
76 |
-
return 5
|
77 |
-
elif n_questions == 2:
|
78 |
-
return 3
|
79 |
-
elif n_questions == 3:
|
80 |
-
return 2
|
81 |
-
else:
|
82 |
-
return 1
|
83 |
-
|
84 |
-
def _get_k_images_by_question(n_questions):
|
85 |
-
if n_questions == 0:
|
86 |
-
return 0
|
87 |
-
elif n_questions == 1:
|
88 |
-
return 7
|
89 |
-
elif n_questions == 2:
|
90 |
-
return 5
|
91 |
-
elif n_questions == 3:
|
92 |
-
return 3
|
93 |
-
else:
|
94 |
-
return 1
|
95 |
-
|
96 |
-
def _add_metadata_and_score(docs: List) -> Document:
|
97 |
-
# Add score to metadata
|
98 |
-
docs_with_metadata = []
|
99 |
-
for i,(doc,score) in enumerate(docs):
|
100 |
-
doc.page_content = doc.page_content.replace("\r\n"," ")
|
101 |
-
doc.metadata["similarity_score"] = score
|
102 |
-
doc.metadata["content"] = doc.page_content
|
103 |
-
if doc.metadata["page_number"] != "N/A":
|
104 |
-
doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
|
105 |
-
else:
|
106 |
-
doc.metadata["page_number"] = 1
|
107 |
-
# doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
|
108 |
-
docs_with_metadata.append(doc)
|
109 |
-
return docs_with_metadata
|
110 |
-
|
111 |
-
def remove_duplicates_chunks(docs):
|
112 |
-
# Remove duplicates or almost duplicates
|
113 |
-
docs = sorted(docs,key=lambda x: x[1],reverse=True)
|
114 |
-
seen = set()
|
115 |
-
result = []
|
116 |
-
for doc in docs:
|
117 |
-
if doc[0].page_content not in seen:
|
118 |
-
seen.add(doc[0].page_content)
|
119 |
-
result.append(doc)
|
120 |
-
return result
|
121 |
-
|
122 |
-
async def get_POC_relevant_documents(
|
123 |
-
query: str,
|
124 |
-
vectorstore:VectorStore,
|
125 |
-
sources:list = ["Acclimaterra","PCAET","Plan Biodiversite"],
|
126 |
-
search_figures:bool = False,
|
127 |
-
search_only:bool = False,
|
128 |
-
k_documents:int = 10,
|
129 |
-
threshold:float = 0.6,
|
130 |
-
k_images: int = 5,
|
131 |
-
reports:list = [],
|
132 |
-
min_size:int = 200,
|
133 |
-
) :
|
134 |
-
# Prepare base search kwargs
|
135 |
-
filters = {}
|
136 |
-
docs_question = []
|
137 |
-
docs_images = []
|
138 |
-
|
139 |
-
# TODO add source selection
|
140 |
-
# if len(reports) > 0:
|
141 |
-
# filters["short_name"] = {"$in":reports}
|
142 |
-
# else:
|
143 |
-
# filters["source"] = { "$in": sources}
|
144 |
-
|
145 |
-
filters_text = {
|
146 |
-
**filters,
|
147 |
-
"chunk_type":"text",
|
148 |
-
# "report_type": {}, # TODO to be completed to choose the right documents / chapters according to the analysis of the question
|
149 |
-
}
|
150 |
-
|
151 |
-
docs_question = vectorstore.similarity_search_with_score(query=query,filter = filters_text,k = k_documents)
|
152 |
-
# remove duplicates or almost duplicates
|
153 |
-
docs_question = remove_duplicates_chunks(docs_question)
|
154 |
-
docs_question = [x for x in docs_question if x[1] > threshold]
|
155 |
-
|
156 |
-
if search_figures:
|
157 |
-
# Images
|
158 |
-
filters_image = {
|
159 |
-
**filters,
|
160 |
-
"chunk_type":"image"
|
161 |
-
}
|
162 |
-
docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images)
|
163 |
-
|
164 |
-
docs_question, docs_images = _add_metadata_and_score(docs_question), _add_metadata_and_score(docs_images)
|
165 |
-
|
166 |
-
docs_question = [x for x in docs_question if len(x.page_content) > min_size]
|
167 |
-
|
168 |
-
return {
|
169 |
-
"docs_question" : docs_question,
|
170 |
-
"docs_images" : docs_images
|
171 |
-
}
|
172 |
-
|
173 |
-
|
174 |
-
async def get_IPCC_relevant_documents(
|
175 |
-
query: str,
|
176 |
-
vectorstore:VectorStore,
|
177 |
-
sources:list = ["IPCC","IPBES","IPOS"],
|
178 |
-
search_figures:bool = False,
|
179 |
-
reports:list = [],
|
180 |
-
threshold:float = 0.6,
|
181 |
-
k_summary:int = 3,
|
182 |
-
k_total:int = 10,
|
183 |
-
k_images: int = 5,
|
184 |
-
namespace:str = "vectors",
|
185 |
-
min_size:int = 200,
|
186 |
-
search_only:bool = False,
|
187 |
-
) :
|
188 |
-
|
189 |
-
# Check if all elements in the list are either IPCC or IPBES
|
190 |
-
assert isinstance(sources,list)
|
191 |
-
assert sources
|
192 |
-
assert all([x in ["IPCC","IPBES","IPOS"] for x in sources])
|
193 |
-
assert k_total > k_summary, "k_total should be greater than k_summary"
|
194 |
-
|
195 |
-
# Prepare base search kwargs
|
196 |
-
filters = {}
|
197 |
-
|
198 |
-
if len(reports) > 0:
|
199 |
-
filters["short_name"] = {"$in":reports}
|
200 |
-
else:
|
201 |
-
filters["source"] = { "$in": sources}
|
202 |
-
|
203 |
-
# INIT
|
204 |
-
docs_summaries = []
|
205 |
-
docs_full = []
|
206 |
-
docs_images = []
|
207 |
-
|
208 |
-
if search_only:
|
209 |
-
# Only search for images if search_only is True
|
210 |
-
if search_figures:
|
211 |
-
filters_image = {
|
212 |
-
**filters,
|
213 |
-
"chunk_type":"image"
|
214 |
-
}
|
215 |
-
docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images)
|
216 |
-
docs_images = _add_metadata_and_score(docs_images)
|
217 |
-
else:
|
218 |
-
# Regular search flow for text and optionally images
|
219 |
-
# Search for k_summary documents in the summaries dataset
|
220 |
-
filters_summaries = {
|
221 |
-
**filters,
|
222 |
-
"chunk_type":"text",
|
223 |
-
"report_type": { "$in":["SPM"]},
|
224 |
-
}
|
225 |
-
|
226 |
-
docs_summaries = vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = k_summary)
|
227 |
-
docs_summaries = [x for x in docs_summaries if x[1] > threshold]
|
228 |
-
|
229 |
-
# Search for k_total - k_summary documents in the full reports dataset
|
230 |
-
filters_full = {
|
231 |
-
**filters,
|
232 |
-
"chunk_type":"text",
|
233 |
-
"report_type": { "$nin":["SPM"]},
|
234 |
-
}
|
235 |
-
docs_full = vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_total)
|
236 |
-
|
237 |
-
if search_figures:
|
238 |
-
# Images
|
239 |
-
filters_image = {
|
240 |
-
**filters,
|
241 |
-
"chunk_type":"image"
|
242 |
-
}
|
243 |
-
docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images)
|
244 |
-
|
245 |
-
docs_summaries, docs_full, docs_images = _add_metadata_and_score(docs_summaries), _add_metadata_and_score(docs_full), _add_metadata_and_score(docs_images)
|
246 |
-
|
247 |
-
# Filter if length are below threshold
|
248 |
-
docs_summaries = [x for x in docs_summaries if len(x.page_content) > min_size]
|
249 |
-
docs_full = [x for x in docs_full if len(x.page_content) > min_size]
|
250 |
-
|
251 |
-
return {
|
252 |
-
"docs_summaries" : docs_summaries,
|
253 |
-
"docs_full" : docs_full,
|
254 |
-
"docs_images" : docs_images,
|
255 |
-
}
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
def concatenate_documents(index, source_type, docs_question_dict, k_by_question, k_summary_by_question, k_images_by_question):
|
260 |
-
# Keep the right number of documents - The k_summary documents from SPM are placed in front
|
261 |
-
if source_type == "IPx":
|
262 |
-
docs_question = docs_question_dict["docs_summaries"][:k_summary_by_question] + docs_question_dict["docs_full"][:(k_by_question - k_summary_by_question)]
|
263 |
-
elif source_type == "POC" :
|
264 |
-
docs_question = docs_question_dict["docs_question"][:k_by_question]
|
265 |
-
else :
|
266 |
-
raise ValueError("source_type should be either Vector or POC")
|
267 |
-
# docs_question = [doc for key in docs_question_dict.keys() for doc in docs_question_dict[key]][:(k_by_question)]
|
268 |
-
|
269 |
-
images_question = docs_question_dict["docs_images"][:k_images_by_question]
|
270 |
-
|
271 |
-
return docs_question, images_question
|
272 |
-
|
273 |
-
|
274 |
-
# The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results
|
275 |
-
# @chain
|
276 |
-
async def retrieve_documents(
|
277 |
-
current_question: Dict[str, Any],
|
278 |
-
config: Dict[str, Any],
|
279 |
-
source_type: str,
|
280 |
-
vectorstore: VectorStore,
|
281 |
-
reranker: Any,
|
282 |
-
search_figures: bool = False,
|
283 |
-
search_only: bool = False,
|
284 |
-
reports: list = [],
|
285 |
-
rerank_by_question: bool = True,
|
286 |
-
k_images_by_question: int = 5,
|
287 |
-
k_before_reranking: int = 100,
|
288 |
-
k_by_question: int = 5,
|
289 |
-
k_summary_by_question: int = 3
|
290 |
-
) -> Tuple[List[Document], List[Document]]:
|
291 |
-
"""
|
292 |
-
Unpack the first question of the remaining questions, and retrieve and rerank corresponding documents, based on the question and selected_sources
|
293 |
-
|
294 |
-
Args:
|
295 |
-
state (dict): The current state containing documents, related content, relevant content sources, remaining questions and n_questions.
|
296 |
-
current_question (dict): The current question being processed.
|
297 |
-
config (dict): Configuration settings for logging and other purposes.
|
298 |
-
vectorstore (object): The vector store used to retrieve relevant documents.
|
299 |
-
reranker (object): The reranker used to rerank the retrieved documents.
|
300 |
-
llm (object): The language model used for processing.
|
301 |
-
rerank_by_question (bool, optional): Whether to rerank documents by question. Defaults to True.
|
302 |
-
k_final (int, optional): The final number of documents to retrieve. Defaults to 15.
|
303 |
-
k_before_reranking (int, optional): The number of documents to retrieve before reranking. Defaults to 100.
|
304 |
-
k_summary (int, optional): The number of summary documents to retrieve. Defaults to 5.
|
305 |
-
k_images (int, optional): The number of image documents to retrieve. Defaults to 5.
|
306 |
-
Returns:
|
307 |
-
dict: The updated state containing the retrieved and reranked documents, related content, and remaining questions.
|
308 |
-
"""
|
309 |
-
sources = current_question["sources"]
|
310 |
-
question = current_question["question"]
|
311 |
-
index = current_question["index"]
|
312 |
-
source_type = current_question["source_type"]
|
313 |
-
|
314 |
-
print(f"Retrieve documents for question: {question}")
|
315 |
-
await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config)
|
316 |
-
|
317 |
-
print(f"""---- Retrieve documents from {current_question["source_type"]}----""")
|
318 |
-
|
319 |
-
if source_type == "IPx":
|
320 |
-
docs_question_dict = await get_IPCC_relevant_documents(
|
321 |
-
query = question,
|
322 |
-
vectorstore=vectorstore,
|
323 |
-
search_figures = search_figures,
|
324 |
-
sources = sources,
|
325 |
-
min_size = 200,
|
326 |
-
k_summary = k_before_reranking-1,
|
327 |
-
k_total = k_before_reranking,
|
328 |
-
k_images = k_images_by_question,
|
329 |
-
threshold = 0.5,
|
330 |
-
search_only = search_only,
|
331 |
-
reports = reports,
|
332 |
-
)
|
333 |
-
|
334 |
-
if source_type == "POC":
|
335 |
-
docs_question_dict = await get_POC_relevant_documents(
|
336 |
-
query = question,
|
337 |
-
vectorstore=vectorstore,
|
338 |
-
search_figures = search_figures,
|
339 |
-
sources = sources,
|
340 |
-
threshold = 0.5,
|
341 |
-
search_only = search_only,
|
342 |
-
reports = reports,
|
343 |
-
min_size= 200,
|
344 |
-
k_documents= k_before_reranking,
|
345 |
-
k_images= k_by_question
|
346 |
-
)
|
347 |
-
|
348 |
-
# Rerank
|
349 |
-
if reranker is not None and rerank_by_question:
|
350 |
-
with suppress_output():
|
351 |
-
for key in docs_question_dict.keys():
|
352 |
-
docs_question_dict[key] = rerank_and_sort_docs(reranker,docs_question_dict[key],question)
|
353 |
-
else:
|
354 |
-
# Add a default reranking score
|
355 |
-
for doc in docs_question:
|
356 |
-
doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
|
357 |
-
|
358 |
-
# Keep the right number of documents
|
359 |
-
docs_question, images_question = concatenate_documents(index, source_type, docs_question_dict, k_by_question, k_summary_by_question, k_images_by_question)
|
360 |
-
|
361 |
-
# Rerank the documents to put the most relevant in front
|
362 |
-
if reranker is not None and rerank_by_question:
|
363 |
-
docs_question = rerank_and_sort_docs(reranker, docs_question, question)
|
364 |
-
|
365 |
-
# Add sources used in the metadata
|
366 |
-
docs_question = _add_sources_used_in_metadata(docs_question,sources,question,index)
|
367 |
-
images_question = _add_sources_used_in_metadata(images_question,sources,question,index)
|
368 |
-
|
369 |
-
return docs_question, images_question
|
370 |
-
|
371 |
-
|
372 |
-
async def retrieve_documents_for_all_questions(state, config, source_type, to_handle_questions_index, vectorstore, reranker, rerank_by_question=True, k_final=15, k_before_reranking=100):
|
373 |
-
"""
|
374 |
-
Retrieve documents in parallel for all questions.
|
375 |
-
"""
|
376 |
-
# to_handle_questions_index = [x for x in state["questions_list"] if x["source_type"] == "IPx"]
|
377 |
-
|
378 |
-
# TODO split les questions selon le type de sources dans le state question + conditions sur le nombre de questions traités par type de source
|
379 |
-
docs = state.get("documents", [])
|
380 |
-
related_content = state.get("related_content", [])
|
381 |
-
search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"]
|
382 |
-
search_only = state["search_only"]
|
383 |
-
reports = state["reports"]
|
384 |
-
|
385 |
-
k_by_question = k_final // state["n_questions"]["total"]
|
386 |
-
k_summary_by_question = _get_k_summary_by_question(state["n_questions"]["total"])
|
387 |
-
k_images_by_question = _get_k_images_by_question(state["n_questions"]["total"])
|
388 |
-
k_before_reranking=100
|
389 |
-
|
390 |
-
tasks = [
|
391 |
-
retrieve_documents(
|
392 |
-
current_question=question,
|
393 |
-
config=config,
|
394 |
-
source_type=source_type,
|
395 |
-
vectorstore=vectorstore,
|
396 |
-
reranker=reranker,
|
397 |
-
search_figures=search_figures,
|
398 |
-
search_only=search_only,
|
399 |
-
reports=reports,
|
400 |
-
rerank_by_question=rerank_by_question,
|
401 |
-
k_images_by_question=k_images_by_question,
|
402 |
-
k_before_reranking=k_before_reranking,
|
403 |
-
k_by_question=k_by_question,
|
404 |
-
k_summary_by_question=k_summary_by_question
|
405 |
-
)
|
406 |
-
for i, question in enumerate(state["questions_list"]) if i in to_handle_questions_index
|
407 |
-
]
|
408 |
-
results = await asyncio.gather(*tasks)
|
409 |
-
# Combine results
|
410 |
-
new_state = {"documents": [], "related_contents": [], "handled_questions_index": to_handle_questions_index}
|
411 |
-
for docs_question, images_question in results:
|
412 |
-
new_state["documents"].extend(docs_question)
|
413 |
-
new_state["related_contents"].extend(images_question)
|
414 |
-
return new_state
|
415 |
-
|
416 |
-
def make_IPx_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
|
417 |
-
|
418 |
-
async def retrieve_IPx_docs(state, config):
|
419 |
-
source_type = "IPx"
|
420 |
-
IPx_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "IPx"]
|
421 |
-
|
422 |
-
# return {"documents":[], "related_contents": [], "handled_questions_index": list(range(len(state["questions_list"])))} # TODO Remove
|
423 |
-
|
424 |
-
state = await retrieve_documents_for_all_questions(
|
425 |
-
state=state,
|
426 |
-
config=config,
|
427 |
-
source_type=source_type,
|
428 |
-
to_handle_questions_index=IPx_questions_index,
|
429 |
-
vectorstore=vectorstore,
|
430 |
-
reranker=reranker,
|
431 |
-
rerank_by_question=rerank_by_question,
|
432 |
-
k_final=k_final,
|
433 |
-
k_before_reranking=k_before_reranking,
|
434 |
-
)
|
435 |
-
return state
|
436 |
-
|
437 |
-
return retrieve_IPx_docs
|
438 |
-
|
439 |
-
|
440 |
-
def make_POC_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
|
441 |
-
|
442 |
-
async def retrieve_POC_docs_node(state, config):
|
443 |
-
if "POC region" not in state["relevant_content_sources_selection"] :
|
444 |
-
return {}
|
445 |
-
|
446 |
-
source_type = "POC"
|
447 |
-
POC_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"]
|
448 |
-
|
449 |
-
state = await retrieve_documents_for_all_questions(
|
450 |
-
state=state,
|
451 |
-
config=config,
|
452 |
-
source_type=source_type,
|
453 |
-
to_handle_questions_index=POC_questions_index,
|
454 |
-
vectorstore=vectorstore,
|
455 |
-
reranker=reranker,
|
456 |
-
rerank_by_question=rerank_by_question,
|
457 |
-
k_final=k_final,
|
458 |
-
k_before_reranking=k_before_reranking,
|
459 |
-
)
|
460 |
-
return state
|
461 |
-
|
462 |
-
return retrieve_POC_docs_node
|
463 |
-
|
464 |
-
|
465 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/chains/retrieve_papers.py
DELETED
@@ -1,95 +0,0 @@
|
|
1 |
-
from climateqa.engine.keywords import make_keywords_chain
|
2 |
-
from climateqa.engine.llm import get_llm
|
3 |
-
from climateqa.knowledge.openalex import OpenAlex
|
4 |
-
from climateqa.engine.chains.answer_rag import make_rag_papers_chain
|
5 |
-
from front.utils import make_html_papers
|
6 |
-
from climateqa.engine.reranker import get_reranker
|
7 |
-
|
8 |
-
oa = OpenAlex()
|
9 |
-
|
10 |
-
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
|
11 |
-
reranker = get_reranker("nano")
|
12 |
-
|
13 |
-
|
14 |
-
papers_cols_widths = {
|
15 |
-
"id":100,
|
16 |
-
"title":300,
|
17 |
-
"doi":100,
|
18 |
-
"publication_year":100,
|
19 |
-
"abstract":500,
|
20 |
-
"is_oa":50,
|
21 |
-
}
|
22 |
-
|
23 |
-
papers_cols = list(papers_cols_widths.keys())
|
24 |
-
papers_cols_widths = list(papers_cols_widths.values())
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
def generate_keywords(query):
|
29 |
-
chain = make_keywords_chain(llm)
|
30 |
-
keywords = chain.invoke(query)
|
31 |
-
keywords = " AND ".join(keywords["keywords"])
|
32 |
-
return keywords
|
33 |
-
|
34 |
-
|
35 |
-
async def find_papers(query,after, relevant_content_sources_selection, reranker= reranker):
|
36 |
-
if "Papers (OpenAlex)" in relevant_content_sources_selection:
|
37 |
-
summary = ""
|
38 |
-
keywords = generate_keywords(query)
|
39 |
-
df_works = oa.search(keywords,after = after)
|
40 |
-
|
41 |
-
print(f"Found {len(df_works)} papers")
|
42 |
-
|
43 |
-
if not df_works.empty:
|
44 |
-
df_works = df_works.dropna(subset=["abstract"])
|
45 |
-
df_works = df_works[df_works["abstract"] != ""].reset_index(drop = True)
|
46 |
-
df_works = oa.rerank(query,df_works,reranker)
|
47 |
-
df_works = df_works.sort_values("rerank_score",ascending=False)
|
48 |
-
docs_html = []
|
49 |
-
for i in range(10):
|
50 |
-
docs_html.append(make_html_papers(df_works, i))
|
51 |
-
docs_html = "".join(docs_html)
|
52 |
-
G = oa.make_network(df_works)
|
53 |
-
|
54 |
-
height = "750px"
|
55 |
-
network = oa.show_network(G,color_by = "rerank_score",notebook=False,height = height)
|
56 |
-
network_html = network.generate_html()
|
57 |
-
|
58 |
-
network_html = network_html.replace("'", "\"")
|
59 |
-
css_to_inject = "<style>#mynetwork { border: none !important; } .card { border: none !important; }</style>"
|
60 |
-
network_html = network_html + css_to_inject
|
61 |
-
|
62 |
-
|
63 |
-
network_html = f"""<iframe style="width: 100%; height: {height};margin:0 auto" name="result" allow="midi; geolocation; microphone; camera;
|
64 |
-
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
|
65 |
-
allow-scripts allow-same-origin allow-popups
|
66 |
-
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
|
67 |
-
allowpaymentrequest="" frameborder="0" srcdoc='{network_html}'></iframe>"""
|
68 |
-
|
69 |
-
|
70 |
-
docs = df_works["content"].head(10).tolist()
|
71 |
-
|
72 |
-
df_works = df_works.reset_index(drop = True).reset_index().rename(columns = {"index":"doc"})
|
73 |
-
df_works["doc"] = df_works["doc"] + 1
|
74 |
-
df_works = df_works[papers_cols]
|
75 |
-
|
76 |
-
yield docs_html, network_html, summary
|
77 |
-
|
78 |
-
chain = make_rag_papers_chain(llm)
|
79 |
-
result = chain.astream_log({"question": query,"docs": docs,"language":"English"})
|
80 |
-
path_answer = "/logs/StrOutputParser/streamed_output/-"
|
81 |
-
|
82 |
-
async for op in result:
|
83 |
-
|
84 |
-
op = op.ops[0]
|
85 |
-
|
86 |
-
if op['path'] == path_answer: # reforulated question
|
87 |
-
new_token = op['value'] # str
|
88 |
-
summary += new_token
|
89 |
-
else:
|
90 |
-
continue
|
91 |
-
yield docs_html, network_html, summary
|
92 |
-
else :
|
93 |
-
print("No papers found")
|
94 |
-
else :
|
95 |
-
yield "","", ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/chains/retriever.py
DELETED
@@ -1,126 +0,0 @@
|
|
1 |
-
# import sys
|
2 |
-
# import os
|
3 |
-
# from contextlib import contextmanager
|
4 |
-
|
5 |
-
# from ..reranker import rerank_docs
|
6 |
-
# from ...knowledge.retriever import ClimateQARetriever
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
# def divide_into_parts(target, parts):
|
12 |
-
# # Base value for each part
|
13 |
-
# base = target // parts
|
14 |
-
# # Remainder to distribute
|
15 |
-
# remainder = target % parts
|
16 |
-
# # List to hold the result
|
17 |
-
# result = []
|
18 |
-
|
19 |
-
# for i in range(parts):
|
20 |
-
# if i < remainder:
|
21 |
-
# # These parts get base value + 1
|
22 |
-
# result.append(base + 1)
|
23 |
-
# else:
|
24 |
-
# # The rest get the base value
|
25 |
-
# result.append(base)
|
26 |
-
|
27 |
-
# return result
|
28 |
-
|
29 |
-
|
30 |
-
# @contextmanager
|
31 |
-
# def suppress_output():
|
32 |
-
# # Open a null device
|
33 |
-
# with open(os.devnull, 'w') as devnull:
|
34 |
-
# # Store the original stdout and stderr
|
35 |
-
# old_stdout = sys.stdout
|
36 |
-
# old_stderr = sys.stderr
|
37 |
-
# # Redirect stdout and stderr to the null device
|
38 |
-
# sys.stdout = devnull
|
39 |
-
# sys.stderr = devnull
|
40 |
-
# try:
|
41 |
-
# yield
|
42 |
-
# finally:
|
43 |
-
# # Restore stdout and stderr
|
44 |
-
# sys.stdout = old_stdout
|
45 |
-
# sys.stderr = old_stderr
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
# def make_retriever_node(vectorstore,reranker,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
|
50 |
-
|
51 |
-
# def retrieve_documents(state):
|
52 |
-
|
53 |
-
# POSSIBLE_SOURCES = ["IPCC","IPBES","IPOS"] # ,"OpenAlex"]
|
54 |
-
# questions = state["questions"]
|
55 |
-
|
56 |
-
# # Use sources from the user input or from the LLM detection
|
57 |
-
# if "sources_input" not in state or state["sources_input"] is None:
|
58 |
-
# sources_input = ["auto"]
|
59 |
-
# else:
|
60 |
-
# sources_input = state["sources_input"]
|
61 |
-
# auto_mode = "auto" in sources_input
|
62 |
-
|
63 |
-
# # There are several options to get the final top k
|
64 |
-
# # Option 1 - Get 100 documents by question and rerank by question
|
65 |
-
# # Option 2 - Get 100/n documents by question and rerank the total
|
66 |
-
# if rerank_by_question:
|
67 |
-
# k_by_question = divide_into_parts(k_final,len(questions))
|
68 |
-
|
69 |
-
# docs = []
|
70 |
-
|
71 |
-
# for i,q in enumerate(questions):
|
72 |
-
|
73 |
-
# sources = q["sources"]
|
74 |
-
# question = q["question"]
|
75 |
-
|
76 |
-
# # If auto mode, we use the sources detected by the LLM
|
77 |
-
# if auto_mode:
|
78 |
-
# sources = [x for x in sources if x in POSSIBLE_SOURCES]
|
79 |
-
|
80 |
-
# # Otherwise, we use the config
|
81 |
-
# else:
|
82 |
-
# sources = sources_input
|
83 |
-
|
84 |
-
# # Search the document store using the retriever
|
85 |
-
# # Configure high top k for further reranking step
|
86 |
-
# retriever = ClimateQARetriever(
|
87 |
-
# vectorstore=vectorstore,
|
88 |
-
# sources = sources,
|
89 |
-
# # reports = ias_reports,
|
90 |
-
# min_size = 200,
|
91 |
-
# k_summary = k_summary,
|
92 |
-
# k_total = k_before_reranking,
|
93 |
-
# threshold = 0.5,
|
94 |
-
# )
|
95 |
-
# docs_question = retriever.get_relevant_documents(question)
|
96 |
-
|
97 |
-
# # Rerank
|
98 |
-
# if reranker is not None:
|
99 |
-
# with suppress_output():
|
100 |
-
# docs_question = rerank_docs(reranker,docs_question,question)
|
101 |
-
# else:
|
102 |
-
# # Add a default reranking score
|
103 |
-
# for doc in docs_question:
|
104 |
-
# doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
|
105 |
-
|
106 |
-
# # If rerank by question we select the top documents for each question
|
107 |
-
# if rerank_by_question:
|
108 |
-
# docs_question = docs_question[:k_by_question[i]]
|
109 |
-
|
110 |
-
# # Add sources used in the metadata
|
111 |
-
# for doc in docs_question:
|
112 |
-
# doc.metadata["sources_used"] = sources
|
113 |
-
|
114 |
-
# # Add to the list of docs
|
115 |
-
# docs.extend(docs_question)
|
116 |
-
|
117 |
-
# # Sorting the list in descending order by rerank_score
|
118 |
-
# # Then select the top k
|
119 |
-
# docs = sorted(docs, key=lambda x: x.metadata["reranking_score"], reverse=True)
|
120 |
-
# docs = docs[:k_final]
|
121 |
-
|
122 |
-
# new_state = {"documents":docs}
|
123 |
-
# return new_state
|
124 |
-
|
125 |
-
# return retrieve_documents
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/chains/sample_router.py
DELETED
@@ -1,66 +0,0 @@
|
|
1 |
-
|
2 |
-
# from typing import List
|
3 |
-
# from typing import Literal
|
4 |
-
# from langchain.prompts import ChatPromptTemplate
|
5 |
-
# from langchain_core.utils.function_calling import convert_to_openai_function
|
6 |
-
# from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
7 |
-
|
8 |
-
# # https://livingdatalab.com/posts/2023-11-05-openai-function-calling-with-langchain.html
|
9 |
-
|
10 |
-
# class Location(BaseModel):
|
11 |
-
# country:str = Field(...,description="The country if directly mentioned or inferred from the location (cities, regions, adresses), ex: France, USA, ...")
|
12 |
-
# location:str = Field(...,description="The specific place if mentioned (cities, regions, addresses), ex: Marseille, New York, Wisconsin, ...")
|
13 |
-
|
14 |
-
# class QueryAnalysis(BaseModel):
|
15 |
-
# """Analyzing the user query"""
|
16 |
-
|
17 |
-
# language: str = Field(
|
18 |
-
# description="Find the language of the query in full words (ex: French, English, Spanish, ...), defaults to English"
|
19 |
-
# )
|
20 |
-
# intent: str = Field(
|
21 |
-
# enum=[
|
22 |
-
# "Environmental impacts of AI",
|
23 |
-
# "Geolocated info about climate change",
|
24 |
-
# "Climate change",
|
25 |
-
# "Biodiversity",
|
26 |
-
# "Deep sea mining",
|
27 |
-
# "Chitchat",
|
28 |
-
# ],
|
29 |
-
# description="""
|
30 |
-
# Categorize the user query in one of the following category,
|
31 |
-
|
32 |
-
# Examples:
|
33 |
-
# - Geolocated info about climate change: "What will be the temperature in Marseille in 2050"
|
34 |
-
# - Climate change: "What is radiative forcing", "How much will
|
35 |
-
# """,
|
36 |
-
# )
|
37 |
-
# sources: List[Literal["IPCC", "IPBES", "IPOS"]] = Field(
|
38 |
-
# ...,
|
39 |
-
# description="""
|
40 |
-
# Given a user question choose which documents would be most relevant for answering their question,
|
41 |
-
# - IPCC is for questions about climate change, energy, impacts, and everything we can find the IPCC reports
|
42 |
-
# - IPBES is for questions about biodiversity and nature
|
43 |
-
# - IPOS is for questions about the ocean and deep sea mining
|
44 |
-
|
45 |
-
# """,
|
46 |
-
# )
|
47 |
-
# date: str = Field(description="The date or period mentioned, ex: 2050, between 2020 and 2050")
|
48 |
-
# location:Location
|
49 |
-
# # query: str = Field(
|
50 |
-
# # description = """
|
51 |
-
# # Translate to english and reformulate the following user message to be a short standalone question, in the context of an educational discussion about climate change.
|
52 |
-
# # The reformulated question will used in a search engine
|
53 |
-
# # By default, assume that the user is asking information about the last century,
|
54 |
-
# # Use the following examples
|
55 |
-
|
56 |
-
# # ### Examples:
|
57 |
-
# # La technologie nous sauvera-t-elle ? -> Can technology help humanity mitigate the effects of climate change?
|
58 |
-
# # what are our reserves in fossil fuel? -> What are the current reserves of fossil fuels and how long will they last?
|
59 |
-
# # what are the main causes of climate change? -> What are the main causes of climate change in the last century?
|
60 |
-
|
61 |
-
# # Question in English:
|
62 |
-
# # """
|
63 |
-
# # )
|
64 |
-
|
65 |
-
# openai_functions = [convert_to_openai_function(QueryAnalysis)]
|
66 |
-
# llm2 = llm.bind(functions = openai_functions,function_call={"name":"QueryAnalysis"})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/chains/set_defaults.py
DELETED
@@ -1,13 +0,0 @@
|
|
1 |
-
def set_defaults(state):
|
2 |
-
print("---- Setting defaults ----")
|
3 |
-
|
4 |
-
if not state["audience"] or state["audience"] is None:
|
5 |
-
state.update({"audience": "experts"})
|
6 |
-
|
7 |
-
sources_input = state["sources_input"] if "sources_input" in state else ["auto"]
|
8 |
-
state.update({"sources_input": sources_input})
|
9 |
-
|
10 |
-
# if not state["sources_input"] or state["sources_input"] is None:
|
11 |
-
# state.update({"sources_input": ["auto"]})
|
12 |
-
|
13 |
-
return state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/chains/translation.py
DELETED
@@ -1,42 +0,0 @@
|
|
1 |
-
|
2 |
-
from langchain_core.pydantic_v1 import BaseModel, Field
|
3 |
-
from typing import List
|
4 |
-
from typing import Literal
|
5 |
-
from langchain.prompts import ChatPromptTemplate
|
6 |
-
from langchain_core.utils.function_calling import convert_to_openai_function
|
7 |
-
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
8 |
-
|
9 |
-
|
10 |
-
class Translation(BaseModel):
|
11 |
-
"""Analyzing the user message input"""
|
12 |
-
|
13 |
-
translation: str = Field(
|
14 |
-
description="Translate the message input to English",
|
15 |
-
)
|
16 |
-
|
17 |
-
|
18 |
-
def make_translation_chain(llm):
|
19 |
-
|
20 |
-
openai_functions = [convert_to_openai_function(Translation)]
|
21 |
-
llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"Translation"})
|
22 |
-
|
23 |
-
prompt = ChatPromptTemplate.from_messages([
|
24 |
-
("system", "You are a helpful assistant, you will translate the user input message to English using the function provided"),
|
25 |
-
("user", "input: {input}")
|
26 |
-
])
|
27 |
-
|
28 |
-
chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
|
29 |
-
return chain
|
30 |
-
|
31 |
-
|
32 |
-
def make_translation_node(llm):
|
33 |
-
translation_chain = make_translation_chain(llm)
|
34 |
-
|
35 |
-
def translate_query(state):
|
36 |
-
print("---- Translate query ----")
|
37 |
-
|
38 |
-
user_input = state["user_input"]
|
39 |
-
translation = translation_chain.invoke({"input":user_input})
|
40 |
-
return {"query":translation["translation"]}
|
41 |
-
|
42 |
-
return translate_query
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/embeddings.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
3 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
4 |
|
5 |
-
def get_embeddings_function(version = "v1.2"
|
6 |
|
7 |
if version == "v1.2":
|
8 |
|
@@ -10,12 +10,12 @@ def get_embeddings_function(version = "v1.2",query_instruction = "Represent this
|
|
10 |
# Best embedding model at a reasonable size at the moment (2023-11-22)
|
11 |
|
12 |
model_name = "BAAI/bge-base-en-v1.5"
|
13 |
-
encode_kwargs = {'normalize_embeddings': True
|
14 |
print("Loading embeddings model: ", model_name)
|
15 |
embeddings_function = HuggingFaceBgeEmbeddings(
|
16 |
model_name=model_name,
|
17 |
encode_kwargs=encode_kwargs,
|
18 |
-
query_instruction=
|
19 |
)
|
20 |
|
21 |
else:
|
@@ -23,6 +23,3 @@ def get_embeddings_function(version = "v1.2",query_instruction = "Represent this
|
|
23 |
embeddings_function = HuggingFaceEmbeddings(model_name = "sentence-transformers/multi-qa-mpnet-base-dot-v1")
|
24 |
|
25 |
return embeddings_function
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
2 |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
3 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
4 |
|
5 |
+
def get_embeddings_function(version = "v1.2"):
|
6 |
|
7 |
if version == "v1.2":
|
8 |
|
|
|
10 |
# Best embedding model at a reasonable size at the moment (2023-11-22)
|
11 |
|
12 |
model_name = "BAAI/bge-base-en-v1.5"
|
13 |
+
encode_kwargs = {'normalize_embeddings': True} # set True to compute cosine similarity
|
14 |
print("Loading embeddings model: ", model_name)
|
15 |
embeddings_function = HuggingFaceBgeEmbeddings(
|
16 |
model_name=model_name,
|
17 |
encode_kwargs=encode_kwargs,
|
18 |
+
query_instruction="Represent this sentence for searching relevant passages: "
|
19 |
)
|
20 |
|
21 |
else:
|
|
|
23 |
embeddings_function = HuggingFaceEmbeddings(model_name = "sentence-transformers/multi-qa-mpnet-base-dot-v1")
|
24 |
|
25 |
return embeddings_function
|
|
|
|
|
|
climateqa/engine/graph.py
DELETED
@@ -1,333 +0,0 @@
|
|
1 |
-
import sys
|
2 |
-
import os
|
3 |
-
from contextlib import contextmanager
|
4 |
-
|
5 |
-
from langchain.schema import Document
|
6 |
-
from langgraph.graph import END, StateGraph
|
7 |
-
from langchain_core.runnables.graph import CurveStyle, MermaidDrawMethod
|
8 |
-
|
9 |
-
from typing_extensions import TypedDict
|
10 |
-
from typing import List, Dict
|
11 |
-
|
12 |
-
import operator
|
13 |
-
from typing import Annotated
|
14 |
-
|
15 |
-
from IPython.display import display, HTML, Image
|
16 |
-
|
17 |
-
from .chains.answer_chitchat import make_chitchat_node
|
18 |
-
from .chains.answer_ai_impact import make_ai_impact_node
|
19 |
-
from .chains.query_transformation import make_query_transform_node
|
20 |
-
from .chains.translation import make_translation_node
|
21 |
-
from .chains.intent_categorization import make_intent_categorization_node
|
22 |
-
from .chains.retrieve_documents import make_IPx_retriever_node, make_POC_retriever_node
|
23 |
-
from .chains.answer_rag import make_rag_node
|
24 |
-
from .chains.graph_retriever import make_graph_retriever_node
|
25 |
-
from .chains.chitchat_categorization import make_chitchat_intent_categorization_node
|
26 |
-
# from .chains.set_defaults import set_defaults
|
27 |
-
|
28 |
-
class GraphState(TypedDict):
|
29 |
-
"""
|
30 |
-
Represents the state of our graph.
|
31 |
-
"""
|
32 |
-
user_input : str
|
33 |
-
language : str
|
34 |
-
intent : str
|
35 |
-
search_graphs_chitchat : bool
|
36 |
-
query: str
|
37 |
-
questions_list : List[dict]
|
38 |
-
handled_questions_index : Annotated[list[int], operator.add]
|
39 |
-
n_questions : int
|
40 |
-
answer: str
|
41 |
-
audience: str = "experts"
|
42 |
-
sources_input: List[str] = ["IPCC","IPBES"]
|
43 |
-
relevant_content_sources_selection: List[str] = ["Figures (IPCC/IPBES)"]
|
44 |
-
sources_auto: bool = True
|
45 |
-
min_year: int = 1960
|
46 |
-
max_year: int = None
|
47 |
-
documents: Annotated[List[Document], operator.add]
|
48 |
-
related_contents : Annotated[List[Document], operator.add]
|
49 |
-
recommended_content : List[Document]
|
50 |
-
search_only : bool = False
|
51 |
-
reports : List[str] = []
|
52 |
-
|
53 |
-
def dummy(state):
|
54 |
-
return
|
55 |
-
|
56 |
-
def search(state): #TODO
|
57 |
-
return
|
58 |
-
|
59 |
-
def answer_search(state):#TODO
|
60 |
-
return
|
61 |
-
|
62 |
-
def route_intent(state):
|
63 |
-
intent = state["intent"]
|
64 |
-
if intent in ["chitchat","esg"]:
|
65 |
-
return "answer_chitchat"
|
66 |
-
# elif intent == "ai_impact":
|
67 |
-
# return "answer_ai_impact"
|
68 |
-
else:
|
69 |
-
# Search route
|
70 |
-
return "answer_climate"
|
71 |
-
|
72 |
-
def chitchat_route_intent(state):
|
73 |
-
intent = state["search_graphs_chitchat"]
|
74 |
-
if intent is True:
|
75 |
-
return "retrieve_graphs_chitchat"
|
76 |
-
elif intent is False:
|
77 |
-
return END
|
78 |
-
|
79 |
-
def route_translation(state):
|
80 |
-
if state["language"].lower() == "english":
|
81 |
-
return "transform_query"
|
82 |
-
else:
|
83 |
-
return "transform_query"
|
84 |
-
# return "translate_query" #TODO : add translation
|
85 |
-
|
86 |
-
|
87 |
-
def route_based_on_relevant_docs(state,threshold_docs=0.2):
|
88 |
-
docs = [x for x in state["documents"] if x.metadata["reranking_score"] > threshold_docs]
|
89 |
-
print("Route : ", ["answer_rag" if len(docs) > 0 else "answer_rag_no_docs"])
|
90 |
-
if len(docs) > 0:
|
91 |
-
return "answer_rag"
|
92 |
-
else:
|
93 |
-
return "answer_rag_no_docs"
|
94 |
-
|
95 |
-
def route_continue_retrieve_documents(state):
|
96 |
-
index_question_ipx = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "IPx"]
|
97 |
-
questions_ipx_finished = all(elem in state["handled_questions_index"] for elem in index_question_ipx)
|
98 |
-
# if questions_ipx_finished and state["search_only"]:
|
99 |
-
# return END
|
100 |
-
if questions_ipx_finished:
|
101 |
-
return "end_retrieve_IPx_documents"
|
102 |
-
else:
|
103 |
-
return "retrieve_documents"
|
104 |
-
|
105 |
-
|
106 |
-
# if state["n_questions"]["IPx"] == len(state["handled_questions_index"]) and state["search_only"] :
|
107 |
-
# return END
|
108 |
-
# elif state["n_questions"]["IPx"] == len(state["handled_questions_index"]):
|
109 |
-
# return "answer_search"
|
110 |
-
# else :
|
111 |
-
# return "retrieve_documents"
|
112 |
-
|
113 |
-
def route_continue_retrieve_local_documents(state):
|
114 |
-
index_question_poc = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"]
|
115 |
-
questions_poc_finished = all(elem in state["handled_questions_index"] for elem in index_question_poc)
|
116 |
-
# if questions_poc_finished and state["search_only"]:
|
117 |
-
# return END
|
118 |
-
if questions_poc_finished or ("POC region" not in state["relevant_content_sources_selection"]):
|
119 |
-
return "end_retrieve_local_documents"
|
120 |
-
else:
|
121 |
-
return "retrieve_local_data"
|
122 |
-
|
123 |
-
# if state["n_questions"]["POC"] == len(state["handled_questions_index"]) and state["search_only"] :
|
124 |
-
# return END
|
125 |
-
# elif state["n_questions"]["POC"] == len(state["handled_questions_index"]):
|
126 |
-
# return "answer_search"
|
127 |
-
# else :
|
128 |
-
# return "retrieve_local_data"
|
129 |
-
|
130 |
-
# if len(state["remaining_questions"]) == 0 and state["search_only"] :
|
131 |
-
# return END
|
132 |
-
# elif len(state["remaining_questions"]) > 0:
|
133 |
-
# return "retrieve_documents"
|
134 |
-
# else:
|
135 |
-
# return "answer_search"
|
136 |
-
|
137 |
-
def route_retrieve_documents(state):
|
138 |
-
sources_to_retrieve = []
|
139 |
-
|
140 |
-
if "Graphs (OurWorldInData)" in state["relevant_content_sources_selection"] :
|
141 |
-
sources_to_retrieve.append("retrieve_graphs")
|
142 |
-
|
143 |
-
if sources_to_retrieve == []:
|
144 |
-
return END
|
145 |
-
return sources_to_retrieve
|
146 |
-
|
147 |
-
def make_id_dict(values):
|
148 |
-
return {k:k for k in values}
|
149 |
-
|
150 |
-
def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_region, reranker, threshold_docs=0.2):
|
151 |
-
|
152 |
-
workflow = StateGraph(GraphState)
|
153 |
-
|
154 |
-
# Define the node functions
|
155 |
-
categorize_intent = make_intent_categorization_node(llm)
|
156 |
-
transform_query = make_query_transform_node(llm)
|
157 |
-
translate_query = make_translation_node(llm)
|
158 |
-
answer_chitchat = make_chitchat_node(llm)
|
159 |
-
answer_ai_impact = make_ai_impact_node(llm)
|
160 |
-
retrieve_documents = make_IPx_retriever_node(vectorstore_ipcc, reranker, llm)
|
161 |
-
retrieve_graphs = make_graph_retriever_node(vectorstore_graphs, reranker)
|
162 |
-
# retrieve_local_data = make_POC_retriever_node(vectorstore_region, reranker, llm)
|
163 |
-
answer_rag = make_rag_node(llm, with_docs=True)
|
164 |
-
answer_rag_no_docs = make_rag_node(llm, with_docs=False)
|
165 |
-
chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
|
166 |
-
|
167 |
-
# Define the nodes
|
168 |
-
# workflow.add_node("set_defaults", set_defaults)
|
169 |
-
workflow.add_node("categorize_intent", categorize_intent)
|
170 |
-
workflow.add_node("answer_climate", dummy)
|
171 |
-
workflow.add_node("answer_search", answer_search)
|
172 |
-
workflow.add_node("transform_query", transform_query)
|
173 |
-
workflow.add_node("translate_query", translate_query)
|
174 |
-
workflow.add_node("answer_chitchat", answer_chitchat)
|
175 |
-
workflow.add_node("chitchat_categorize_intent", chitchat_categorize_intent)
|
176 |
-
workflow.add_node("retrieve_graphs", retrieve_graphs)
|
177 |
-
# workflow.add_node("retrieve_local_data", retrieve_local_data)
|
178 |
-
workflow.add_node("retrieve_graphs_chitchat", retrieve_graphs)
|
179 |
-
workflow.add_node("retrieve_documents", retrieve_documents)
|
180 |
-
workflow.add_node("answer_rag", answer_rag)
|
181 |
-
workflow.add_node("answer_rag_no_docs", answer_rag_no_docs)
|
182 |
-
|
183 |
-
# Entry point
|
184 |
-
workflow.set_entry_point("categorize_intent")
|
185 |
-
|
186 |
-
# CONDITIONAL EDGES
|
187 |
-
workflow.add_conditional_edges(
|
188 |
-
"categorize_intent",
|
189 |
-
route_intent,
|
190 |
-
make_id_dict(["answer_chitchat","answer_climate"])
|
191 |
-
)
|
192 |
-
|
193 |
-
workflow.add_conditional_edges(
|
194 |
-
"chitchat_categorize_intent",
|
195 |
-
chitchat_route_intent,
|
196 |
-
make_id_dict(["retrieve_graphs_chitchat", END])
|
197 |
-
)
|
198 |
-
|
199 |
-
workflow.add_conditional_edges(
|
200 |
-
"answer_climate",
|
201 |
-
route_translation,
|
202 |
-
make_id_dict(["translate_query","transform_query"])
|
203 |
-
)
|
204 |
-
|
205 |
-
workflow.add_conditional_edges(
|
206 |
-
"answer_search",
|
207 |
-
lambda x : route_based_on_relevant_docs(x,threshold_docs=threshold_docs),
|
208 |
-
make_id_dict(["answer_rag","answer_rag_no_docs"])
|
209 |
-
)
|
210 |
-
workflow.add_conditional_edges(
|
211 |
-
"transform_query",
|
212 |
-
route_retrieve_documents,
|
213 |
-
make_id_dict(["retrieve_graphs", END])
|
214 |
-
)
|
215 |
-
|
216 |
-
# Define the edges
|
217 |
-
workflow.add_edge("translate_query", "transform_query")
|
218 |
-
workflow.add_edge("transform_query", "retrieve_documents") #TODO put back
|
219 |
-
# workflow.add_edge("transform_query", "retrieve_local_data")
|
220 |
-
# workflow.add_edge("transform_query", END) # TODO remove
|
221 |
-
|
222 |
-
workflow.add_edge("retrieve_graphs", END)
|
223 |
-
workflow.add_edge("answer_rag", END)
|
224 |
-
workflow.add_edge("answer_rag_no_docs", END)
|
225 |
-
workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
|
226 |
-
workflow.add_edge("retrieve_graphs_chitchat", END)
|
227 |
-
|
228 |
-
# workflow.add_edge("retrieve_local_data", "answer_search")
|
229 |
-
workflow.add_edge("retrieve_documents", "answer_search")
|
230 |
-
|
231 |
-
# Compile
|
232 |
-
app = workflow.compile()
|
233 |
-
return app
|
234 |
-
|
235 |
-
def make_graph_agent_poc(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_region, reranker, threshold_docs=0.2):
|
236 |
-
|
237 |
-
workflow = StateGraph(GraphState)
|
238 |
-
|
239 |
-
# Define the node functions
|
240 |
-
categorize_intent = make_intent_categorization_node(llm)
|
241 |
-
transform_query = make_query_transform_node(llm)
|
242 |
-
translate_query = make_translation_node(llm)
|
243 |
-
answer_chitchat = make_chitchat_node(llm)
|
244 |
-
answer_ai_impact = make_ai_impact_node(llm)
|
245 |
-
retrieve_documents = make_IPx_retriever_node(vectorstore_ipcc, reranker, llm)
|
246 |
-
retrieve_graphs = make_graph_retriever_node(vectorstore_graphs, reranker)
|
247 |
-
retrieve_local_data = make_POC_retriever_node(vectorstore_region, reranker, llm)
|
248 |
-
answer_rag = make_rag_node(llm, with_docs=True)
|
249 |
-
answer_rag_no_docs = make_rag_node(llm, with_docs=False)
|
250 |
-
chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
|
251 |
-
|
252 |
-
# Define the nodes
|
253 |
-
# workflow.add_node("set_defaults", set_defaults)
|
254 |
-
workflow.add_node("categorize_intent", categorize_intent)
|
255 |
-
workflow.add_node("answer_climate", dummy)
|
256 |
-
workflow.add_node("answer_search", answer_search)
|
257 |
-
# workflow.add_node("end_retrieve_local_documents", dummy)
|
258 |
-
# workflow.add_node("end_retrieve_IPx_documents", dummy)
|
259 |
-
workflow.add_node("transform_query", transform_query)
|
260 |
-
workflow.add_node("translate_query", translate_query)
|
261 |
-
workflow.add_node("answer_chitchat", answer_chitchat)
|
262 |
-
workflow.add_node("chitchat_categorize_intent", chitchat_categorize_intent)
|
263 |
-
workflow.add_node("retrieve_graphs", retrieve_graphs)
|
264 |
-
workflow.add_node("retrieve_local_data", retrieve_local_data)
|
265 |
-
workflow.add_node("retrieve_graphs_chitchat", retrieve_graphs)
|
266 |
-
workflow.add_node("retrieve_documents", retrieve_documents)
|
267 |
-
workflow.add_node("answer_rag", answer_rag)
|
268 |
-
workflow.add_node("answer_rag_no_docs", answer_rag_no_docs)
|
269 |
-
|
270 |
-
# Entry point
|
271 |
-
workflow.set_entry_point("categorize_intent")
|
272 |
-
|
273 |
-
# CONDITIONAL EDGES
|
274 |
-
workflow.add_conditional_edges(
|
275 |
-
"categorize_intent",
|
276 |
-
route_intent,
|
277 |
-
make_id_dict(["answer_chitchat","answer_climate"])
|
278 |
-
)
|
279 |
-
|
280 |
-
workflow.add_conditional_edges(
|
281 |
-
"chitchat_categorize_intent",
|
282 |
-
chitchat_route_intent,
|
283 |
-
make_id_dict(["retrieve_graphs_chitchat", END])
|
284 |
-
)
|
285 |
-
|
286 |
-
workflow.add_conditional_edges(
|
287 |
-
"answer_climate",
|
288 |
-
route_translation,
|
289 |
-
make_id_dict(["translate_query","transform_query"])
|
290 |
-
)
|
291 |
-
|
292 |
-
workflow.add_conditional_edges(
|
293 |
-
"answer_search",
|
294 |
-
lambda x : route_based_on_relevant_docs(x,threshold_docs=threshold_docs),
|
295 |
-
make_id_dict(["answer_rag","answer_rag_no_docs"])
|
296 |
-
)
|
297 |
-
workflow.add_conditional_edges(
|
298 |
-
"transform_query",
|
299 |
-
route_retrieve_documents,
|
300 |
-
make_id_dict(["retrieve_graphs", END])
|
301 |
-
)
|
302 |
-
|
303 |
-
# Define the edges
|
304 |
-
workflow.add_edge("translate_query", "transform_query")
|
305 |
-
workflow.add_edge("transform_query", "retrieve_documents") #TODO put back
|
306 |
-
workflow.add_edge("transform_query", "retrieve_local_data")
|
307 |
-
# workflow.add_edge("transform_query", END) # TODO remove
|
308 |
-
|
309 |
-
workflow.add_edge("retrieve_graphs", END)
|
310 |
-
workflow.add_edge("answer_rag", END)
|
311 |
-
workflow.add_edge("answer_rag_no_docs", END)
|
312 |
-
workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
|
313 |
-
workflow.add_edge("retrieve_graphs_chitchat", END)
|
314 |
-
|
315 |
-
workflow.add_edge("retrieve_local_data", "answer_search")
|
316 |
-
workflow.add_edge("retrieve_documents", "answer_search")
|
317 |
-
|
318 |
-
# Compile
|
319 |
-
app = workflow.compile()
|
320 |
-
return app
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
def display_graph(app):
|
326 |
-
|
327 |
-
display(
|
328 |
-
Image(
|
329 |
-
app.get_graph(xray = True).draw_mermaid_png(
|
330 |
-
draw_method=MermaidDrawMethod.API,
|
331 |
-
)
|
332 |
-
)
|
333 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/graph_retriever.py
DELETED
@@ -1,88 +0,0 @@
|
|
1 |
-
from langchain_core.retrievers import BaseRetriever
|
2 |
-
from langchain_core.documents.base import Document
|
3 |
-
from langchain_core.vectorstores import VectorStore
|
4 |
-
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
|
5 |
-
|
6 |
-
from typing import List
|
7 |
-
|
8 |
-
# class GraphRetriever(BaseRetriever):
|
9 |
-
# vectorstore:VectorStore
|
10 |
-
# sources:list = ["OWID"] # plus tard ajouter OurWorldInData # faudra integrate avec l'autre retriever
|
11 |
-
# threshold:float = 0.5
|
12 |
-
# k_total:int = 10
|
13 |
-
|
14 |
-
# def _get_relevant_documents(
|
15 |
-
# self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
16 |
-
# ) -> List[Document]:
|
17 |
-
|
18 |
-
# # Check if all elements in the list are IEA or OWID
|
19 |
-
# assert isinstance(self.sources,list)
|
20 |
-
# assert self.sources
|
21 |
-
# assert any([x in ["OWID"] for x in self.sources])
|
22 |
-
|
23 |
-
# # Prepare base search kwargs
|
24 |
-
# filters = {}
|
25 |
-
|
26 |
-
# filters["source"] = {"$in": self.sources}
|
27 |
-
|
28 |
-
# docs = self.vectorstore.similarity_search_with_score(query=query, filter=filters, k=self.k_total)
|
29 |
-
|
30 |
-
# # Filter if scores are below threshold
|
31 |
-
# docs = [x for x in docs if x[1] > self.threshold]
|
32 |
-
|
33 |
-
# # Remove duplicate documents
|
34 |
-
# unique_docs = []
|
35 |
-
# seen_docs = []
|
36 |
-
# for i, doc in enumerate(docs):
|
37 |
-
# if doc[0].page_content not in seen_docs:
|
38 |
-
# unique_docs.append(doc)
|
39 |
-
# seen_docs.append(doc[0].page_content)
|
40 |
-
|
41 |
-
# # Add score to metadata
|
42 |
-
# results = []
|
43 |
-
# for i,(doc,score) in enumerate(unique_docs):
|
44 |
-
# doc.metadata["similarity_score"] = score
|
45 |
-
# doc.metadata["content"] = doc.page_content
|
46 |
-
# results.append(doc)
|
47 |
-
|
48 |
-
# return results
|
49 |
-
|
50 |
-
async def retrieve_graphs(
|
51 |
-
query: str,
|
52 |
-
vectorstore:VectorStore,
|
53 |
-
sources:list = ["OWID"], # plus tard ajouter OurWorldInData # faudra integrate avec l'autre retriever
|
54 |
-
threshold:float = 0.5,
|
55 |
-
k_total:int = 10,
|
56 |
-
)-> List[Document]:
|
57 |
-
|
58 |
-
# Check if all elements in the list are IEA or OWID
|
59 |
-
assert isinstance(sources,list)
|
60 |
-
assert sources
|
61 |
-
assert any([x in ["OWID"] for x in sources])
|
62 |
-
|
63 |
-
# Prepare base search kwargs
|
64 |
-
filters = {}
|
65 |
-
|
66 |
-
filters["source"] = {"$in": sources}
|
67 |
-
|
68 |
-
docs = vectorstore.similarity_search_with_score(query=query, filter=filters, k=k_total)
|
69 |
-
|
70 |
-
# Filter if scores are below threshold
|
71 |
-
docs = [x for x in docs if x[1] > threshold]
|
72 |
-
|
73 |
-
# Remove duplicate documents
|
74 |
-
unique_docs = []
|
75 |
-
seen_docs = []
|
76 |
-
for i, doc in enumerate(docs):
|
77 |
-
if doc[0].page_content not in seen_docs:
|
78 |
-
unique_docs.append(doc)
|
79 |
-
seen_docs.append(doc[0].page_content)
|
80 |
-
|
81 |
-
# Add score to metadata
|
82 |
-
results = []
|
83 |
-
for i,(doc,score) in enumerate(unique_docs):
|
84 |
-
doc.metadata["similarity_score"] = score
|
85 |
-
doc.metadata["content"] = doc.page_content
|
86 |
-
results.append(doc)
|
87 |
-
|
88 |
-
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/keywords.py
CHANGED
@@ -11,12 +11,10 @@ class KeywordsOutput(BaseModel):
|
|
11 |
|
12 |
keywords: list = Field(
|
13 |
description="""
|
14 |
-
Generate 1 or 2 relevant keywords from the user query to ask a search engine for scientific research papers.
|
15 |
-
Do not use special characters or accents.
|
16 |
|
17 |
Example:
|
18 |
- "What is the impact of deep sea mining ?" -> ["deep sea mining"]
|
19 |
-
- "Quel est l'impact de l'exploitation minière en haute mer ?" -> ["deep sea mining"]
|
20 |
- "How will El Nino be impacted by climate change" -> ["el nino"]
|
21 |
- "Is climate change a hoax" -> [Climate change","hoax"]
|
22 |
"""
|
|
|
11 |
|
12 |
keywords: list = Field(
|
13 |
description="""
|
14 |
+
Generate 1 or 2 relevant keywords from the user query to ask a search engine for scientific research papers.
|
|
|
15 |
|
16 |
Example:
|
17 |
- "What is the impact of deep sea mining ?" -> ["deep sea mining"]
|
|
|
18 |
- "How will El Nino be impacted by climate change" -> ["el nino"]
|
19 |
- "Is climate change a hoax" -> [Climate change","hoax"]
|
20 |
"""
|
climateqa/engine/llm/__init__.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
from climateqa.engine.llm.openai import get_llm as get_openai_llm
|
2 |
from climateqa.engine.llm.azure import get_llm as get_azure_llm
|
3 |
-
from climateqa.engine.llm.ollama import get_llm as get_ollama_llm
|
4 |
|
5 |
|
6 |
def get_llm(provider="openai",**kwargs):
|
@@ -9,8 +8,6 @@ def get_llm(provider="openai",**kwargs):
|
|
9 |
return get_openai_llm(**kwargs)
|
10 |
elif provider == "azure":
|
11 |
return get_azure_llm(**kwargs)
|
12 |
-
elif provider == "ollama":
|
13 |
-
return get_ollama_llm(**kwargs)
|
14 |
else:
|
15 |
raise ValueError(f"Unknown provider: {provider}")
|
16 |
|
|
|
1 |
from climateqa.engine.llm.openai import get_llm as get_openai_llm
|
2 |
from climateqa.engine.llm.azure import get_llm as get_azure_llm
|
|
|
3 |
|
4 |
|
5 |
def get_llm(provider="openai",**kwargs):
|
|
|
8 |
return get_openai_llm(**kwargs)
|
9 |
elif provider == "azure":
|
10 |
return get_azure_llm(**kwargs)
|
|
|
|
|
11 |
else:
|
12 |
raise ValueError(f"Unknown provider: {provider}")
|
13 |
|
climateqa/engine/llm/ollama.py
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
from langchain_community.llms import Ollama
|
4 |
-
|
5 |
-
def get_llm(model="llama3", **kwargs):
|
6 |
-
return Ollama(model=model, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/llm/openai.py
CHANGED
@@ -7,7 +7,7 @@ try:
|
|
7 |
except Exception:
|
8 |
pass
|
9 |
|
10 |
-
def get_llm(model="gpt-
|
11 |
|
12 |
llm = ChatOpenAI(
|
13 |
model=model,
|
|
|
7 |
except Exception:
|
8 |
pass
|
9 |
|
10 |
+
def get_llm(model="gpt-3.5-turbo-0125",max_tokens=1024, temperature=0.0, streaming=True,timeout=30, **kwargs):
|
11 |
|
12 |
llm = ChatOpenAI(
|
13 |
model=model,
|
climateqa/engine/{chains/prompts.py → prompts.py}
RENAMED
@@ -36,40 +36,13 @@ You are given a question and extracted passages of the IPCC and/or IPBES reports
|
|
36 |
"""
|
37 |
|
38 |
|
39 |
-
# answer_prompt_template_old = """
|
40 |
-
# You are ClimateQ&A, an AI Assistant created by Ekimetrics. You are given a question and extracted passages of reports. Provide a clear and structured answer based on the passages provided, the context and the guidelines.
|
41 |
-
|
42 |
-
# Guidelines:
|
43 |
-
# - If the passages have useful facts or numbers, use them in your answer.
|
44 |
-
# - When you use information from a passage, mention where it came from by using [Doc i] at the end of the sentence. i stands for the number of the document.
|
45 |
-
# - Do not use the sentence 'Doc i says ...' to say where information came from.
|
46 |
-
# - If the same thing is said in more than one document, you can mention all of them like this: [Doc i, Doc j, Doc k]
|
47 |
-
# - Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
|
48 |
-
# - If it makes sense, use bullet points and lists to make your answers easier to understand.
|
49 |
-
# - You do not need to use every passage. Only use the ones that help answer the question.
|
50 |
-
# - If the documents do not have the information needed to answer the question, just say you do not have enough information.
|
51 |
-
# - Consider by default that the question is about the past century unless it is specified otherwise.
|
52 |
-
# - If the passage is the caption of a picture, you can still use it as part of your answer as any other document.
|
53 |
-
|
54 |
-
# -----------------------
|
55 |
-
# Passages:
|
56 |
-
# {context}
|
57 |
-
|
58 |
-
# -----------------------
|
59 |
-
# Question: {query} - Explained to {audience}
|
60 |
-
# Answer in {language} with the passages citations:
|
61 |
-
# """
|
62 |
-
|
63 |
answer_prompt_template = """
|
64 |
-
You are ClimateQ&A, an AI Assistant created by Ekimetrics. You are given a question and extracted passages of reports. Provide a clear and structured answer based on the passages provided, the context and the guidelines.
|
65 |
|
66 |
Guidelines:
|
67 |
- If the passages have useful facts or numbers, use them in your answer.
|
68 |
- When you use information from a passage, mention where it came from by using [Doc i] at the end of the sentence. i stands for the number of the document.
|
69 |
-
-
|
70 |
-
- The different sources are IPCC, IPBES, PPCP (for Plan Climat Air Energie Territorial de Paris), PBDP (for Plan Biodiversité de Paris), Acclimaterra.
|
71 |
-
- Do not mention that you are using specific extract documents, but mention only the source information. "According to IPCC, ..." rather than "According to the provided document from IPCC ..."
|
72 |
-
- Make a clear distinction between information from IPCC, IPBES, Acclimaterra that are scientific reports and PPCP, PBDP that are strategic reports. Strategic reports should not be taken has verified facts, but as political or strategic decisions.
|
73 |
- If the same thing is said in more than one document, you can mention all of them like this: [Doc i, Doc j, Doc k]
|
74 |
- Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
|
75 |
- If it makes sense, use bullet points and lists to make your answers easier to understand.
|
@@ -78,16 +51,16 @@ Guidelines:
|
|
78 |
- Consider by default that the question is about the past century unless it is specified otherwise.
|
79 |
- If the passage is the caption of a picture, you can still use it as part of your answer as any other document.
|
80 |
|
81 |
-
|
82 |
-----------------------
|
83 |
Passages:
|
84 |
{context}
|
85 |
|
86 |
-----------------------
|
87 |
-
Question: {
|
88 |
Answer in {language} with the passages citations:
|
89 |
"""
|
90 |
|
|
|
91 |
papers_prompt_template = """
|
92 |
You are ClimateQ&A, an AI Assistant created by Ekimetrics. You are given a question and extracted abstracts of scientific papers. Provide a clear and structured answer based on the abstracts provided, the context and the guidelines.
|
93 |
|
@@ -164,7 +137,7 @@ Guidelines:
|
|
164 |
- If the question is not related to environmental issues, never never answer it. Say it's not your role.
|
165 |
- Make paragraphs by starting new lines to make your answers more readable.
|
166 |
|
167 |
-
Question: {
|
168 |
Answer in {language}:
|
169 |
"""
|
170 |
|
@@ -174,27 +147,4 @@ audience_prompts = {
|
|
174 |
"children": "6 year old children that don't know anything about science and climate change and need metaphors to learn",
|
175 |
"general": "the general public who know the basics in science and climate change and want to learn more about it without technical terms. Still use references to passages.",
|
176 |
"experts": "expert and climate scientists that are not afraid of technical terms",
|
177 |
-
}
|
178 |
-
|
179 |
-
|
180 |
-
answer_prompt_graph_template = """
|
181 |
-
Given the user question and a list of graphs which are related to the question, rank the graphs based on relevance to the user question. ALWAYS follow the guidelines given below.
|
182 |
-
|
183 |
-
### Guidelines ###
|
184 |
-
- Keep all the graphs that are given to you.
|
185 |
-
- NEVER modify the graph HTML embedding, the category or the source leave them exactly as they are given.
|
186 |
-
- Return the ranked graphs as a list of dictionaries with keys 'embedding', 'category', and 'source'.
|
187 |
-
- Return a valid JSON output.
|
188 |
-
|
189 |
-
-----------------------
|
190 |
-
User question:
|
191 |
-
{query}
|
192 |
-
|
193 |
-
Graphs and their HTML embedding:
|
194 |
-
{recommended_content}
|
195 |
-
|
196 |
-
-----------------------
|
197 |
-
{format_instructions}
|
198 |
-
|
199 |
-
Output the result as json with a key "graphs" containing a list of dictionaries of the relevant graphs with keys 'embedding', 'category', and 'source'. Do not modify the graph HTML embedding, the category or the source. Do not put any message or text before or after the JSON output.
|
200 |
-
"""
|
|
|
36 |
"""
|
37 |
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
answer_prompt_template = """
|
40 |
+
You are ClimateQ&A, an AI Assistant created by Ekimetrics. You are given a question and extracted passages of the IPCC and/or IPBES reports. Provide a clear and structured answer based on the passages provided, the context and the guidelines.
|
41 |
|
42 |
Guidelines:
|
43 |
- If the passages have useful facts or numbers, use them in your answer.
|
44 |
- When you use information from a passage, mention where it came from by using [Doc i] at the end of the sentence. i stands for the number of the document.
|
45 |
+
- Do not use the sentence 'Doc i says ...' to say where information came from.
|
|
|
|
|
|
|
46 |
- If the same thing is said in more than one document, you can mention all of them like this: [Doc i, Doc j, Doc k]
|
47 |
- Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
|
48 |
- If it makes sense, use bullet points and lists to make your answers easier to understand.
|
|
|
51 |
- Consider by default that the question is about the past century unless it is specified otherwise.
|
52 |
- If the passage is the caption of a picture, you can still use it as part of your answer as any other document.
|
53 |
|
|
|
54 |
-----------------------
|
55 |
Passages:
|
56 |
{context}
|
57 |
|
58 |
-----------------------
|
59 |
+
Question: {question} - Explained to {audience}
|
60 |
Answer in {language} with the passages citations:
|
61 |
"""
|
62 |
|
63 |
+
|
64 |
papers_prompt_template = """
|
65 |
You are ClimateQ&A, an AI Assistant created by Ekimetrics. You are given a question and extracted abstracts of scientific papers. Provide a clear and structured answer based on the abstracts provided, the context and the guidelines.
|
66 |
|
|
|
137 |
- If the question is not related to environmental issues, never never answer it. Say it's not your role.
|
138 |
- Make paragraphs by starting new lines to make your answers more readable.
|
139 |
|
140 |
+
Question: {question}
|
141 |
Answer in {language}:
|
142 |
"""
|
143 |
|
|
|
147 |
"children": "6 year old children that don't know anything about science and climate change and need metaphors to learn",
|
148 |
"general": "the general public who know the basics in science and climate change and want to learn more about it without technical terms. Still use references to passages.",
|
149 |
"experts": "expert and climate scientists that are not afraid of technical terms",
|
150 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/{chains/answer_rag.py → rag.py}
RENAMED
@@ -2,16 +2,17 @@ from operator import itemgetter
|
|
2 |
|
3 |
from langchain_core.prompts import ChatPromptTemplate
|
4 |
from langchain_core.output_parsers import StrOutputParser
|
|
|
5 |
from langchain_core.prompts.prompt import PromptTemplate
|
6 |
from langchain_core.prompts.base import format_document
|
7 |
|
8 |
-
from climateqa.engine.
|
9 |
-
from climateqa.engine.
|
10 |
-
import
|
11 |
-
from
|
|
|
12 |
|
13 |
-
|
14 |
-
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="Source : {source} - Content : {page_content}")
|
15 |
|
16 |
def _combine_documents(
|
17 |
docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, sep="\n\n"
|
@@ -39,52 +40,72 @@ def get_text_docs(x):
|
|
39 |
def get_image_docs(x):
|
40 |
return [doc for doc in x if doc.metadata["chunk_type"] == "image"]
|
41 |
|
42 |
-
|
|
|
|
|
|
|
43 |
prompt = ChatPromptTemplate.from_template(answer_prompt_template)
|
44 |
-
|
45 |
-
"context":lambda x : _combine_documents(x["documents"]),
|
46 |
-
"context_length":lambda x : print("CONTEXT LENGTH : " , len(_combine_documents(x["documents"]))),
|
47 |
-
"query":itemgetter("query"),
|
48 |
-
"language":itemgetter("language"),
|
49 |
-
"audience":itemgetter("audience"),
|
50 |
-
} | prompt | llm | StrOutputParser())
|
51 |
-
return chain
|
52 |
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
return chain
|
57 |
|
58 |
-
|
|
|
|
|
|
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
async def answer_rag(state,config):
|
66 |
-
print("---- Answer RAG ----")
|
67 |
-
start_time = time.time()
|
68 |
-
print("Sources used : " + "\n".join([x.metadata["short_name"] + " - page " + str(x.metadata["page_number"]) for x in state["documents"]]))
|
69 |
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
-
|
73 |
-
|
74 |
-
print("RAG elapsed time: ", elapsed_time)
|
75 |
-
print("Answer size : ", len(answer))
|
76 |
-
# print(f"\n\nAnswer:\n{answer}")
|
77 |
-
|
78 |
-
return {"answer":answer}
|
79 |
|
80 |
-
|
|
|
|
|
|
|
81 |
|
82 |
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
|
85 |
def make_rag_papers_chain(llm):
|
86 |
|
87 |
prompt = ChatPromptTemplate.from_template(papers_prompt_template)
|
|
|
88 |
input_documents = {
|
89 |
"context":lambda x : _combine_documents(x["docs"]),
|
90 |
**pass_values(["question","language"])
|
@@ -110,4 +131,4 @@ def make_illustration_chain(llm):
|
|
110 |
}
|
111 |
|
112 |
illustration_chain = input_description_images | prompt_with_images | llm | StrOutputParser()
|
113 |
-
return illustration_chain
|
|
|
2 |
|
3 |
from langchain_core.prompts import ChatPromptTemplate
|
4 |
from langchain_core.output_parsers import StrOutputParser
|
5 |
+
from langchain_core.runnables import RunnablePassthrough, RunnableLambda, RunnableBranch
|
6 |
from langchain_core.prompts.prompt import PromptTemplate
|
7 |
from langchain_core.prompts.base import format_document
|
8 |
|
9 |
+
from climateqa.engine.reformulation import make_reformulation_chain
|
10 |
+
from climateqa.engine.prompts import answer_prompt_template,answer_prompt_without_docs_template,answer_prompt_images_template
|
11 |
+
from climateqa.engine.prompts import papers_prompt_template
|
12 |
+
from climateqa.engine.utils import pass_values, flatten_dict,prepare_chain,rename_chain
|
13 |
+
from climateqa.engine.keywords import make_keywords_chain
|
14 |
|
15 |
+
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
|
|
|
16 |
|
17 |
def _combine_documents(
|
18 |
docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, sep="\n\n"
|
|
|
40 |
def get_image_docs(x):
|
41 |
return [doc for doc in x if doc.metadata["chunk_type"] == "image"]
|
42 |
|
43 |
+
|
44 |
+
def make_rag_chain(retriever,llm):
|
45 |
+
|
46 |
+
# Construct the prompt
|
47 |
prompt = ChatPromptTemplate.from_template(answer_prompt_template)
|
48 |
+
prompt_without_docs = ChatPromptTemplate.from_template(answer_prompt_without_docs_template)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
+
# ------- CHAIN 0 - Reformulation
|
51 |
+
reformulation = make_reformulation_chain(llm)
|
52 |
+
reformulation = prepare_chain(reformulation,"reformulation")
|
|
|
53 |
|
54 |
+
# ------- Find all keywords from the reformulated query
|
55 |
+
keywords = make_keywords_chain(llm)
|
56 |
+
keywords = {"keywords":itemgetter("question") | keywords}
|
57 |
+
keywords = prepare_chain(keywords,"keywords")
|
58 |
|
59 |
+
# ------- CHAIN 1
|
60 |
+
# Retrieved documents
|
61 |
+
find_documents = {"docs": itemgetter("question") | retriever} | RunnablePassthrough()
|
62 |
+
find_documents = prepare_chain(find_documents,"find_documents")
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
+
# ------- CHAIN 2
|
65 |
+
# Construct inputs for the llm
|
66 |
+
input_documents = {
|
67 |
+
"context":lambda x : _combine_documents(x["docs"]),
|
68 |
+
**pass_values(["question","audience","language","keywords"])
|
69 |
+
}
|
70 |
+
|
71 |
+
# ------- CHAIN 3
|
72 |
+
# Bot answer
|
73 |
+
llm_final = rename_chain(llm,"answer")
|
74 |
+
|
75 |
+
answer_with_docs = {
|
76 |
+
"answer": input_documents | prompt | llm_final | StrOutputParser(),
|
77 |
+
**pass_values(["question","audience","language","query","docs","keywords"]),
|
78 |
+
}
|
79 |
+
|
80 |
+
answer_without_docs = {
|
81 |
+
"answer": prompt_without_docs | llm_final | StrOutputParser(),
|
82 |
+
**pass_values(["question","audience","language","query","docs","keywords"]),
|
83 |
+
}
|
84 |
+
|
85 |
+
# def has_images(x):
|
86 |
+
# image_docs = [doc for doc in x["docs"] if doc.metadata["chunk_type"]=="image"]
|
87 |
+
# return len(image_docs) > 0
|
88 |
|
89 |
+
def has_docs(x):
|
90 |
+
return len(x["docs"]) > 0
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
+
answer = RunnableBranch(
|
93 |
+
(lambda x: has_docs(x), answer_with_docs),
|
94 |
+
answer_without_docs,
|
95 |
+
)
|
96 |
|
97 |
|
98 |
+
# ------- FINAL CHAIN
|
99 |
+
# Build the final chain
|
100 |
+
rag_chain = reformulation | keywords | find_documents | answer
|
101 |
+
|
102 |
+
return rag_chain
|
103 |
|
104 |
|
105 |
def make_rag_papers_chain(llm):
|
106 |
|
107 |
prompt = ChatPromptTemplate.from_template(papers_prompt_template)
|
108 |
+
|
109 |
input_documents = {
|
110 |
"context":lambda x : _combine_documents(x["docs"]),
|
111 |
**pass_values(["question","language"])
|
|
|
131 |
}
|
132 |
|
133 |
illustration_chain = input_description_images | prompt_with_images | llm | StrOutputParser()
|
134 |
+
return illustration_chain
|
climateqa/engine/{chains/reformulation.py → reformulation.py}
RENAMED
@@ -3,7 +3,7 @@ from langchain.output_parsers.structured import StructuredOutputParser, Response
|
|
3 |
from langchain_core.prompts import PromptTemplate
|
4 |
from langchain_core.runnables import RunnablePassthrough, RunnableLambda, RunnableBranch
|
5 |
|
6 |
-
from climateqa.engine.
|
7 |
from climateqa.engine.utils import pass_values, flatten_dict
|
8 |
|
9 |
|
|
|
3 |
from langchain_core.prompts import PromptTemplate
|
4 |
from langchain_core.runnables import RunnablePassthrough, RunnableLambda, RunnableBranch
|
5 |
|
6 |
+
from climateqa.engine.prompts import reformulation_prompt_template
|
7 |
from climateqa.engine.utils import pass_values, flatten_dict
|
8 |
|
9 |
|
climateqa/engine/reranker.py
DELETED
@@ -1,55 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from dotenv import load_dotenv
|
3 |
-
from scipy.special import expit, logit
|
4 |
-
from rerankers import Reranker
|
5 |
-
from sentence_transformers import CrossEncoder
|
6 |
-
|
7 |
-
load_dotenv()
|
8 |
-
|
9 |
-
def get_reranker(model = "nano", cohere_api_key = None):
|
10 |
-
|
11 |
-
assert model in ["nano","tiny","small","large", "jina"]
|
12 |
-
|
13 |
-
if model == "nano":
|
14 |
-
reranker = Reranker('ms-marco-TinyBERT-L-2-v2', model_type='flashrank')
|
15 |
-
elif model == "tiny":
|
16 |
-
reranker = Reranker('ms-marco-MiniLM-L-12-v2', model_type='flashrank')
|
17 |
-
elif model == "small":
|
18 |
-
reranker = Reranker("mixedbread-ai/mxbai-rerank-xsmall-v1", model_type='cross-encoder')
|
19 |
-
elif model == "large":
|
20 |
-
if cohere_api_key is None:
|
21 |
-
cohere_api_key = os.environ["COHERE_API_KEY"]
|
22 |
-
reranker = Reranker("cohere", lang='en', api_key = cohere_api_key)
|
23 |
-
elif model == "jina":
|
24 |
-
# Reached token quota so does not work
|
25 |
-
reranker = Reranker("jina-reranker-v2-base-multilingual", api_key = os.getenv("JINA_RERANKER_API_KEY"))
|
26 |
-
# marche pas sans gpu ? et anyways returns with another structure donc faudrait changer le code du retriever node
|
27 |
-
# reranker = CrossEncoder("jinaai/jina-reranker-v2-base-multilingual", automodel_args={"torch_dtype": "auto"}, trust_remote_code=True,)
|
28 |
-
return reranker
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
def rerank_docs(reranker,docs,query):
|
33 |
-
if docs == []:
|
34 |
-
return []
|
35 |
-
|
36 |
-
# Get a list of texts from langchain docs
|
37 |
-
input_docs = [x.page_content for x in docs]
|
38 |
-
|
39 |
-
# Rerank using rerankers library
|
40 |
-
results = reranker.rank(query=query, docs=input_docs)
|
41 |
-
|
42 |
-
# Prepare langchain list of docs
|
43 |
-
docs_reranked = []
|
44 |
-
for result in results.results:
|
45 |
-
doc_id = result.document.doc_id
|
46 |
-
doc = docs[doc_id]
|
47 |
-
doc.metadata["reranking_score"] = result.score
|
48 |
-
doc.metadata["query_used_for_retrieval"] = query
|
49 |
-
docs_reranked.append(doc)
|
50 |
-
return docs_reranked
|
51 |
-
|
52 |
-
def rerank_and_sort_docs(reranker, docs, query):
|
53 |
-
docs_reranked = rerank_docs(reranker,docs,query)
|
54 |
-
docs_reranked = sorted(docs_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
|
55 |
-
return docs_reranked
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/retriever.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/langchain-ai/langchain/issues/8623
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
|
5 |
+
from langchain_core.retrievers import BaseRetriever
|
6 |
+
from langchain_core.vectorstores import VectorStoreRetriever
|
7 |
+
from langchain_core.documents.base import Document
|
8 |
+
from langchain_core.vectorstores import VectorStore
|
9 |
+
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
|
10 |
+
|
11 |
+
from typing import List
|
12 |
+
from pydantic import Field
|
13 |
+
|
14 |
+
class ClimateQARetriever(BaseRetriever):
|
15 |
+
vectorstore:VectorStore
|
16 |
+
sources:list = ["IPCC","IPBES","IPOS"]
|
17 |
+
reports:list = []
|
18 |
+
threshold:float = 0.6
|
19 |
+
k_summary:int = 3
|
20 |
+
k_total:int = 10
|
21 |
+
namespace:str = "vectors",
|
22 |
+
min_size:int = 200,
|
23 |
+
|
24 |
+
|
25 |
+
def _get_relevant_documents(
|
26 |
+
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
27 |
+
) -> List[Document]:
|
28 |
+
|
29 |
+
# Check if all elements in the list are either IPCC or IPBES
|
30 |
+
assert isinstance(self.sources,list)
|
31 |
+
assert all([x in ["IPCC","IPBES","IPOS"] for x in self.sources])
|
32 |
+
assert self.k_total > self.k_summary, "k_total should be greater than k_summary"
|
33 |
+
|
34 |
+
# Prepare base search kwargs
|
35 |
+
filters = {}
|
36 |
+
|
37 |
+
if len(self.reports) > 0:
|
38 |
+
filters["short_name"] = {"$in":self.reports}
|
39 |
+
else:
|
40 |
+
filters["source"] = { "$in":self.sources}
|
41 |
+
|
42 |
+
# Search for k_summary documents in the summaries dataset
|
43 |
+
filters_summaries = {
|
44 |
+
**filters,
|
45 |
+
"report_type": { "$in":["SPM"]},
|
46 |
+
}
|
47 |
+
|
48 |
+
docs_summaries = self.vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = self.k_summary)
|
49 |
+
docs_summaries = [x for x in docs_summaries if x[1] > self.threshold]
|
50 |
+
|
51 |
+
# Search for k_total - k_summary documents in the full reports dataset
|
52 |
+
filters_full = {
|
53 |
+
**filters,
|
54 |
+
"report_type": { "$nin":["SPM"]},
|
55 |
+
}
|
56 |
+
k_full = self.k_total - len(docs_summaries)
|
57 |
+
docs_full = self.vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_full)
|
58 |
+
|
59 |
+
# Concatenate documents
|
60 |
+
docs = docs_summaries + docs_full
|
61 |
+
|
62 |
+
# Filter if scores are below threshold
|
63 |
+
docs = [x for x in docs if len(x[0].page_content) > self.min_size]
|
64 |
+
# docs = [x for x in docs if x[1] > self.threshold]
|
65 |
+
|
66 |
+
# Add score to metadata
|
67 |
+
results = []
|
68 |
+
for i,(doc,score) in enumerate(docs):
|
69 |
+
doc.metadata["similarity_score"] = score
|
70 |
+
doc.metadata["content"] = doc.page_content
|
71 |
+
doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
|
72 |
+
# doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
|
73 |
+
results.append(doc)
|
74 |
+
|
75 |
+
# Sort by score
|
76 |
+
# results = sorted(results,key = lambda x : x.metadata["similarity_score"],reverse = True)
|
77 |
+
|
78 |
+
return results
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
# def filter_summaries(df,k_summary = 3,k_total = 10):
|
84 |
+
# # assert source in ["IPCC","IPBES","ALL"], "source arg should be in (IPCC,IPBES,ALL)"
|
85 |
+
|
86 |
+
# # # Filter by source
|
87 |
+
# # if source == "IPCC":
|
88 |
+
# # df = df.loc[df["source"]=="IPCC"]
|
89 |
+
# # elif source == "IPBES":
|
90 |
+
# # df = df.loc[df["source"]=="IPBES"]
|
91 |
+
# # else:
|
92 |
+
# # pass
|
93 |
+
|
94 |
+
# # Separate summaries and full reports
|
95 |
+
# df_summaries = df.loc[df["report_type"].isin(["SPM","TS"])]
|
96 |
+
# df_full = df.loc[~df["report_type"].isin(["SPM","TS"])]
|
97 |
+
|
98 |
+
# # Find passages from summaries dataset
|
99 |
+
# passages_summaries = df_summaries.head(k_summary)
|
100 |
+
|
101 |
+
# # Find passages from full reports dataset
|
102 |
+
# passages_fullreports = df_full.head(k_total - len(passages_summaries))
|
103 |
+
|
104 |
+
# # Concatenate passages
|
105 |
+
# passages = pd.concat([passages_summaries,passages_fullreports],axis = 0,ignore_index = True)
|
106 |
+
# return passages
|
107 |
+
|
108 |
+
|
109 |
+
|
110 |
+
|
111 |
+
# def retrieve_with_summaries(query,retriever,k_summary = 3,k_total = 10,sources = ["IPCC","IPBES"],max_k = 100,threshold = 0.555,as_dict = True,min_length = 300):
|
112 |
+
# assert max_k > k_total
|
113 |
+
|
114 |
+
# validated_sources = ["IPCC","IPBES"]
|
115 |
+
# sources = [x for x in sources if x in validated_sources]
|
116 |
+
# filters = {
|
117 |
+
# "source": { "$in": sources },
|
118 |
+
# }
|
119 |
+
# print(filters)
|
120 |
+
|
121 |
+
# # Retrieve documents
|
122 |
+
# docs = retriever.retrieve(query,top_k = max_k,filters = filters)
|
123 |
+
|
124 |
+
# # Filter by score
|
125 |
+
# docs = [{**x.meta,"score":x.score,"content":x.content} for x in docs if x.score > threshold]
|
126 |
+
|
127 |
+
# if len(docs) == 0:
|
128 |
+
# return []
|
129 |
+
# res = pd.DataFrame(docs)
|
130 |
+
# passages_df = filter_summaries(res,k_summary,k_total)
|
131 |
+
# if as_dict:
|
132 |
+
# contents = passages_df["content"].tolist()
|
133 |
+
# meta = passages_df.drop(columns = ["content"]).to_dict(orient = "records")
|
134 |
+
# passages = []
|
135 |
+
# for i in range(len(contents)):
|
136 |
+
# passages.append({"content":contents[i],"meta":meta[i]})
|
137 |
+
# return passages
|
138 |
+
# else:
|
139 |
+
# return passages_df
|
140 |
+
|
141 |
+
|
142 |
+
|
143 |
+
# def retrieve(query,sources = ["IPCC"],threshold = 0.555,k = 10):
|
144 |
+
|
145 |
+
|
146 |
+
# print("hellooooo")
|
147 |
+
|
148 |
+
# # Reformulate queries
|
149 |
+
# reformulated_query,language = reformulate(query)
|
150 |
+
|
151 |
+
# print(reformulated_query)
|
152 |
+
|
153 |
+
# # Retrieve documents
|
154 |
+
# passages = retrieve_with_summaries(reformulated_query,retriever,k_total = k,k_summary = 3,as_dict = True,sources = sources,threshold = threshold)
|
155 |
+
# response = {
|
156 |
+
# "query":query,
|
157 |
+
# "reformulated_query":reformulated_query,
|
158 |
+
# "language":language,
|
159 |
+
# "sources":passages,
|
160 |
+
# "prompts":{"init_prompt":init_prompt,"sources_prompt":sources_prompt},
|
161 |
+
# }
|
162 |
+
# return response
|
163 |
+
|
climateqa/engine/utils.py
CHANGED
@@ -1,15 +1,8 @@
|
|
1 |
from operator import itemgetter
|
2 |
from typing import Any, Dict, Iterable, Tuple
|
3 |
-
import tiktoken
|
4 |
from langchain_core.runnables import RunnablePassthrough
|
5 |
|
6 |
|
7 |
-
def num_tokens_from_string(string: str, encoding_name: str = "cl100k_base") -> int:
|
8 |
-
encoding = tiktoken.get_encoding(encoding_name)
|
9 |
-
num_tokens = len(encoding.encode(string))
|
10 |
-
return num_tokens
|
11 |
-
|
12 |
-
|
13 |
def pass_values(x):
|
14 |
if not isinstance(x, list):
|
15 |
x = [x]
|
@@ -74,13 +67,3 @@ def flatten_dict(
|
|
74 |
"""
|
75 |
flat_dict = {k: v for k, v in _flatten_dict(nested_dict, parent_key, sep)}
|
76 |
return flat_dict
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
async def log_event(info,name,config):
|
81 |
-
"""Helper function that will run a dummy chain with the given info
|
82 |
-
The astream_event function will catch this chain and stream the dict info to the logger
|
83 |
-
"""
|
84 |
-
|
85 |
-
chain = RunnablePassthrough().with_config(run_name=name)
|
86 |
-
_ = await chain.ainvoke(info,config)
|
|
|
1 |
from operator import itemgetter
|
2 |
from typing import Any, Dict, Iterable, Tuple
|
|
|
3 |
from langchain_core.runnables import RunnablePassthrough
|
4 |
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
def pass_values(x):
|
7 |
if not isinstance(x, list):
|
8 |
x = [x]
|
|
|
67 |
"""
|
68 |
flat_dict = {k: v for k, v in _flatten_dict(nested_dict, parent_key, sep)}
|
69 |
return flat_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/vectorstore.py
CHANGED
@@ -13,9 +13,7 @@ except:
|
|
13 |
pass
|
14 |
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
def get_pinecone_vectorstore(embeddings,text_key = "content", index_name = os.getenv("PINECONE_API_INDEX")):
|
19 |
|
20 |
# # initialize pinecone
|
21 |
# pinecone.init(
|
@@ -29,7 +27,7 @@ def get_pinecone_vectorstore(embeddings,text_key = "content", index_name = os.ge
|
|
29 |
# return vectorstore
|
30 |
|
31 |
pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
|
32 |
-
index = pc.Index(
|
33 |
|
34 |
vectorstore = PineconeVectorstore(
|
35 |
index, embeddings, text_key,
|
|
|
13 |
pass
|
14 |
|
15 |
|
16 |
+
def get_pinecone_vectorstore(embeddings,text_key = "content"):
|
|
|
|
|
17 |
|
18 |
# # initialize pinecone
|
19 |
# pinecone.init(
|
|
|
27 |
# return vectorstore
|
28 |
|
29 |
pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
|
30 |
+
index = pc.Index(os.getenv("PINECONE_API_INDEX"))
|
31 |
|
32 |
vectorstore = PineconeVectorstore(
|
33 |
index, embeddings, text_key,
|
climateqa/handle_stream_events.py
DELETED
@@ -1,126 +0,0 @@
|
|
1 |
-
from langchain_core.runnables.schema import StreamEvent
|
2 |
-
from gradio import ChatMessage
|
3 |
-
from climateqa.engine.chains.prompts import audience_prompts
|
4 |
-
from front.utils import make_html_source,parse_output_llm_with_sources,serialize_docs,make_toolbox,generate_html_graphs
|
5 |
-
import numpy as np
|
6 |
-
|
7 |
-
def init_audience(audience :str) -> str:
|
8 |
-
if audience == "Children":
|
9 |
-
audience_prompt = audience_prompts["children"]
|
10 |
-
elif audience == "General public":
|
11 |
-
audience_prompt = audience_prompts["general"]
|
12 |
-
elif audience == "Experts":
|
13 |
-
audience_prompt = audience_prompts["experts"]
|
14 |
-
else:
|
15 |
-
audience_prompt = audience_prompts["experts"]
|
16 |
-
return audience_prompt
|
17 |
-
|
18 |
-
def convert_to_docs_to_html(docs: list[dict]) -> str:
|
19 |
-
docs_html = []
|
20 |
-
for i, d in enumerate(docs, 1):
|
21 |
-
if d.metadata["chunk_type"] == "text":
|
22 |
-
docs_html.append(make_html_source(d, i))
|
23 |
-
return "".join(docs_html)
|
24 |
-
|
25 |
-
def handle_retrieved_documents(event: StreamEvent, history : list[ChatMessage], used_documents : list[str],related_content:list[str]) -> tuple[str, list[ChatMessage], list[str]]:
|
26 |
-
"""
|
27 |
-
Handles the retrieved documents and returns the HTML representation of the documents
|
28 |
-
|
29 |
-
Args:
|
30 |
-
event (StreamEvent): The event containing the retrieved documents
|
31 |
-
history (list[ChatMessage]): The current message history
|
32 |
-
used_documents (list[str]): The list of used documents
|
33 |
-
|
34 |
-
Returns:
|
35 |
-
tuple[str, list[ChatMessage], list[str]]: The updated HTML representation of the documents, the updated message history and the updated list of used documents
|
36 |
-
"""
|
37 |
-
if "documents" not in event["data"]["output"] or event["data"]["output"]["documents"] == []:
|
38 |
-
return history, used_documents, related_content
|
39 |
-
|
40 |
-
try:
|
41 |
-
docs = event["data"]["output"]["documents"]
|
42 |
-
|
43 |
-
used_documents = used_documents + [f"{d.metadata['short_name']} - {d.metadata['name']}" for d in docs]
|
44 |
-
if used_documents!=[]:
|
45 |
-
history[-1].content = "Adding sources :\n\n - " + "\n - ".join(np.unique(used_documents))
|
46 |
-
|
47 |
-
#TODO do the same for related contents
|
48 |
-
|
49 |
-
except Exception as e:
|
50 |
-
print(f"Error getting documents: {e}")
|
51 |
-
print(event)
|
52 |
-
return history, used_documents, related_content
|
53 |
-
|
54 |
-
def stream_answer(history: list[ChatMessage], event : StreamEvent, start_streaming : bool, answer_message_content : str)-> tuple[list[ChatMessage], bool, str]:
|
55 |
-
"""
|
56 |
-
Handles the streaming of the answer and updates the history with the new message content
|
57 |
-
|
58 |
-
Args:
|
59 |
-
history (list[ChatMessage]): The current message history
|
60 |
-
event (StreamEvent): The event containing the streamed answer
|
61 |
-
start_streaming (bool): A flag indicating if the streaming has started
|
62 |
-
new_message_content (str): The content of the new message
|
63 |
-
|
64 |
-
Returns:
|
65 |
-
tuple[list[ChatMessage], bool, str]: The updated history, the updated streaming flag and the updated message content
|
66 |
-
"""
|
67 |
-
if start_streaming == False:
|
68 |
-
start_streaming = True
|
69 |
-
history.append(ChatMessage(role="assistant", content = ""))
|
70 |
-
answer_message_content += event["data"]["chunk"].content
|
71 |
-
answer_message_content = parse_output_llm_with_sources(answer_message_content)
|
72 |
-
history[-1] = ChatMessage(role="assistant", content = answer_message_content)
|
73 |
-
# history.append(ChatMessage(role="assistant", content = new_message_content))
|
74 |
-
return history, start_streaming, answer_message_content
|
75 |
-
|
76 |
-
def handle_retrieved_owid_graphs(event :StreamEvent, graphs_html: str) -> str:
|
77 |
-
"""
|
78 |
-
Handles the retrieved OWID graphs and returns the HTML representation of the graphs
|
79 |
-
|
80 |
-
Args:
|
81 |
-
event (StreamEvent): The event containing the retrieved graphs
|
82 |
-
graphs_html (str): The current HTML representation of the graphs
|
83 |
-
|
84 |
-
Returns:
|
85 |
-
str: The updated HTML representation
|
86 |
-
"""
|
87 |
-
try:
|
88 |
-
recommended_content = event["data"]["output"]["recommended_content"]
|
89 |
-
|
90 |
-
unique_graphs = []
|
91 |
-
seen_embeddings = set()
|
92 |
-
|
93 |
-
for x in recommended_content:
|
94 |
-
embedding = x.metadata["returned_content"]
|
95 |
-
|
96 |
-
# Check if the embedding has already been seen
|
97 |
-
if embedding not in seen_embeddings:
|
98 |
-
unique_graphs.append({
|
99 |
-
"embedding": embedding,
|
100 |
-
"metadata": {
|
101 |
-
"source": x.metadata["source"],
|
102 |
-
"category": x.metadata["category"]
|
103 |
-
}
|
104 |
-
})
|
105 |
-
# Add the embedding to the seen set
|
106 |
-
seen_embeddings.add(embedding)
|
107 |
-
|
108 |
-
|
109 |
-
categories = {}
|
110 |
-
for graph in unique_graphs:
|
111 |
-
category = graph['metadata']['category']
|
112 |
-
if category not in categories:
|
113 |
-
categories[category] = []
|
114 |
-
categories[category].append(graph['embedding'])
|
115 |
-
|
116 |
-
|
117 |
-
for category, embeddings in categories.items():
|
118 |
-
graphs_html += f"<h3>{category}</h3>"
|
119 |
-
for embedding in embeddings:
|
120 |
-
graphs_html += f"<div>{embedding}</div>"
|
121 |
-
|
122 |
-
|
123 |
-
except Exception as e:
|
124 |
-
print(f"Error getting graphs: {e}")
|
125 |
-
|
126 |
-
return graphs_html
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/knowledge/__init__.py
DELETED
File without changes
|
climateqa/knowledge/retriever.py
DELETED
@@ -1,102 +0,0 @@
|
|
1 |
-
# # https://github.com/langchain-ai/langchain/issues/8623
|
2 |
-
|
3 |
-
# import pandas as pd
|
4 |
-
|
5 |
-
# from langchain_core.retrievers import BaseRetriever
|
6 |
-
# from langchain_core.vectorstores import VectorStoreRetriever
|
7 |
-
# from langchain_core.documents.base import Document
|
8 |
-
# from langchain_core.vectorstores import VectorStore
|
9 |
-
# from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
|
10 |
-
|
11 |
-
# from typing import List
|
12 |
-
# from pydantic import Field
|
13 |
-
|
14 |
-
# def _add_metadata_and_score(docs: List) -> Document:
|
15 |
-
# # Add score to metadata
|
16 |
-
# docs_with_metadata = []
|
17 |
-
# for i,(doc,score) in enumerate(docs):
|
18 |
-
# doc.page_content = doc.page_content.replace("\r\n"," ")
|
19 |
-
# doc.metadata["similarity_score"] = score
|
20 |
-
# doc.metadata["content"] = doc.page_content
|
21 |
-
# doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
|
22 |
-
# # doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
|
23 |
-
# docs_with_metadata.append(doc)
|
24 |
-
# return docs_with_metadata
|
25 |
-
|
26 |
-
# class ClimateQARetriever(BaseRetriever):
|
27 |
-
# vectorstore:VectorStore
|
28 |
-
# sources:list = ["IPCC","IPBES","IPOS"]
|
29 |
-
# reports:list = []
|
30 |
-
# threshold:float = 0.6
|
31 |
-
# k_summary:int = 3
|
32 |
-
# k_total:int = 10
|
33 |
-
# namespace:str = "vectors",
|
34 |
-
# min_size:int = 200,
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
# def _get_relevant_documents(
|
39 |
-
# self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
40 |
-
# ) -> List[Document]:
|
41 |
-
|
42 |
-
# # Check if all elements in the list are either IPCC or IPBES
|
43 |
-
# assert isinstance(self.sources,list)
|
44 |
-
# assert self.sources
|
45 |
-
# assert all([x in ["IPCC","IPBES","IPOS"] for x in self.sources])
|
46 |
-
# assert self.k_total > self.k_summary, "k_total should be greater than k_summary"
|
47 |
-
|
48 |
-
# # Prepare base search kwargs
|
49 |
-
# filters = {}
|
50 |
-
|
51 |
-
# if len(self.reports) > 0:
|
52 |
-
# filters["short_name"] = {"$in":self.reports}
|
53 |
-
# else:
|
54 |
-
# filters["source"] = { "$in":self.sources}
|
55 |
-
|
56 |
-
# # Search for k_summary documents in the summaries dataset
|
57 |
-
# filters_summaries = {
|
58 |
-
# **filters,
|
59 |
-
# "chunk_type":"text",
|
60 |
-
# "report_type": { "$in":["SPM"]},
|
61 |
-
# }
|
62 |
-
|
63 |
-
# docs_summaries = self.vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = self.k_summary)
|
64 |
-
# docs_summaries = [x for x in docs_summaries if x[1] > self.threshold]
|
65 |
-
# # docs_summaries = []
|
66 |
-
|
67 |
-
# # Search for k_total - k_summary documents in the full reports dataset
|
68 |
-
# filters_full = {
|
69 |
-
# **filters,
|
70 |
-
# "chunk_type":"text",
|
71 |
-
# "report_type": { "$nin":["SPM"]},
|
72 |
-
# }
|
73 |
-
# k_full = self.k_total - len(docs_summaries)
|
74 |
-
# docs_full = self.vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_full)
|
75 |
-
|
76 |
-
# # Images
|
77 |
-
# filters_image = {
|
78 |
-
# **filters,
|
79 |
-
# "chunk_type":"image"
|
80 |
-
# }
|
81 |
-
# docs_images = self.vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_full)
|
82 |
-
|
83 |
-
# # docs_images = []
|
84 |
-
|
85 |
-
# # Concatenate documents
|
86 |
-
# # docs = docs_summaries + docs_full + docs_images
|
87 |
-
|
88 |
-
# # Filter if scores are below threshold
|
89 |
-
# # docs = [x for x in docs if x[1] > self.threshold]
|
90 |
-
|
91 |
-
# docs_summaries, docs_full, docs_images = _add_metadata_and_score(docs_summaries), _add_metadata_and_score(docs_full), _add_metadata_and_score(docs_images)
|
92 |
-
|
93 |
-
# # Filter if length are below threshold
|
94 |
-
# docs_summaries = [x for x in docs_summaries if len(x.page_content) > self.min_size]
|
95 |
-
# docs_full = [x for x in docs_full if len(x.page_content) > self.min_size]
|
96 |
-
|
97 |
-
|
98 |
-
# return {
|
99 |
-
# "docs_summaries" : docs_summaries,
|
100 |
-
# "docs_full" : docs_full,
|
101 |
-
# "docs_images" : docs_images,
|
102 |
-
# }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/papers/__init__.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
|
3 |
+
from pyalex import Works, Authors, Sources, Institutions, Concepts, Publishers, Funders
|
4 |
+
import pyalex
|
5 |
+
|
6 |
+
pyalex.config.email = "[email protected]"
|
7 |
+
|
8 |
+
class OpenAlex():
|
9 |
+
def __init__(self):
|
10 |
+
pass
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
def search(self,keywords,n_results = 100,after = None,before = None):
|
15 |
+
works = Works().search(keywords).get()
|
16 |
+
|
17 |
+
for page in works.paginate(per_page=n_results):
|
18 |
+
break
|
19 |
+
|
20 |
+
df_works = pd.DataFrame(page)
|
21 |
+
|
22 |
+
return works
|
23 |
+
|
24 |
+
|
25 |
+
def make_network(self):
|
26 |
+
pass
|
27 |
+
|
28 |
+
|
29 |
+
def get_abstract_from_inverted_index(self,index):
|
30 |
+
|
31 |
+
# Determine the maximum index to know the length of the reconstructed array
|
32 |
+
max_index = max([max(positions) for positions in index.values()])
|
33 |
+
|
34 |
+
# Initialize a list with placeholders for all positions
|
35 |
+
reconstructed = [''] * (max_index + 1)
|
36 |
+
|
37 |
+
# Iterate through the inverted index and place each token at its respective position(s)
|
38 |
+
for token, positions in index.items():
|
39 |
+
for position in positions:
|
40 |
+
reconstructed[position] = token
|
41 |
+
|
42 |
+
# Join the tokens to form the reconstructed sentence(s)
|
43 |
+
return ' '.join(reconstructed)
|
climateqa/{knowledge → papers}/openalex.py
RENAMED
@@ -3,32 +3,18 @@ import networkx as nx
|
|
3 |
import matplotlib.pyplot as plt
|
4 |
from pyvis.network import Network
|
5 |
|
6 |
-
from langchain_core.retrievers import BaseRetriever
|
7 |
-
from langchain_core.vectorstores import VectorStoreRetriever
|
8 |
-
from langchain_core.documents.base import Document
|
9 |
-
from langchain_core.vectorstores import VectorStore
|
10 |
-
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
|
11 |
-
|
12 |
-
from ..engine.utils import num_tokens_from_string
|
13 |
-
|
14 |
-
from typing import List
|
15 |
-
from pydantic import Field
|
16 |
-
|
17 |
from pyalex import Works, Authors, Sources, Institutions, Concepts, Publishers, Funders
|
18 |
import pyalex
|
19 |
|
20 |
pyalex.config.email = "[email protected]"
|
21 |
|
22 |
-
|
23 |
-
def replace_nan_with_empty_dict(x):
|
24 |
-
return x if pd.notna(x) else {}
|
25 |
-
|
26 |
class OpenAlex():
|
27 |
def __init__(self):
|
28 |
pass
|
29 |
|
30 |
|
31 |
-
|
|
|
32 |
|
33 |
if isinstance(keywords,str):
|
34 |
works = Works().search(keywords)
|
@@ -41,36 +27,29 @@ class OpenAlex():
|
|
41 |
break
|
42 |
|
43 |
df_works = pd.DataFrame(page)
|
44 |
-
|
45 |
-
if df_works.empty:
|
46 |
-
return df_works
|
47 |
-
|
48 |
-
df_works = df_works.dropna(subset = ["title"])
|
49 |
-
df_works["primary_location"] = df_works["primary_location"].map(replace_nan_with_empty_dict)
|
50 |
-
df_works["abstract"] = df_works["abstract_inverted_index"].apply(lambda x: self.get_abstract_from_inverted_index(x)).fillna("")
|
51 |
df_works["is_oa"] = df_works["open_access"].map(lambda x : x.get("is_oa",False))
|
52 |
df_works["pdf_url"] = df_works["primary_location"].map(lambda x : x.get("pdf_url",None))
|
53 |
-
df_works["
|
54 |
-
df_works["content"] = (df_works["title"] + "\n" + df_works["abstract"]).map(lambda x : x.strip())
|
55 |
-
df_works["num_tokens"] = df_works["content"].map(lambda x : num_tokens_from_string(x))
|
56 |
-
|
57 |
-
df_works = df_works.drop(columns = ["abstract_inverted_index"])
|
58 |
-
df_works["display_name"] = df_works["primary_location"].apply(lambda x :x["source"] if type(x) == dict and 'source' in x else "").apply(lambda x : x["display_name"] if type(x) == dict and "display_name" in x else "")
|
59 |
-
df_works["subtitle"] = df_works["title"].astype(str) + " - " + df_works["display_name"].astype(str) + " - " + df_works["publication_year"].astype(str)
|
60 |
|
61 |
-
return df_works
|
62 |
else:
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
|
66 |
def rerank(self,query,df,reranker):
|
67 |
|
68 |
scores = reranker.rank(
|
69 |
query,
|
70 |
-
df["content"].tolist()
|
|
|
71 |
)
|
72 |
-
scores
|
73 |
-
scores = [x
|
74 |
df["rerank_score"] = scores
|
75 |
return df
|
76 |
|
@@ -160,36 +139,4 @@ class OpenAlex():
|
|
160 |
reconstructed[position] = token
|
161 |
|
162 |
# Join the tokens to form the reconstructed sentence(s)
|
163 |
-
return ' '.join(reconstructed)
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
class OpenAlexRetriever(BaseRetriever):
|
168 |
-
min_year:int = 1960
|
169 |
-
max_year:int = None
|
170 |
-
k:int = 100
|
171 |
-
|
172 |
-
def _get_relevant_documents(
|
173 |
-
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
174 |
-
) -> List[Document]:
|
175 |
-
|
176 |
-
openalex = OpenAlex()
|
177 |
-
|
178 |
-
# Search for documents
|
179 |
-
df_docs = openalex.search(query,n_results=self.k,after = self.min_year,before = self.max_year)
|
180 |
-
|
181 |
-
docs = []
|
182 |
-
for i,row in df_docs.iterrows():
|
183 |
-
num_tokens = row["num_tokens"]
|
184 |
-
|
185 |
-
if num_tokens < 50 or num_tokens > 1000:
|
186 |
-
continue
|
187 |
-
|
188 |
-
doc = Document(
|
189 |
-
page_content = row["content"],
|
190 |
-
metadata = row.to_dict()
|
191 |
-
)
|
192 |
-
docs.append(doc)
|
193 |
-
return docs
|
194 |
-
|
195 |
-
|
|
|
3 |
import matplotlib.pyplot as plt
|
4 |
from pyvis.network import Network
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from pyalex import Works, Authors, Sources, Institutions, Concepts, Publishers, Funders
|
7 |
import pyalex
|
8 |
|
9 |
pyalex.config.email = "[email protected]"
|
10 |
|
|
|
|
|
|
|
|
|
11 |
class OpenAlex():
|
12 |
def __init__(self):
|
13 |
pass
|
14 |
|
15 |
|
16 |
+
|
17 |
+
def search(self,keywords,n_results = 100,after = None,before = None):
|
18 |
|
19 |
if isinstance(keywords,str):
|
20 |
works = Works().search(keywords)
|
|
|
27 |
break
|
28 |
|
29 |
df_works = pd.DataFrame(page)
|
30 |
+
df_works["abstract"] = df_works["abstract_inverted_index"].apply(lambda x: self.get_abstract_from_inverted_index(x))
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
df_works["is_oa"] = df_works["open_access"].map(lambda x : x.get("is_oa",False))
|
32 |
df_works["pdf_url"] = df_works["primary_location"].map(lambda x : x.get("pdf_url",None))
|
33 |
+
df_works["content"] = df_works["title"] + "\n" + df_works["abstract"]
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
|
|
35 |
else:
|
36 |
+
df_works = []
|
37 |
+
for keyword in keywords:
|
38 |
+
df_keyword = self.search(keyword,n_results = n_results,after = after,before = before)
|
39 |
+
df_works.append(df_keyword)
|
40 |
+
df_works = pd.concat(df_works,ignore_index=True,axis = 0)
|
41 |
+
return df_works
|
42 |
|
43 |
|
44 |
def rerank(self,query,df,reranker):
|
45 |
|
46 |
scores = reranker.rank(
|
47 |
query,
|
48 |
+
df["content"].tolist(),
|
49 |
+
top_k = len(df),
|
50 |
)
|
51 |
+
scores.sort(key = lambda x : x["corpus_id"])
|
52 |
+
scores = [x["score"] for x in scores]
|
53 |
df["rerank_score"] = scores
|
54 |
return df
|
55 |
|
|
|
139 |
reconstructed[position] = token
|
140 |
|
141 |
# Join the tokens to form the reconstructed sentence(s)
|
142 |
+
return ' '.join(reconstructed)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/utils.py
CHANGED
@@ -20,16 +20,3 @@ def get_image_from_azure_blob_storage(path):
|
|
20 |
file_object = get_file_from_azure_blob_storage(path)
|
21 |
image = Image.open(file_object)
|
22 |
return image
|
23 |
-
|
24 |
-
def remove_duplicates_keep_highest_score(documents):
|
25 |
-
unique_docs = {}
|
26 |
-
|
27 |
-
for doc in documents:
|
28 |
-
doc_id = doc.metadata.get('doc_id')
|
29 |
-
if doc_id in unique_docs:
|
30 |
-
if doc.metadata['reranking_score'] > unique_docs[doc_id].metadata['reranking_score']:
|
31 |
-
unique_docs[doc_id] = doc
|
32 |
-
else:
|
33 |
-
unique_docs[doc_id] = doc
|
34 |
-
|
35 |
-
return list(unique_docs.values())
|
|
|
20 |
file_object = get_file_from_azure_blob_storage(path)
|
21 |
image = Image.open(file_object)
|
22 |
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
front/__init__.py
DELETED
File without changes
|
front/callbacks.py
DELETED
File without changes
|
front/deprecated.py
DELETED
@@ -1,46 +0,0 @@
|
|
1 |
-
|
2 |
-
# Functions to toggle visibility
|
3 |
-
def toggle_summary_visibility():
|
4 |
-
global summary_visible
|
5 |
-
summary_visible = not summary_visible
|
6 |
-
return gr.update(visible=summary_visible)
|
7 |
-
|
8 |
-
def toggle_relevant_visibility():
|
9 |
-
global relevant_visible
|
10 |
-
relevant_visible = not relevant_visible
|
11 |
-
return gr.update(visible=relevant_visible)
|
12 |
-
|
13 |
-
def change_completion_status(current_state):
|
14 |
-
current_state = 1 - current_state
|
15 |
-
return current_state
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
def vote(data: gr.LikeData):
|
20 |
-
if data.liked:
|
21 |
-
print(data.value)
|
22 |
-
else:
|
23 |
-
print(data)
|
24 |
-
|
25 |
-
def save_graph(saved_graphs_state, embedding, category):
|
26 |
-
print(f"\nCategory:\n{saved_graphs_state}\n")
|
27 |
-
if category not in saved_graphs_state:
|
28 |
-
saved_graphs_state[category] = []
|
29 |
-
if embedding not in saved_graphs_state[category]:
|
30 |
-
saved_graphs_state[category].append(embedding)
|
31 |
-
return saved_graphs_state, gr.Button("Graph Saved")
|
32 |
-
|
33 |
-
|
34 |
-
# Function to save feedback
|
35 |
-
def save_feedback(feed: str, user_id):
|
36 |
-
if len(feed) > 1:
|
37 |
-
timestamp = str(datetime.now().timestamp())
|
38 |
-
file = user_id + timestamp + ".json"
|
39 |
-
logs = {
|
40 |
-
"user_id": user_id,
|
41 |
-
"feedback": feed,
|
42 |
-
"time": timestamp,
|
43 |
-
}
|
44 |
-
log_on_azure(file, logs, share_client)
|
45 |
-
return "Feedback submitted, thank you!"
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
front/event_listeners.py
DELETED
File without changes
|
front/tabs/__init__.py
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
from .tab_config import create_config_modal
|
2 |
-
from .tab_examples import create_examples_tab
|
3 |
-
from .tab_papers import create_papers_tab
|
4 |
-
from .tab_figures import create_figures_tab
|
5 |
-
from .chat_interface import create_chat_interface
|
6 |
-
from .tab_about import create_about_tab
|
|
|
|
|
|
|
|
|
|
|
|
|
|
front/tabs/chat_interface.py
DELETED
@@ -1,55 +0,0 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
from gradio.components import ChatMessage
|
3 |
-
|
4 |
-
# Initialize prompt and system template
|
5 |
-
init_prompt = """
|
6 |
-
Hello, I am ClimateQ&A, a conversational assistant designed to help you understand climate change and biodiversity loss. I will answer your questions by **sifting through the IPCC and IPBES scientific reports**.
|
7 |
-
|
8 |
-
❓ How to use
|
9 |
-
- **Language**: You can ask me your questions in any language.
|
10 |
-
- **Audience**: You can specify your audience (children, general public, experts) to get a more adapted answer.
|
11 |
-
- **Sources**: You can choose to search in the IPCC or IPBES reports, or both.
|
12 |
-
- **Relevant content sources**: You can choose to search for figures, papers, or graphs that can be relevant for your question.
|
13 |
-
|
14 |
-
⚠️ Limitations
|
15 |
-
*Please note that the AI is not perfect and may sometimes give irrelevant answers. If you are not satisfied with the answer, please ask a more specific question or report your feedback to help us improve the system.*
|
16 |
-
|
17 |
-
🛈 Information
|
18 |
-
Please note that we log your questions for meta-analysis purposes, so avoid sharing any sensitive or personal information.
|
19 |
-
|
20 |
-
What do you want to learn ?
|
21 |
-
"""
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
# UI Layout Components
|
26 |
-
def create_chat_interface():
|
27 |
-
chatbot = gr.Chatbot(
|
28 |
-
value=[ChatMessage(role="assistant", content=init_prompt)],
|
29 |
-
type="messages",
|
30 |
-
show_copy_button=True,
|
31 |
-
show_label=False,
|
32 |
-
elem_id="chatbot",
|
33 |
-
layout="panel",
|
34 |
-
avatar_images=(None, "https://i.ibb.co/YNyd5W2/logo4.png"),
|
35 |
-
max_height="80vh",
|
36 |
-
height="100vh"
|
37 |
-
)
|
38 |
-
|
39 |
-
with gr.Row(elem_id="input-message"):
|
40 |
-
|
41 |
-
textbox = gr.Textbox(
|
42 |
-
placeholder="Ask me anything here!",
|
43 |
-
show_label=False,
|
44 |
-
scale=12,
|
45 |
-
lines=1,
|
46 |
-
interactive=True,
|
47 |
-
elem_id=f"input-textbox"
|
48 |
-
)
|
49 |
-
|
50 |
-
config_button = gr.Button("", elem_id="config-button")
|
51 |
-
|
52 |
-
return chatbot, textbox, config_button
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
front/tabs/main_tab.py
DELETED
@@ -1,69 +0,0 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
from .chat_interface import create_chat_interface
|
3 |
-
from .tab_examples import create_examples_tab
|
4 |
-
from .tab_papers import create_papers_tab
|
5 |
-
from .tab_figures import create_figures_tab
|
6 |
-
from .chat_interface import create_chat_interface
|
7 |
-
|
8 |
-
def cqa_tab(tab_name):
|
9 |
-
# State variables
|
10 |
-
current_graphs = gr.State([])
|
11 |
-
with gr.Tab(tab_name):
|
12 |
-
with gr.Row(elem_id="chatbot-row"):
|
13 |
-
# Left column - Chat interface
|
14 |
-
with gr.Column(scale=2):
|
15 |
-
chatbot, textbox, config_button = create_chat_interface()
|
16 |
-
|
17 |
-
# Right column - Content panels
|
18 |
-
with gr.Column(scale=2, variant="panel", elem_id="right-panel"):
|
19 |
-
with gr.Tabs(elem_id="right_panel_tab") as tabs:
|
20 |
-
# Examples tab
|
21 |
-
with gr.TabItem("Examples", elem_id="tab-examples", id=0):
|
22 |
-
examples_hidden, dropdown_samples, samples = create_examples_tab()
|
23 |
-
|
24 |
-
# Sources tab
|
25 |
-
with gr.Tab("Sources", elem_id="tab-sources", id=1) as tab_sources:
|
26 |
-
sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox")
|
27 |
-
|
28 |
-
|
29 |
-
# Recommended content tab
|
30 |
-
with gr.Tab("Recommended content", elem_id="tab-recommended_content", id=2) as tab_recommended_content:
|
31 |
-
with gr.Tabs(elem_id="group-subtabs") as tabs_recommended_content:
|
32 |
-
# Figures subtab
|
33 |
-
with gr.Tab("Figures", elem_id="tab-figures", id=3) as tab_figures:
|
34 |
-
sources_raw, new_figures, used_figures, gallery_component, figures_cards, figure_modal = create_figures_tab()
|
35 |
-
|
36 |
-
# Papers subtab
|
37 |
-
with gr.Tab("Papers", elem_id="tab-citations", id=4) as tab_papers:
|
38 |
-
papers_summary, papers_html, citations_network, papers_modal = create_papers_tab()
|
39 |
-
|
40 |
-
# Graphs subtab
|
41 |
-
with gr.Tab("Graphs", elem_id="tab-graphs", id=5) as tab_graphs:
|
42 |
-
graphs_container = gr.HTML(
|
43 |
-
"<h2>There are no graphs to be displayed at the moment. Try asking another question.</h2>",
|
44 |
-
elem_id="graphs-container"
|
45 |
-
)
|
46 |
-
return {
|
47 |
-
"chatbot": chatbot,
|
48 |
-
"textbox": textbox,
|
49 |
-
"tabs": tabs,
|
50 |
-
"sources_raw": sources_raw,
|
51 |
-
"new_figures": new_figures,
|
52 |
-
"current_graphs": current_graphs,
|
53 |
-
"examples_hidden": examples_hidden,
|
54 |
-
"dropdown_samples": dropdown_samples,
|
55 |
-
"samples": samples,
|
56 |
-
"sources_textbox": sources_textbox,
|
57 |
-
"figures_cards": figures_cards,
|
58 |
-
"gallery_component": gallery_component,
|
59 |
-
"config_button": config_button,
|
60 |
-
"papers_html": papers_html,
|
61 |
-
"citations_network": citations_network,
|
62 |
-
"papers_summary": papers_summary,
|
63 |
-
"tab_recommended_content": tab_recommended_content,
|
64 |
-
"tab_sources": tab_sources,
|
65 |
-
"tab_figures": tab_figures,
|
66 |
-
"tab_graphs": tab_graphs,
|
67 |
-
"tab_papers": tab_papers,
|
68 |
-
"graph_container": graphs_container
|
69 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
front/tabs/tab_about.py
DELETED
@@ -1,38 +0,0 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
|
3 |
-
# Citation information
|
4 |
-
CITATION_LABEL = "BibTeX citation for ClimateQ&A"
|
5 |
-
CITATION_TEXT = r"""@misc{climateqa,
|
6 |
-
author={Théo Alves Da Costa, Timothée Bohe},
|
7 |
-
title={ClimateQ&A, AI-powered conversational assistant for climate change and biodiversity loss},
|
8 |
-
year={2024},
|
9 |
-
howpublished= {\url{https://climateqa.com}},
|
10 |
-
}
|
11 |
-
@software{climateqa,
|
12 |
-
author = {Théo Alves Da Costa, Timothée Bohe},
|
13 |
-
publisher = {ClimateQ&A},
|
14 |
-
title = {ClimateQ&A, AI-powered conversational assistant for climate change and biodiversity loss},
|
15 |
-
}
|
16 |
-
"""
|
17 |
-
|
18 |
-
def create_about_tab():
|
19 |
-
with gr.Tab("About", elem_classes="max-height other-tabs"):
|
20 |
-
with gr.Row():
|
21 |
-
with gr.Column(scale=1):
|
22 |
-
gr.Markdown(
|
23 |
-
"""
|
24 |
-
### More info
|
25 |
-
- See more info at [https://climateqa.com](https://climateqa.com/docs/intro/)
|
26 |
-
- Feedbacks on this [form](https://forms.office.com/e/1Yzgxm6jbp)
|
27 |
-
|
28 |
-
### Citation
|
29 |
-
"""
|
30 |
-
)
|
31 |
-
with gr.Accordion(CITATION_LABEL, elem_id="citation", open=False):
|
32 |
-
gr.Textbox(
|
33 |
-
value=CITATION_TEXT,
|
34 |
-
label="",
|
35 |
-
interactive=False,
|
36 |
-
show_copy_button=True,
|
37 |
-
lines=len(CITATION_TEXT.split('\n')),
|
38 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
front/tabs/tab_config.py
DELETED
@@ -1,123 +0,0 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
from gradio_modal import Modal
|
3 |
-
from climateqa.constants import POSSIBLE_REPORTS
|
4 |
-
from typing import TypedDict
|
5 |
-
|
6 |
-
class ConfigPanel(TypedDict):
|
7 |
-
config_open: gr.State
|
8 |
-
config_modal: Modal
|
9 |
-
dropdown_sources: gr.CheckboxGroup
|
10 |
-
dropdown_reports: gr.Dropdown
|
11 |
-
dropdown_external_sources: gr.CheckboxGroup
|
12 |
-
search_only: gr.Checkbox
|
13 |
-
dropdown_audience: gr.Dropdown
|
14 |
-
after: gr.Slider
|
15 |
-
output_query: gr.Textbox
|
16 |
-
output_language: gr.Textbox
|
17 |
-
|
18 |
-
|
19 |
-
def create_config_modal():
|
20 |
-
config_open = gr.State(value=True)
|
21 |
-
with Modal(visible=False, elem_id="modal-config") as config_modal:
|
22 |
-
gr.Markdown("Reminders: You can talk in any language, ClimateQ&A is multi-lingual!")
|
23 |
-
|
24 |
-
dropdown_sources = gr.CheckboxGroup(
|
25 |
-
choices=["IPCC", "IPBES", "IPOS"],
|
26 |
-
label="Select source (by default search in all sources)",
|
27 |
-
value=["IPCC"],
|
28 |
-
interactive=True
|
29 |
-
)
|
30 |
-
|
31 |
-
dropdown_reports = gr.Dropdown(
|
32 |
-
choices=POSSIBLE_REPORTS,
|
33 |
-
label="Or select specific reports",
|
34 |
-
multiselect=True,
|
35 |
-
value=None,
|
36 |
-
interactive=True
|
37 |
-
)
|
38 |
-
|
39 |
-
dropdown_external_sources = gr.CheckboxGroup(
|
40 |
-
choices=["Figures (IPCC/IPBES)", "Papers (OpenAlex)", "Graphs (OurWorldInData)","POC region"],
|
41 |
-
label="Select database to search for relevant content",
|
42 |
-
value=["Figures (IPCC/IPBES)","POC region"],
|
43 |
-
interactive=True
|
44 |
-
)
|
45 |
-
|
46 |
-
search_only = gr.Checkbox(
|
47 |
-
label="Search only for recommended content without chating",
|
48 |
-
value=False,
|
49 |
-
interactive=True,
|
50 |
-
elem_id="checkbox-chat"
|
51 |
-
)
|
52 |
-
|
53 |
-
dropdown_audience = gr.Dropdown(
|
54 |
-
choices=["Children", "General public", "Experts"],
|
55 |
-
label="Select audience",
|
56 |
-
value="Experts",
|
57 |
-
interactive=True
|
58 |
-
)
|
59 |
-
|
60 |
-
after = gr.Slider(
|
61 |
-
minimum=1950,
|
62 |
-
maximum=2023,
|
63 |
-
step=1,
|
64 |
-
value=1960,
|
65 |
-
label="Publication date",
|
66 |
-
show_label=True,
|
67 |
-
interactive=True,
|
68 |
-
elem_id="date-papers",
|
69 |
-
visible=False
|
70 |
-
)
|
71 |
-
|
72 |
-
output_query = gr.Textbox(
|
73 |
-
label="Query used for retrieval",
|
74 |
-
show_label=True,
|
75 |
-
elem_id="reformulated-query",
|
76 |
-
lines=2,
|
77 |
-
interactive=False,
|
78 |
-
visible=False
|
79 |
-
)
|
80 |
-
|
81 |
-
output_language = gr.Textbox(
|
82 |
-
label="Language",
|
83 |
-
show_label=True,
|
84 |
-
elem_id="language",
|
85 |
-
lines=1,
|
86 |
-
interactive=False,
|
87 |
-
visible=False
|
88 |
-
)
|
89 |
-
|
90 |
-
dropdown_external_sources.change(
|
91 |
-
lambda x: gr.update(visible="Papers (OpenAlex)" in x),
|
92 |
-
inputs=[dropdown_external_sources],
|
93 |
-
outputs=[after]
|
94 |
-
)
|
95 |
-
|
96 |
-
close_config_modal_button = gr.Button("Validate and Close", elem_id="close-config-modal")
|
97 |
-
|
98 |
-
|
99 |
-
# return ConfigPanel(
|
100 |
-
# config_open=config_open,
|
101 |
-
# config_modal=config_modal,
|
102 |
-
# dropdown_sources=dropdown_sources,
|
103 |
-
# dropdown_reports=dropdown_reports,
|
104 |
-
# dropdown_external_sources=dropdown_external_sources,
|
105 |
-
# search_only=search_only,
|
106 |
-
# dropdown_audience=dropdown_audience,
|
107 |
-
# after=after,
|
108 |
-
# output_query=output_query,
|
109 |
-
# output_language=output_language
|
110 |
-
# )
|
111 |
-
return {
|
112 |
-
"config_open" : config_open,
|
113 |
-
"config_modal": config_modal,
|
114 |
-
"dropdown_sources": dropdown_sources,
|
115 |
-
"dropdown_reports": dropdown_reports,
|
116 |
-
"dropdown_external_sources": dropdown_external_sources,
|
117 |
-
"search_only": search_only,
|
118 |
-
"dropdown_audience": dropdown_audience,
|
119 |
-
"after": after,
|
120 |
-
"output_query": output_query,
|
121 |
-
"output_language": output_language,
|
122 |
-
"close_config_modal_button": close_config_modal_button
|
123 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
front/tabs/tab_examples.py
DELETED
@@ -1,40 +0,0 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
from climateqa.sample_questions import QUESTIONS
|
3 |
-
|
4 |
-
|
5 |
-
def create_examples_tab():
|
6 |
-
examples_hidden = gr.Textbox(visible=False, elem_id=f"examples-hidden")
|
7 |
-
first_key = list(QUESTIONS.keys())[0]
|
8 |
-
dropdown_samples = gr.Dropdown(
|
9 |
-
choices=QUESTIONS.keys(),
|
10 |
-
value=first_key,
|
11 |
-
interactive=True,
|
12 |
-
label="Select a category of sample questions",
|
13 |
-
elem_id="dropdown-samples"
|
14 |
-
)
|
15 |
-
|
16 |
-
samples = []
|
17 |
-
for i, key in enumerate(QUESTIONS.keys()):
|
18 |
-
examples_visible = (i == 0)
|
19 |
-
with gr.Row(visible=examples_visible) as group_examples:
|
20 |
-
examples_questions = gr.Examples(
|
21 |
-
examples=QUESTIONS[key],
|
22 |
-
inputs=[examples_hidden],
|
23 |
-
examples_per_page=8,
|
24 |
-
run_on_click=False,
|
25 |
-
elem_id=f"examples{i}",
|
26 |
-
api_name=f"examples{i}"
|
27 |
-
)
|
28 |
-
samples.append(group_examples)
|
29 |
-
|
30 |
-
|
31 |
-
def change_sample_questions(key):
|
32 |
-
index = list(QUESTIONS.keys()).index(key)
|
33 |
-
visible_bools = [False] * len(samples)
|
34 |
-
visible_bools[index] = True
|
35 |
-
return [gr.update(visible=visible_bools[i]) for i in range(len(samples))]
|
36 |
-
|
37 |
-
# event listener
|
38 |
-
dropdown_samples.change(change_sample_questions, dropdown_samples, samples)
|
39 |
-
|
40 |
-
return examples_hidden
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
front/tabs/tab_figures.py
DELETED
@@ -1,31 +0,0 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
from gradio_modal import Modal
|
3 |
-
|
4 |
-
|
5 |
-
def create_figures_tab():
|
6 |
-
sources_raw = gr.State()
|
7 |
-
new_figures = gr.State([])
|
8 |
-
used_figures = gr.State([])
|
9 |
-
|
10 |
-
with Modal(visible=False, elem_id="modal_figure_galery") as figure_modal:
|
11 |
-
gallery_component = gr.Gallery(
|
12 |
-
object_fit='scale-down',
|
13 |
-
elem_id="gallery-component",
|
14 |
-
height="80vh"
|
15 |
-
)
|
16 |
-
|
17 |
-
show_full_size_figures = gr.Button(
|
18 |
-
"Show figures in full size",
|
19 |
-
elem_id="show-figures",
|
20 |
-
interactive=True
|
21 |
-
)
|
22 |
-
show_full_size_figures.click(
|
23 |
-
lambda: Modal(visible=True),
|
24 |
-
None,
|
25 |
-
figure_modal
|
26 |
-
)
|
27 |
-
|
28 |
-
figures_cards = gr.HTML(show_label=False, elem_id="sources-figures")
|
29 |
-
|
30 |
-
return sources_raw, new_figures, used_figures, gallery_component, figures_cards, figure_modal
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|