test-tim
#12
by
timeki
- opened
This view is limited to 50 files because it contains too many changes.
See the raw diff here.
- .gitattributes +0 -2
- .gitignore +0 -13
- README.md +1 -1
- app.py +608 -486
- climateqa/chat.py +0 -194
- climateqa/constants.py +1 -59
- climateqa/engine/chains/__init__.py +0 -0
- climateqa/engine/chains/answer_ai_impact.py +0 -46
- climateqa/engine/chains/answer_chitchat.py +0 -56
- climateqa/engine/chains/chitchat_categorization.py +0 -43
- climateqa/engine/chains/follow_up.py +0 -33
- climateqa/engine/chains/graph_retriever.py +0 -130
- climateqa/engine/chains/intent_categorization.py +0 -97
- climateqa/engine/chains/keywords_extraction.py +0 -40
- climateqa/engine/chains/query_transformation.py +0 -300
- climateqa/engine/chains/retrieve_documents.py +0 -705
- climateqa/engine/chains/retrieve_papers.py +0 -95
- climateqa/engine/chains/retriever.py +0 -126
- climateqa/engine/chains/sample_router.py +0 -66
- climateqa/engine/chains/set_defaults.py +0 -13
- climateqa/engine/chains/standalone_question.py +0 -42
- climateqa/engine/chains/translation.py +0 -42
- climateqa/engine/embeddings.py +3 -6
- climateqa/engine/graph.py +0 -346
- climateqa/engine/graph_retriever.py +0 -88
- climateqa/engine/keywords.py +1 -3
- climateqa/engine/llm/__init__.py +0 -3
- climateqa/engine/llm/ollama.py +0 -6
- climateqa/engine/llm/openai.py +1 -1
- climateqa/engine/{chains/prompts.py → prompts.py} +6 -107
- climateqa/engine/{chains/answer_rag.py → rag.py} +60 -41
- climateqa/engine/{chains/reformulation.py → reformulation.py} +1 -1
- climateqa/engine/reranker.py +0 -55
- climateqa/engine/retriever.py +163 -0
- climateqa/engine/talk_to_data/config.py +0 -11
- climateqa/engine/talk_to_data/drias/config.py +0 -124
- climateqa/engine/talk_to_data/drias/plot_informations.py +0 -88
- climateqa/engine/talk_to_data/drias/plots.py +0 -434
- climateqa/engine/talk_to_data/drias/queries.py +0 -83
- climateqa/engine/talk_to_data/input_processing.py +0 -257
- climateqa/engine/talk_to_data/ipcc/config.py +0 -98
- climateqa/engine/talk_to_data/ipcc/plot_informations.py +0 -50
- climateqa/engine/talk_to_data/ipcc/plots.py +0 -189
- climateqa/engine/talk_to_data/ipcc/queries.py +0 -144
- climateqa/engine/talk_to_data/main.py +0 -124
- climateqa/engine/talk_to_data/myVanna.py +0 -13
- climateqa/engine/talk_to_data/objects/llm_outputs.py +0 -13
- climateqa/engine/talk_to_data/objects/location.py +0 -12
- climateqa/engine/talk_to_data/objects/plot.py +0 -23
- climateqa/engine/talk_to_data/objects/states.py +0 -19
.gitattributes
CHANGED
@@ -44,5 +44,3 @@ documents/climate_gpt_v2_only_giec.faiss filter=lfs diff=lfs merge=lfs -text
|
|
44 |
documents/climate_gpt_v2.faiss filter=lfs diff=lfs merge=lfs -text
|
45 |
climateqa_v3.db filter=lfs diff=lfs merge=lfs -text
|
46 |
climateqa_v3.faiss filter=lfs diff=lfs merge=lfs -text
|
47 |
-
data/drias/drias.db filter=lfs diff=lfs merge=lfs -text
|
48 |
-
front/assets/*.png filter=lfs diff=lfs merge=lfs -text
|
|
|
44 |
documents/climate_gpt_v2.faiss filter=lfs diff=lfs merge=lfs -text
|
45 |
climateqa_v3.db filter=lfs diff=lfs merge=lfs -text
|
46 |
climateqa_v3.faiss filter=lfs diff=lfs merge=lfs -text
|
|
|
|
.gitignore
CHANGED
@@ -5,16 +5,3 @@ __pycache__/utils.cpython-38.pyc
|
|
5 |
|
6 |
notebooks/
|
7 |
*.pyc
|
8 |
-
|
9 |
-
**/.ipynb_checkpoints/
|
10 |
-
**/.flashrank_cache/
|
11 |
-
|
12 |
-
data/
|
13 |
-
sandbox/
|
14 |
-
|
15 |
-
climateqa/talk_to_data/database/
|
16 |
-
*.db
|
17 |
-
|
18 |
-
data_ingestion/
|
19 |
-
.vscode
|
20 |
-
*old/
|
|
|
5 |
|
6 |
notebooks/
|
7 |
*.pyc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
@@ -4,7 +4,7 @@ emoji: 🌍
|
|
4 |
colorFrom: blue
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
fullWidth: true
|
10 |
pinned: false
|
|
|
4 |
colorFrom: blue
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.19.1
|
8 |
app_file: app.py
|
9 |
fullWidth: true
|
10 |
pinned: false
|
app.py
CHANGED
@@ -1,44 +1,52 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
import gradio as gr
|
4 |
|
5 |
-
from
|
|
|
6 |
|
7 |
-
|
8 |
-
|
9 |
-
from climateqa.engine.llm import get_llm
|
10 |
-
from climateqa.engine.vectorstore import get_pinecone_vectorstore
|
11 |
-
from climateqa.engine.reranker import get_reranker
|
12 |
-
from climateqa.engine.graph import make_graph_agent, make_graph_agent_poc
|
13 |
-
from climateqa.engine.chains.retrieve_papers import find_papers
|
14 |
-
from climateqa.chat import start_chat, chat_stream, finish_chat
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
20 |
|
21 |
-
from
|
22 |
-
from gradio_modal import Modal
|
23 |
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
from utils import create_user_id
|
26 |
-
import logging
|
27 |
|
28 |
-
logging.basicConfig(level=logging.WARNING)
|
29 |
-
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Suppresses INFO and WARNING logs
|
30 |
-
logging.getLogger().setLevel(logging.WARNING)
|
31 |
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
# Load environment variables in local mode
|
34 |
try:
|
35 |
from dotenv import load_dotenv
|
36 |
-
|
37 |
load_dotenv()
|
38 |
except Exception as e:
|
39 |
pass
|
40 |
|
41 |
-
|
42 |
# Set up Gradio Theme
|
43 |
theme = gr.themes.Base(
|
44 |
primary_hue="blue",
|
@@ -46,7 +54,15 @@ theme = gr.themes.Base(
|
|
46 |
font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"],
|
47 |
)
|
48 |
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
account_key = os.environ["BLOB_ACCOUNT_KEY"]
|
51 |
if len(account_key) == 86:
|
52 |
account_key += "=="
|
@@ -64,102 +80,365 @@ share_client = service.get_share_client(file_share_name)
|
|
64 |
user_id = create_user_id()
|
65 |
|
66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
# Create vectorstore and retriever
|
68 |
-
|
69 |
-
|
70 |
-
embeddings_function, index_name=os.getenv("PINECONE_API_INDEX")
|
71 |
-
)
|
72 |
-
vectorstore_graphs = get_pinecone_vectorstore(
|
73 |
-
embeddings_function,
|
74 |
-
index_name=os.getenv("PINECONE_API_INDEX_OWID"),
|
75 |
-
text_key="description",
|
76 |
-
)
|
77 |
-
vectorstore_region = get_pinecone_vectorstore(
|
78 |
-
embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_LOCAL_V2")
|
79 |
-
)
|
80 |
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
sources = sources
|
121 |
-
|
122 |
-
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
-
async for event in chat_stream(
|
126 |
-
agent,
|
127 |
-
query,
|
128 |
-
history,
|
129 |
-
audience,
|
130 |
-
sources,
|
131 |
-
reports,
|
132 |
-
relevant_content_sources_selection,
|
133 |
-
search_only,
|
134 |
-
share_client,
|
135 |
-
user_id,
|
136 |
-
):
|
137 |
-
yield event
|
138 |
-
|
139 |
-
|
140 |
-
async def chat_poc(
|
141 |
-
query,
|
142 |
-
history,
|
143 |
-
audience,
|
144 |
-
sources,
|
145 |
-
reports,
|
146 |
-
relevant_content_sources_selection,
|
147 |
-
search_only,
|
148 |
-
):
|
149 |
-
print("chat poc - message received")
|
150 |
-
async for event in chat_stream(
|
151 |
-
agent_poc,
|
152 |
-
query,
|
153 |
-
history,
|
154 |
-
audience,
|
155 |
-
sources,
|
156 |
-
reports,
|
157 |
-
relevant_content_sources_selection,
|
158 |
-
search_only,
|
159 |
-
share_client,
|
160 |
-
user_id,
|
161 |
-
):
|
162 |
-
yield event
|
163 |
|
164 |
|
165 |
# --------------------------------------------------------------------
|
@@ -167,389 +446,232 @@ async def chat_poc(
|
|
167 |
# --------------------------------------------------------------------
|
168 |
|
169 |
|
170 |
-
|
171 |
-
|
172 |
-
print(config_open)
|
173 |
-
new_config_visibility_status = not config_open
|
174 |
-
return Modal(visible=new_config_visibility_status), new_config_visibility_status
|
175 |
-
|
176 |
-
|
177 |
-
def update_sources_number_display(
|
178 |
-
sources_textbox, figures_cards, current_graphs, papers_html
|
179 |
-
):
|
180 |
-
sources_number = sources_textbox.count("<h2>")
|
181 |
-
figures_number = figures_cards.count("<h2>")
|
182 |
-
graphs_number = current_graphs.count("<iframe")
|
183 |
-
papers_number = papers_html.count("<h2>")
|
184 |
-
sources_notif_label = f"Sources ({sources_number})"
|
185 |
-
figures_notif_label = f"Figures ({figures_number})"
|
186 |
-
graphs_notif_label = f"Graphs ({graphs_number})"
|
187 |
-
papers_notif_label = f"Papers ({papers_number})"
|
188 |
-
recommended_content_notif_label = (
|
189 |
-
f"Recommended content ({figures_number + graphs_number + papers_number})"
|
190 |
-
)
|
191 |
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
gr.update(label=graphs_notif_label),
|
197 |
-
gr.update(label=papers_notif_label),
|
198 |
-
)
|
199 |
|
|
|
|
|
200 |
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
)
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
[
|
356 |
-
chatbot,
|
357 |
-
new_sources_hmtl,
|
358 |
-
output_query,
|
359 |
-
output_language,
|
360 |
-
new_figures,
|
361 |
-
current_graphs,
|
362 |
-
follow_up_examples.dataset,
|
363 |
-
],
|
364 |
-
concurrency_limit=8,
|
365 |
-
api_name=f"chat_{examples_hidden.elem_id}",
|
366 |
-
)
|
367 |
-
.then(
|
368 |
-
finish_chat,
|
369 |
-
None,
|
370 |
-
[textbox],
|
371 |
-
api_name=f"finish_chat_{follow_up_examples_hidden.elem_id}",
|
372 |
-
)
|
373 |
-
)
|
374 |
-
|
375 |
-
elif tab_name == "France - Local Q&A":
|
376 |
-
print("chat poc - message sent")
|
377 |
-
# Event for textbox
|
378 |
-
(
|
379 |
-
textbox.submit(
|
380 |
-
start_chat,
|
381 |
-
[textbox, chatbot, search_only],
|
382 |
-
[textbox, tabs, chatbot, sources_raw],
|
383 |
-
queue=False,
|
384 |
-
api_name=f"start_chat_{textbox.elem_id}",
|
385 |
-
)
|
386 |
-
.then(
|
387 |
-
chat_poc,
|
388 |
-
[
|
389 |
-
textbox,
|
390 |
-
chatbot,
|
391 |
-
dropdown_audience,
|
392 |
-
dropdown_sources,
|
393 |
-
dropdown_reports,
|
394 |
-
dropdown_external_sources,
|
395 |
-
search_only,
|
396 |
-
],
|
397 |
-
[
|
398 |
-
chatbot,
|
399 |
-
new_sources_hmtl,
|
400 |
-
output_query,
|
401 |
-
output_language,
|
402 |
-
new_figures,
|
403 |
-
current_graphs,
|
404 |
-
follow_up_examples.dataset,
|
405 |
-
],
|
406 |
-
concurrency_limit=8,
|
407 |
-
api_name=f"chat_{textbox.elem_id}",
|
408 |
-
)
|
409 |
-
.then(
|
410 |
-
finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}"
|
411 |
-
)
|
412 |
-
)
|
413 |
-
# Event for examples_hidden
|
414 |
-
(
|
415 |
-
examples_hidden.change(
|
416 |
-
start_chat,
|
417 |
-
[examples_hidden, chatbot, search_only],
|
418 |
-
[examples_hidden, tabs, chatbot, sources_raw],
|
419 |
-
queue=False,
|
420 |
-
api_name=f"start_chat_{examples_hidden.elem_id}",
|
421 |
-
)
|
422 |
-
.then(
|
423 |
-
chat_poc,
|
424 |
-
[
|
425 |
-
examples_hidden,
|
426 |
-
chatbot,
|
427 |
-
dropdown_audience,
|
428 |
-
dropdown_sources,
|
429 |
-
dropdown_reports,
|
430 |
-
dropdown_external_sources,
|
431 |
-
search_only,
|
432 |
-
],
|
433 |
-
[
|
434 |
-
chatbot,
|
435 |
-
new_sources_hmtl,
|
436 |
-
output_query,
|
437 |
-
output_language,
|
438 |
-
new_figures,
|
439 |
-
current_graphs,
|
440 |
-
follow_up_examples.dataset,
|
441 |
-
],
|
442 |
-
concurrency_limit=8,
|
443 |
-
api_name=f"chat_{examples_hidden.elem_id}",
|
444 |
-
)
|
445 |
-
.then(
|
446 |
-
finish_chat,
|
447 |
-
None,
|
448 |
-
[textbox],
|
449 |
-
api_name=f"finish_chat_{examples_hidden.elem_id}",
|
450 |
-
)
|
451 |
-
)
|
452 |
-
(
|
453 |
-
follow_up_examples_hidden.change(
|
454 |
-
start_chat,
|
455 |
-
[follow_up_examples_hidden, chatbot, search_only],
|
456 |
-
[follow_up_examples_hidden, tabs, chatbot, sources_raw],
|
457 |
-
queue=False,
|
458 |
-
api_name=f"start_chat_{examples_hidden.elem_id}",
|
459 |
-
)
|
460 |
-
.then(
|
461 |
-
chat,
|
462 |
-
[
|
463 |
-
follow_up_examples_hidden,
|
464 |
-
chatbot,
|
465 |
-
dropdown_audience,
|
466 |
-
dropdown_sources,
|
467 |
-
dropdown_reports,
|
468 |
-
dropdown_external_sources,
|
469 |
-
search_only,
|
470 |
-
],
|
471 |
-
[
|
472 |
-
chatbot,
|
473 |
-
new_sources_hmtl,
|
474 |
-
output_query,
|
475 |
-
output_language,
|
476 |
-
new_figures,
|
477 |
-
current_graphs,
|
478 |
-
follow_up_examples.dataset,
|
479 |
-
],
|
480 |
-
concurrency_limit=8,
|
481 |
-
api_name=f"chat_{examples_hidden.elem_id}",
|
482 |
-
)
|
483 |
-
.then(
|
484 |
-
finish_chat,
|
485 |
-
None,
|
486 |
-
[textbox],
|
487 |
-
api_name=f"finish_chat_{follow_up_examples_hidden.elem_id}",
|
488 |
-
)
|
489 |
-
)
|
490 |
-
|
491 |
-
new_sources_hmtl.change(
|
492 |
-
lambda x: x, inputs=[new_sources_hmtl], outputs=[sources_textbox]
|
493 |
-
)
|
494 |
-
current_graphs.change(
|
495 |
-
lambda x: x, inputs=[current_graphs], outputs=[graphs_container]
|
496 |
)
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
|
|
501 |
)
|
502 |
|
503 |
-
# Update sources numbers
|
504 |
-
for component in [sources_textbox, figures_cards, current_graphs, papers_html]:
|
505 |
-
component.change(
|
506 |
-
update_sources_number_display,
|
507 |
-
[sources_textbox, figures_cards, current_graphs, papers_html],
|
508 |
-
[tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers],
|
509 |
-
)
|
510 |
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
)
|
518 |
|
519 |
-
# if tab_name == "France - Local Q&A": # Not untill results are good enough
|
520 |
-
# # Drias search
|
521 |
-
# textbox.submit(ask_vanna, [textbox], [vanna_sql_query ,vanna_table, vanna_display])
|
522 |
|
|
|
523 |
|
524 |
-
|
525 |
-
|
526 |
-
with gr.Blocks(
|
527 |
-
title="Climate Q&A",
|
528 |
-
css_paths=os.getcwd() + "/style.css",
|
529 |
-
theme=theme,
|
530 |
-
elem_id="main-component",
|
531 |
-
) as demo:
|
532 |
-
config_components = create_config_modal()
|
533 |
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
|
|
|
|
|
|
539 |
|
540 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
541 |
|
542 |
-
event_handling(cqa_components, config_components, tab_name="ClimateQ&A")
|
543 |
-
event_handling(
|
544 |
-
local_cqa_components, config_components, tab_name="France - Local Q&A"
|
545 |
-
)
|
546 |
|
547 |
-
|
|
|
548 |
|
549 |
-
|
550 |
|
551 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
552 |
|
|
|
553 |
|
554 |
-
demo
|
555 |
-
demo.launch(ssr_mode=False)
|
|
|
1 |
+
from climateqa.engine.embeddings import get_embeddings_function
|
2 |
+
embeddings_function = get_embeddings_function()
|
|
|
3 |
|
4 |
+
from climateqa.papers.openalex import OpenAlex
|
5 |
+
from sentence_transformers import CrossEncoder
|
6 |
|
7 |
+
reranker = CrossEncoder("mixedbread-ai/mxbai-rerank-xsmall-v1")
|
8 |
+
oa = OpenAlex()
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
+
import gradio as gr
|
11 |
+
import pandas as pd
|
12 |
+
import numpy as np
|
13 |
+
import os
|
14 |
+
import time
|
15 |
+
import re
|
16 |
+
import json
|
17 |
|
18 |
+
# from gradio_modal import Modal
|
|
|
19 |
|
20 |
+
from io import BytesIO
|
21 |
+
import base64
|
22 |
+
|
23 |
+
from datetime import datetime
|
24 |
+
from azure.storage.fileshare import ShareServiceClient
|
25 |
|
26 |
from utils import create_user_id
|
|
|
27 |
|
|
|
|
|
|
|
28 |
|
29 |
|
30 |
+
# ClimateQ&A imports
|
31 |
+
from climateqa.engine.llm import get_llm
|
32 |
+
from climateqa.engine.rag import make_rag_chain
|
33 |
+
from climateqa.engine.vectorstore import get_pinecone_vectorstore
|
34 |
+
from climateqa.engine.retriever import ClimateQARetriever
|
35 |
+
from climateqa.engine.embeddings import get_embeddings_function
|
36 |
+
from climateqa.engine.prompts import audience_prompts
|
37 |
+
from climateqa.sample_questions import QUESTIONS
|
38 |
+
from climateqa.constants import POSSIBLE_REPORTS
|
39 |
+
from climateqa.utils import get_image_from_azure_blob_storage
|
40 |
+
from climateqa.engine.keywords import make_keywords_chain
|
41 |
+
from climateqa.engine.rag import make_rag_papers_chain
|
42 |
+
|
43 |
# Load environment variables in local mode
|
44 |
try:
|
45 |
from dotenv import load_dotenv
|
|
|
46 |
load_dotenv()
|
47 |
except Exception as e:
|
48 |
pass
|
49 |
|
|
|
50 |
# Set up Gradio Theme
|
51 |
theme = gr.themes.Base(
|
52 |
primary_hue="blue",
|
|
|
54 |
font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"],
|
55 |
)
|
56 |
|
57 |
+
|
58 |
+
|
59 |
+
init_prompt = ""
|
60 |
+
|
61 |
+
system_template = {
|
62 |
+
"role": "system",
|
63 |
+
"content": init_prompt,
|
64 |
+
}
|
65 |
+
|
66 |
account_key = os.environ["BLOB_ACCOUNT_KEY"]
|
67 |
if len(account_key) == 86:
|
68 |
account_key += "=="
|
|
|
80 |
user_id = create_user_id()
|
81 |
|
82 |
|
83 |
+
|
84 |
+
def parse_output_llm_with_sources(output):
|
85 |
+
# Split the content into a list of text and "[Doc X]" references
|
86 |
+
content_parts = re.split(r'\[(Doc\s?\d+(?:,\s?Doc\s?\d+)*)\]', output)
|
87 |
+
parts = []
|
88 |
+
for part in content_parts:
|
89 |
+
if part.startswith("Doc"):
|
90 |
+
subparts = part.split(",")
|
91 |
+
subparts = [subpart.lower().replace("doc","").strip() for subpart in subparts]
|
92 |
+
subparts = [f"""<a href="#doc{subpart}" class="a-doc-ref" target="_self"><span class='doc-ref'><sup>{subpart}</sup></span></a>""" for subpart in subparts]
|
93 |
+
parts.append("".join(subparts))
|
94 |
+
else:
|
95 |
+
parts.append(part)
|
96 |
+
content_parts = "".join(parts)
|
97 |
+
return content_parts
|
98 |
+
|
99 |
+
|
100 |
# Create vectorstore and retriever
|
101 |
+
vectorstore = get_pinecone_vectorstore(embeddings_function)
|
102 |
+
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
+
|
105 |
+
def make_pairs(lst):
|
106 |
+
"""from a list of even lenght, make tupple pairs"""
|
107 |
+
return [(lst[i], lst[i + 1]) for i in range(0, len(lst), 2)]
|
108 |
+
|
109 |
+
|
110 |
+
def serialize_docs(docs):
|
111 |
+
new_docs = []
|
112 |
+
for doc in docs:
|
113 |
+
new_doc = {}
|
114 |
+
new_doc["page_content"] = doc.page_content
|
115 |
+
new_doc["metadata"] = doc.metadata
|
116 |
+
new_docs.append(new_doc)
|
117 |
+
return new_docs
|
118 |
+
|
119 |
+
|
120 |
+
|
121 |
+
async def chat(query,history,audience,sources,reports):
|
122 |
+
"""taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of:
|
123 |
+
(messages in gradio format, messages in langchain format, source documents)"""
|
124 |
+
|
125 |
+
print(f">> NEW QUESTION : {query}")
|
126 |
+
|
127 |
+
if audience == "Children":
|
128 |
+
audience_prompt = audience_prompts["children"]
|
129 |
+
elif audience == "General public":
|
130 |
+
audience_prompt = audience_prompts["general"]
|
131 |
+
elif audience == "Experts":
|
132 |
+
audience_prompt = audience_prompts["experts"]
|
133 |
+
else:
|
134 |
+
audience_prompt = audience_prompts["experts"]
|
135 |
+
|
136 |
+
# Prepare default values
|
137 |
+
if len(sources) == 0:
|
138 |
+
sources = ["IPCC"]
|
139 |
+
|
140 |
+
if len(reports) == 0:
|
141 |
+
reports = []
|
142 |
+
|
143 |
+
retriever = ClimateQARetriever(vectorstore=vectorstore,sources = sources,min_size = 200,reports = reports,k_summary = 3,k_total = 15,threshold=0.5)
|
144 |
+
rag_chain = make_rag_chain(retriever,llm)
|
145 |
+
|
146 |
+
inputs = {"query": query,"audience": audience_prompt}
|
147 |
+
result = rag_chain.astream_log(inputs) #{"callbacks":[MyCustomAsyncHandler()]})
|
148 |
+
# result = rag_chain.stream(inputs)
|
149 |
+
|
150 |
+
path_reformulation = "/logs/reformulation/final_output"
|
151 |
+
path_keywords = "/logs/keywords/final_output"
|
152 |
+
path_retriever = "/logs/find_documents/final_output"
|
153 |
+
path_answer = "/logs/answer/streamed_output_str/-"
|
154 |
+
|
155 |
+
docs_html = ""
|
156 |
+
output_query = ""
|
157 |
+
output_language = ""
|
158 |
+
output_keywords = ""
|
159 |
+
gallery = []
|
160 |
+
|
161 |
+
try:
|
162 |
+
async for op in result:
|
163 |
+
|
164 |
+
op = op.ops[0]
|
165 |
+
|
166 |
+
if op['path'] == path_reformulation: # reforulated question
|
167 |
+
try:
|
168 |
+
output_language = op['value']["language"] # str
|
169 |
+
output_query = op["value"]["question"]
|
170 |
+
except Exception as e:
|
171 |
+
raise gr.Error(f"ClimateQ&A Error: {e} - The error has been noted, try another question and if the error remains, you can contact us :)")
|
172 |
+
|
173 |
+
if op["path"] == path_keywords:
|
174 |
+
try:
|
175 |
+
output_keywords = op['value']["keywords"] # str
|
176 |
+
output_keywords = " AND ".join(output_keywords)
|
177 |
+
except Exception as e:
|
178 |
+
pass
|
179 |
+
|
180 |
+
|
181 |
+
elif op['path'] == path_retriever: # documents
|
182 |
+
try:
|
183 |
+
docs = op['value']['docs'] # List[Document]
|
184 |
+
docs_html = []
|
185 |
+
for i, d in enumerate(docs, 1):
|
186 |
+
docs_html.append(make_html_source(d, i))
|
187 |
+
docs_html = "".join(docs_html)
|
188 |
+
except TypeError:
|
189 |
+
print("No documents found")
|
190 |
+
print("op: ",op)
|
191 |
+
continue
|
192 |
+
|
193 |
+
elif op['path'] == path_answer: # final answer
|
194 |
+
new_token = op['value'] # str
|
195 |
+
# time.sleep(0.01)
|
196 |
+
previous_answer = history[-1][1]
|
197 |
+
previous_answer = previous_answer if previous_answer is not None else ""
|
198 |
+
answer_yet = previous_answer + new_token
|
199 |
+
answer_yet = parse_output_llm_with_sources(answer_yet)
|
200 |
+
history[-1] = (query,answer_yet)
|
201 |
+
|
202 |
+
|
203 |
+
|
204 |
+
else:
|
205 |
+
continue
|
206 |
+
|
207 |
+
history = [tuple(x) for x in history]
|
208 |
+
yield history,docs_html,output_query,output_language,gallery,output_query,output_keywords
|
209 |
+
|
210 |
+
except Exception as e:
|
211 |
+
raise gr.Error(f"{e}")
|
212 |
+
|
213 |
+
|
214 |
+
try:
|
215 |
+
# Log answer on Azure Blob Storage
|
216 |
+
if os.getenv("GRADIO_ENV") != "local":
|
217 |
+
timestamp = str(datetime.now().timestamp())
|
218 |
+
file = timestamp + ".json"
|
219 |
+
prompt = history[-1][0]
|
220 |
+
logs = {
|
221 |
+
"user_id": str(user_id),
|
222 |
+
"prompt": prompt,
|
223 |
+
"query": prompt,
|
224 |
+
"question":output_query,
|
225 |
+
"sources":sources,
|
226 |
+
"docs":serialize_docs(docs),
|
227 |
+
"answer": history[-1][1],
|
228 |
+
"time": timestamp,
|
229 |
+
}
|
230 |
+
log_on_azure(file, logs, share_client)
|
231 |
+
except Exception as e:
|
232 |
+
print(f"Error logging on Azure Blob Storage: {e}")
|
233 |
+
raise gr.Error(f"ClimateQ&A Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)")
|
234 |
+
|
235 |
+
image_dict = {}
|
236 |
+
for i,doc in enumerate(docs):
|
237 |
+
|
238 |
+
if doc.metadata["chunk_type"] == "image":
|
239 |
+
try:
|
240 |
+
key = f"Image {i+1}"
|
241 |
+
image_path = doc.metadata["image_path"].split("documents/")[1]
|
242 |
+
img = get_image_from_azure_blob_storage(image_path)
|
243 |
+
|
244 |
+
# Convert the image to a byte buffer
|
245 |
+
buffered = BytesIO()
|
246 |
+
img.save(buffered, format="PNG")
|
247 |
+
img_str = base64.b64encode(buffered.getvalue()).decode()
|
248 |
+
|
249 |
+
# Embedding the base64 string in Markdown
|
250 |
+
markdown_image = f""
|
251 |
+
image_dict[key] = {"img":img,"md":markdown_image,"caption":doc.page_content,"key":key,"figure_code":doc.metadata["figure_code"]}
|
252 |
+
except Exception as e:
|
253 |
+
print(f"Skipped adding image {i} because of {e}")
|
254 |
+
|
255 |
+
if len(image_dict) > 0:
|
256 |
+
|
257 |
+
gallery = [x["img"] for x in list(image_dict.values())]
|
258 |
+
img = list(image_dict.values())[0]
|
259 |
+
img_md = img["md"]
|
260 |
+
img_caption = img["caption"]
|
261 |
+
img_code = img["figure_code"]
|
262 |
+
if img_code != "N/A":
|
263 |
+
img_name = f"{img['key']} - {img['figure_code']}"
|
264 |
+
else:
|
265 |
+
img_name = f"{img['key']}"
|
266 |
+
|
267 |
+
answer_yet = history[-1][1] + f"\n\n{img_md}\n<p class='chatbot-caption'><b>{img_name}</b> - {img_caption}</p>"
|
268 |
+
history[-1] = (history[-1][0],answer_yet)
|
269 |
+
history = [tuple(x) for x in history]
|
270 |
+
|
271 |
+
# gallery = [x.metadata["image_path"] for x in docs if (len(x.metadata["image_path"]) > 0 and "IAS" in x.metadata["image_path"])]
|
272 |
+
# if len(gallery) > 0:
|
273 |
+
# gallery = list(set("|".join(gallery).split("|")))
|
274 |
+
# gallery = [get_image_from_azure_blob_storage(x) for x in gallery]
|
275 |
+
|
276 |
+
yield history,docs_html,output_query,output_language,gallery,output_query,output_keywords
|
277 |
+
|
278 |
+
|
279 |
+
def make_html_source(source,i):
|
280 |
+
meta = source.metadata
|
281 |
+
# content = source.page_content.split(":",1)[1].strip()
|
282 |
+
content = source.page_content.strip()
|
283 |
+
|
284 |
+
toc_levels = []
|
285 |
+
for j in range(2):
|
286 |
+
level = meta[f"toc_level{j}"]
|
287 |
+
if level != "N/A":
|
288 |
+
toc_levels.append(level)
|
289 |
+
else:
|
290 |
+
break
|
291 |
+
toc_levels = " > ".join(toc_levels)
|
292 |
+
|
293 |
+
if len(toc_levels) > 0:
|
294 |
+
name = f"<b>{toc_levels}</b><br/>{meta['name']}"
|
295 |
+
else:
|
296 |
+
name = meta['name']
|
297 |
+
|
298 |
+
if meta["chunk_type"] == "text":
|
299 |
+
|
300 |
+
card = f"""
|
301 |
+
<div class="card" id="doc{i}">
|
302 |
+
<div class="card-content">
|
303 |
+
<h2>Doc {i} - {meta['short_name']} - Page {int(meta['page_number'])}</h2>
|
304 |
+
<p>{content}</p>
|
305 |
+
</div>
|
306 |
+
<div class="card-footer">
|
307 |
+
<span>{name}</span>
|
308 |
+
<a href="{meta['url']}#page={int(meta['page_number'])}" target="_blank" class="pdf-link">
|
309 |
+
<span role="img" aria-label="Open PDF">🔗</span>
|
310 |
+
</a>
|
311 |
+
</div>
|
312 |
+
</div>
|
313 |
+
"""
|
314 |
+
|
315 |
+
else:
|
316 |
+
|
317 |
+
if meta["figure_code"] != "N/A":
|
318 |
+
title = f"{meta['figure_code']} - {meta['short_name']}"
|
319 |
+
else:
|
320 |
+
title = f"{meta['short_name']}"
|
321 |
+
|
322 |
+
card = f"""
|
323 |
+
<div class="card card-image">
|
324 |
+
<div class="card-content">
|
325 |
+
<h2>Image {i} - {title} - Page {int(meta['page_number'])}</h2>
|
326 |
+
<p>{content}</p>
|
327 |
+
<p class='ai-generated'>AI-generated description</p>
|
328 |
+
</div>
|
329 |
+
<div class="card-footer">
|
330 |
+
<span>{name}</span>
|
331 |
+
<a href="{meta['url']}#page={int(meta['page_number'])}" target="_blank" class="pdf-link">
|
332 |
+
<span role="img" aria-label="Open PDF">🔗</span>
|
333 |
+
</a>
|
334 |
+
</div>
|
335 |
+
</div>
|
336 |
+
"""
|
337 |
+
|
338 |
+
return card
|
339 |
+
|
340 |
+
|
341 |
+
|
342 |
+
# else:
|
343 |
+
# docs_string = "No relevant passages found in the climate science reports (IPCC and IPBES)"
|
344 |
+
# complete_response = "**No relevant passages found in the climate science reports (IPCC and IPBES), you may want to ask a more specific question (specifying your question on climate issues).**"
|
345 |
+
# messages.append({"role": "assistant", "content": complete_response})
|
346 |
+
# gradio_format = make_pairs([a["content"] for a in messages[1:]])
|
347 |
+
# yield gradio_format, messages, docs_string
|
348 |
+
|
349 |
+
|
350 |
+
def save_feedback(feed: str, user_id):
|
351 |
+
if len(feed) > 1:
|
352 |
+
timestamp = str(datetime.now().timestamp())
|
353 |
+
file = user_id + timestamp + ".json"
|
354 |
+
logs = {
|
355 |
+
"user_id": user_id,
|
356 |
+
"feedback": feed,
|
357 |
+
"time": timestamp,
|
358 |
+
}
|
359 |
+
log_on_azure(file, logs, share_client)
|
360 |
+
return "Feedback submitted, thank you!"
|
361 |
+
|
362 |
+
|
363 |
+
|
364 |
+
|
365 |
+
def log_on_azure(file, logs, share_client):
|
366 |
+
logs = json.dumps(logs)
|
367 |
+
file_client = share_client.get_file_client(file)
|
368 |
+
file_client.upload_file(logs)
|
369 |
+
|
370 |
+
|
371 |
+
def generate_keywords(query):
|
372 |
+
chain = make_keywords_chain(llm)
|
373 |
+
keywords = chain.invoke(query)
|
374 |
+
keywords = " AND ".join(keywords["keywords"])
|
375 |
+
return keywords
|
376 |
+
|
377 |
+
|
378 |
+
|
379 |
+
papers_cols_widths = {
|
380 |
+
"doc":50,
|
381 |
+
"id":100,
|
382 |
+
"title":300,
|
383 |
+
"doi":100,
|
384 |
+
"publication_year":100,
|
385 |
+
"abstract":500,
|
386 |
+
"rerank_score":100,
|
387 |
+
"is_oa":50,
|
388 |
+
}
|
389 |
+
|
390 |
+
papers_cols = list(papers_cols_widths.keys())
|
391 |
+
papers_cols_widths = list(papers_cols_widths.values())
|
392 |
+
|
393 |
+
async def find_papers(query, keywords,after):
|
394 |
+
|
395 |
+
summary = ""
|
396 |
+
|
397 |
+
df_works = oa.search(keywords,after = after)
|
398 |
+
df_works = df_works.dropna(subset=["abstract"])
|
399 |
+
df_works = oa.rerank(query,df_works,reranker)
|
400 |
+
df_works = df_works.sort_values("rerank_score",ascending=False)
|
401 |
+
G = oa.make_network(df_works)
|
402 |
+
|
403 |
+
height = "750px"
|
404 |
+
network = oa.show_network(G,color_by = "rerank_score",notebook=False,height = height)
|
405 |
+
network_html = network.generate_html()
|
406 |
+
|
407 |
+
network_html = network_html.replace("'", "\"")
|
408 |
+
css_to_inject = "<style>#mynetwork { border: none !important; } .card { border: none !important; }</style>"
|
409 |
+
network_html = network_html + css_to_inject
|
410 |
+
|
411 |
+
|
412 |
+
network_html = f"""<iframe style="width: 100%; height: {height};margin:0 auto" name="result" allow="midi; geolocation; microphone; camera;
|
413 |
+
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
|
414 |
+
allow-scripts allow-same-origin allow-popups
|
415 |
+
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
|
416 |
+
allowpaymentrequest="" frameborder="0" srcdoc='{network_html}'></iframe>"""
|
417 |
+
|
418 |
+
|
419 |
+
docs = df_works["content"].head(15).tolist()
|
420 |
+
|
421 |
+
df_works = df_works.reset_index(drop = True).reset_index().rename(columns = {"index":"doc"})
|
422 |
+
df_works["doc"] = df_works["doc"] + 1
|
423 |
+
df_works = df_works[papers_cols]
|
424 |
+
|
425 |
+
yield df_works,network_html,summary
|
426 |
+
|
427 |
+
chain = make_rag_papers_chain(llm)
|
428 |
+
result = chain.astream_log({"question": query,"docs": docs,"language":"English"})
|
429 |
+
path_answer = "/logs/StrOutputParser/streamed_output/-"
|
430 |
+
|
431 |
+
async for op in result:
|
432 |
+
|
433 |
+
op = op.ops[0]
|
434 |
+
|
435 |
+
if op['path'] == path_answer: # reforulated question
|
436 |
+
new_token = op['value'] # str
|
437 |
+
summary += new_token
|
438 |
+
else:
|
439 |
+
continue
|
440 |
+
yield df_works,network_html,summary
|
441 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
442 |
|
443 |
|
444 |
# --------------------------------------------------------------------
|
|
|
446 |
# --------------------------------------------------------------------
|
447 |
|
448 |
|
449 |
+
init_prompt = """
|
450 |
+
Hello, I am ClimateQ&A, a conversational assistant designed to help you understand climate change and biodiversity loss. I will answer your questions by **sifting through the IPCC and IPBES scientific reports**.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
451 |
|
452 |
+
❓ How to use
|
453 |
+
- **Language**: You can ask me your questions in any language.
|
454 |
+
- **Audience**: You can specify your audience (children, general public, experts) to get a more adapted answer.
|
455 |
+
- **Sources**: You can choose to search in the IPCC or IPBES reports, or both.
|
|
|
|
|
|
|
456 |
|
457 |
+
⚠️ Limitations
|
458 |
+
*Please note that the AI is not perfect and may sometimes give irrelevant answers. If you are not satisfied with the answer, please ask a more specific question or report your feedback to help us improve the system.*
|
459 |
|
460 |
+
What do you want to learn ?
|
461 |
+
"""
|
462 |
+
|
463 |
+
|
464 |
+
def vote(data: gr.LikeData):
|
465 |
+
if data.liked:
|
466 |
+
print(data.value)
|
467 |
+
else:
|
468 |
+
print(data)
|
469 |
+
|
470 |
+
|
471 |
+
|
472 |
+
with gr.Blocks(title="Climate Q&A", css="style.css", theme=theme,elem_id = "main-component") as demo:
|
473 |
+
# user_id_state = gr.State([user_id])
|
474 |
+
|
475 |
+
with gr.Tab("ClimateQ&A"):
|
476 |
+
|
477 |
+
with gr.Row(elem_id="chatbot-row"):
|
478 |
+
with gr.Column(scale=2):
|
479 |
+
# state = gr.State([system_template])
|
480 |
+
chatbot = gr.Chatbot(
|
481 |
+
value=[(None,init_prompt)],
|
482 |
+
show_copy_button=True,show_label = False,elem_id="chatbot",layout = "panel",
|
483 |
+
avatar_images = (None,"https://i.ibb.co/YNyd5W2/logo4.png"),
|
484 |
+
)#,avatar_images = ("assets/logo4.png",None))
|
485 |
+
|
486 |
+
# bot.like(vote,None,None)
|
487 |
+
|
488 |
+
|
489 |
+
|
490 |
+
with gr.Row(elem_id = "input-message"):
|
491 |
+
textbox=gr.Textbox(placeholder="Ask me anything here!",show_label=False,scale=7,lines = 1,interactive = True,elem_id="input-textbox")
|
492 |
+
# submit = gr.Button("",elem_id = "submit-button",scale = 1,interactive = True,icon = "https://static-00.iconduck.com/assets.00/settings-icon-2048x2046-cw28eevx.png")
|
493 |
+
|
494 |
+
|
495 |
+
with gr.Column(scale=1, variant="panel",elem_id = "right-panel"):
|
496 |
+
|
497 |
+
|
498 |
+
with gr.Tabs() as tabs:
|
499 |
+
with gr.TabItem("Examples",elem_id = "tab-examples",id = 0):
|
500 |
+
|
501 |
+
examples_hidden = gr.Textbox(visible = False)
|
502 |
+
first_key = list(QUESTIONS.keys())[0]
|
503 |
+
dropdown_samples = gr.Dropdown(QUESTIONS.keys(),value = first_key,interactive = True,show_label = True,label = "Select a category of sample questions",elem_id = "dropdown-samples")
|
504 |
+
|
505 |
+
samples = []
|
506 |
+
for i,key in enumerate(QUESTIONS.keys()):
|
507 |
+
|
508 |
+
examples_visible = True if i == 0 else False
|
509 |
+
|
510 |
+
with gr.Row(visible = examples_visible) as group_examples:
|
511 |
+
|
512 |
+
examples_questions = gr.Examples(
|
513 |
+
QUESTIONS[key],
|
514 |
+
[examples_hidden],
|
515 |
+
examples_per_page=8,
|
516 |
+
run_on_click=False,
|
517 |
+
elem_id=f"examples{i}",
|
518 |
+
api_name=f"examples{i}",
|
519 |
+
# label = "Click on the example question or enter your own",
|
520 |
+
# cache_examples=True,
|
521 |
+
)
|
522 |
+
|
523 |
+
samples.append(group_examples)
|
524 |
+
|
525 |
+
|
526 |
+
with gr.Tab("Sources",elem_id = "tab-citations",id = 1):
|
527 |
+
sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox")
|
528 |
+
docs_textbox = gr.State("")
|
529 |
+
|
530 |
+
# with Modal(visible = False) as config_modal:
|
531 |
+
with gr.Tab("Configuration",elem_id = "tab-config",id = 2):
|
532 |
+
|
533 |
+
gr.Markdown("Reminder: You can talk in any language, ClimateQ&A is multi-lingual!")
|
534 |
+
|
535 |
+
|
536 |
+
dropdown_sources = gr.CheckboxGroup(
|
537 |
+
["IPCC", "IPBES","IPOS"],
|
538 |
+
label="Select source",
|
539 |
+
value=["IPCC"],
|
540 |
+
interactive=True,
|
541 |
+
)
|
542 |
+
|
543 |
+
dropdown_reports = gr.Dropdown(
|
544 |
+
POSSIBLE_REPORTS,
|
545 |
+
label="Or select specific reports",
|
546 |
+
multiselect=True,
|
547 |
+
value=None,
|
548 |
+
interactive=True,
|
549 |
+
)
|
550 |
+
|
551 |
+
dropdown_audience = gr.Dropdown(
|
552 |
+
["Children","General public","Experts"],
|
553 |
+
label="Select audience",
|
554 |
+
value="Experts",
|
555 |
+
interactive=True,
|
556 |
+
)
|
557 |
+
|
558 |
+
output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False)
|
559 |
+
output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False)
|
560 |
+
|
561 |
+
|
562 |
+
|
563 |
+
|
564 |
+
|
565 |
+
|
566 |
+
#---------------------------------------------------------------------------------------
|
567 |
+
# OTHER TABS
|
568 |
+
#---------------------------------------------------------------------------------------
|
569 |
+
|
570 |
+
|
571 |
+
with gr.Tab("Figures",elem_id = "tab-images",elem_classes = "max-height other-tabs"):
|
572 |
+
gallery_component = gr.Gallery()
|
573 |
+
|
574 |
+
with gr.Tab("Papers (beta)",elem_id = "tab-papers",elem_classes = "max-height other-tabs"):
|
575 |
+
|
576 |
+
with gr.Row():
|
577 |
+
with gr.Column(scale=1):
|
578 |
+
query_papers = gr.Textbox(placeholder="Question",show_label=False,lines = 1,interactive = True,elem_id="query-papers")
|
579 |
+
keywords_papers = gr.Textbox(placeholder="Keywords",show_label=False,lines = 1,interactive = True,elem_id="keywords-papers")
|
580 |
+
after = gr.Slider(minimum=1950,maximum=2023,step=1,value=1960,label="Publication date",show_label=True,interactive=True,elem_id="date-papers")
|
581 |
+
search_papers = gr.Button("Search",elem_id="search-papers",interactive=True)
|
582 |
+
|
583 |
+
with gr.Column(scale=7):
|
584 |
+
|
585 |
+
with gr.Tab("Summary",elem_id="papers-summary-tab"):
|
586 |
+
papers_summary = gr.Markdown(visible=True,elem_id="papers-summary")
|
587 |
+
|
588 |
+
with gr.Tab("Relevant papers",elem_id="papers-results-tab"):
|
589 |
+
papers_dataframe = gr.Dataframe(visible=True,elem_id="papers-table",headers = papers_cols)
|
590 |
+
|
591 |
+
with gr.Tab("Citations network",elem_id="papers-network-tab"):
|
592 |
+
citations_network = gr.HTML(visible=True,elem_id="papers-citations-network")
|
593 |
+
|
594 |
+
|
595 |
+
|
596 |
+
with gr.Tab("About",elem_classes = "max-height other-tabs"):
|
597 |
+
with gr.Row():
|
598 |
+
with gr.Column(scale=1):
|
599 |
+
gr.Markdown("See more info at [https://climateqa.com](https://climateqa.com/docs/intro/)")
|
600 |
+
|
601 |
+
|
602 |
+
def start_chat(query,history):
|
603 |
+
history = history + [(query,None)]
|
604 |
+
history = [tuple(x) for x in history]
|
605 |
+
return (gr.update(interactive = False),gr.update(selected=1),history)
|
606 |
+
|
607 |
+
def finish_chat():
|
608 |
+
return (gr.update(interactive = True,value = ""))
|
609 |
+
|
610 |
+
(textbox
|
611 |
+
.submit(start_chat, [textbox,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_textbox")
|
612 |
+
.then(chat, [textbox,chatbot,dropdown_audience, dropdown_sources,dropdown_reports], [chatbot,sources_textbox,output_query,output_language,gallery_component,query_papers,keywords_papers],concurrency_limit = 8,api_name = "chat_textbox")
|
613 |
+
.then(finish_chat, None, [textbox],api_name = "finish_chat_textbox")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
614 |
)
|
615 |
+
|
616 |
+
(examples_hidden
|
617 |
+
.change(start_chat, [examples_hidden,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_examples")
|
618 |
+
.then(chat, [examples_hidden,chatbot,dropdown_audience, dropdown_sources,dropdown_reports], [chatbot,sources_textbox,output_query,output_language,gallery_component,query_papers,keywords_papers],concurrency_limit = 8,api_name = "chat_examples")
|
619 |
+
.then(finish_chat, None, [textbox],api_name = "finish_chat_examples")
|
620 |
)
|
621 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
622 |
|
623 |
+
def change_sample_questions(key):
|
624 |
+
index = list(QUESTIONS.keys()).index(key)
|
625 |
+
visible_bools = [False] * len(samples)
|
626 |
+
visible_bools[index] = True
|
627 |
+
return [gr.update(visible=visible_bools[i]) for i in range(len(samples))]
|
628 |
+
|
|
|
629 |
|
|
|
|
|
|
|
630 |
|
631 |
+
dropdown_samples.change(change_sample_questions,dropdown_samples,samples)
|
632 |
|
633 |
+
query_papers.submit(generate_keywords,[query_papers], [keywords_papers])
|
634 |
+
search_papers.click(find_papers,[query_papers,keywords_papers,after], [papers_dataframe,citations_network,papers_summary])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
635 |
|
636 |
+
# # textbox.submit(predict_climateqa,[textbox,bot],[None,bot,sources_textbox])
|
637 |
+
# (textbox
|
638 |
+
# .submit(answer_user, [textbox,examples_hidden, bot], [textbox, bot],queue = False)
|
639 |
+
# .success(change_tab,None,tabs)
|
640 |
+
# .success(fetch_sources,[textbox,dropdown_sources], [textbox,sources_textbox,docs_textbox,output_query,output_language])
|
641 |
+
# .success(answer_bot, [textbox,bot,docs_textbox,output_query,output_language,dropdown_audience], [textbox,bot],queue = True)
|
642 |
+
# .success(lambda x : textbox,[textbox],[textbox])
|
643 |
+
# )
|
644 |
|
645 |
+
# (examples_hidden
|
646 |
+
# .change(answer_user_example, [textbox,examples_hidden, bot], [textbox, bot],queue = False)
|
647 |
+
# .success(change_tab,None,tabs)
|
648 |
+
# .success(fetch_sources,[textbox,dropdown_sources], [textbox,sources_textbox,docs_textbox,output_query,output_language])
|
649 |
+
# .success(answer_bot, [textbox,bot,docs_textbox,output_query,output_language,dropdown_audience], [textbox,bot],queue=True)
|
650 |
+
# .success(lambda x : textbox,[textbox],[textbox])
|
651 |
+
# )
|
652 |
+
# submit_button.click(answer_user, [textbox, bot], [textbox, bot], queue=True).then(
|
653 |
+
# answer_bot, [textbox,bot,dropdown_audience,dropdown_sources], [textbox,bot,sources_textbox]
|
654 |
+
# )
|
655 |
|
|
|
|
|
|
|
|
|
656 |
|
657 |
+
# with Modal(visible=True) as first_modal:
|
658 |
+
# gr.Markdown("# Welcome to ClimateQ&A !")
|
659 |
|
660 |
+
# gr.Markdown("### Examples")
|
661 |
|
662 |
+
# examples = gr.Examples(
|
663 |
+
# ["Yo ça roule","ça boume"],
|
664 |
+
# [examples_hidden],
|
665 |
+
# examples_per_page=8,
|
666 |
+
# run_on_click=False,
|
667 |
+
# elem_id="examples",
|
668 |
+
# api_name="examples",
|
669 |
+
# )
|
670 |
+
|
671 |
+
|
672 |
+
# submit.click(lambda: Modal(visible=True), None, config_modal)
|
673 |
+
|
674 |
|
675 |
+
demo.queue()
|
676 |
|
677 |
+
demo.launch()
|
|
climateqa/chat.py
DELETED
@@ -1,194 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from datetime import datetime
|
3 |
-
import gradio as gr
|
4 |
-
# from .agent import agent
|
5 |
-
from gradio import ChatMessage
|
6 |
-
from langgraph.graph.state import CompiledStateGraph
|
7 |
-
import json
|
8 |
-
|
9 |
-
from .handle_stream_events import (
|
10 |
-
init_audience,
|
11 |
-
handle_retrieved_documents,
|
12 |
-
convert_to_docs_to_html,
|
13 |
-
stream_answer,
|
14 |
-
handle_retrieved_owid_graphs,
|
15 |
-
)
|
16 |
-
from .logging import (
|
17 |
-
log_interaction
|
18 |
-
)
|
19 |
-
|
20 |
-
# Chat functions
|
21 |
-
def start_chat(query, history, search_only):
|
22 |
-
history = history + [ChatMessage(role="user", content=query)]
|
23 |
-
if not search_only:
|
24 |
-
return (gr.update(interactive=False), gr.update(selected=1), history, [])
|
25 |
-
else:
|
26 |
-
return (gr.update(interactive=False), gr.update(selected=2), history, [])
|
27 |
-
|
28 |
-
def finish_chat():
|
29 |
-
return gr.update(interactive=True, value="")
|
30 |
-
|
31 |
-
def handle_numerical_data(event):
|
32 |
-
if event["name"] == "retrieve_drias_data" and event["event"] == "on_chain_end":
|
33 |
-
numerical_data = event["data"]["output"]["drias_data"]
|
34 |
-
sql_query = event["data"]["output"]["drias_sql_query"]
|
35 |
-
return numerical_data, sql_query
|
36 |
-
return None, None
|
37 |
-
|
38 |
-
# Main chat function
|
39 |
-
async def chat_stream(
|
40 |
-
agent : CompiledStateGraph,
|
41 |
-
query: str,
|
42 |
-
history: list[ChatMessage],
|
43 |
-
audience: str,
|
44 |
-
sources: list[str],
|
45 |
-
reports: list[str],
|
46 |
-
relevant_content_sources_selection: list[str],
|
47 |
-
search_only: bool,
|
48 |
-
share_client,
|
49 |
-
user_id: str
|
50 |
-
) -> tuple[list, str, str, str, list, str]:
|
51 |
-
"""Process a chat query and return response with relevant sources and visualizations.
|
52 |
-
|
53 |
-
Args:
|
54 |
-
query (str): The user's question
|
55 |
-
history (list): Chat message history
|
56 |
-
audience (str): Target audience type
|
57 |
-
sources (list): Knowledge base sources to search
|
58 |
-
reports (list): Specific reports to search within sources
|
59 |
-
relevant_content_sources_selection (list): Types of content to retrieve (figures, papers, etc)
|
60 |
-
search_only (bool): Whether to only search without generating answer
|
61 |
-
|
62 |
-
Yields:
|
63 |
-
tuple: Contains:
|
64 |
-
- history: Updated chat history
|
65 |
-
- docs_html: HTML of retrieved documents
|
66 |
-
- output_query: Processed query
|
67 |
-
- output_language: Detected language
|
68 |
-
- related_contents: Related content
|
69 |
-
- graphs_html: HTML of relevant graphs
|
70 |
-
"""
|
71 |
-
# Log incoming question
|
72 |
-
date_now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
73 |
-
print(f">> NEW QUESTION ({date_now}) : {query}")
|
74 |
-
|
75 |
-
audience_prompt = init_audience(audience)
|
76 |
-
sources = sources or ["IPCC", "IPBES"]
|
77 |
-
reports = reports or []
|
78 |
-
relevant_history_discussion = history[-2:] if len(history) > 1 else []
|
79 |
-
|
80 |
-
# Prepare inputs for agent
|
81 |
-
inputs = {
|
82 |
-
"user_input": query,
|
83 |
-
"audience": audience_prompt,
|
84 |
-
"sources_input": sources,
|
85 |
-
"relevant_content_sources_selection": relevant_content_sources_selection,
|
86 |
-
"search_only": search_only,
|
87 |
-
"reports": reports,
|
88 |
-
"chat_history": relevant_history_discussion,
|
89 |
-
}
|
90 |
-
|
91 |
-
# Get streaming events from agent
|
92 |
-
result = agent.astream_events(inputs, version="v1")
|
93 |
-
|
94 |
-
# Initialize state variables
|
95 |
-
docs = []
|
96 |
-
related_contents = []
|
97 |
-
docs_html = ""
|
98 |
-
new_docs_html = ""
|
99 |
-
output_query = ""
|
100 |
-
output_language = ""
|
101 |
-
output_keywords = ""
|
102 |
-
start_streaming = False
|
103 |
-
graphs_html = ""
|
104 |
-
used_documents = []
|
105 |
-
retrieved_contents = []
|
106 |
-
answer_message_content = ""
|
107 |
-
vanna_data = {}
|
108 |
-
follow_up_examples = gr.Dataset(samples=[])
|
109 |
-
|
110 |
-
# Define processing steps
|
111 |
-
steps_display = {
|
112 |
-
"categorize_intent": ("🔄️ Analyzing user message", True),
|
113 |
-
"transform_query": ("🔄️ Thinking step by step to answer the question", True),
|
114 |
-
"retrieve_documents": ("🔄️ Searching in the knowledge base", False),
|
115 |
-
"retrieve_local_data": ("🔄️ Searching in the knowledge base", False),
|
116 |
-
}
|
117 |
-
|
118 |
-
try:
|
119 |
-
# Process streaming events
|
120 |
-
async for event in result:
|
121 |
-
|
122 |
-
if "langgraph_node" in event["metadata"]:
|
123 |
-
node = event["metadata"]["langgraph_node"]
|
124 |
-
|
125 |
-
# Handle document retrieval
|
126 |
-
if event["event"] == "on_chain_end" and event["name"] in ["retrieve_documents","retrieve_local_data"] and event["data"]["output"] != None:
|
127 |
-
history, used_documents, retrieved_contents = handle_retrieved_documents(
|
128 |
-
event, history, used_documents, retrieved_contents
|
129 |
-
)
|
130 |
-
# Handle Vanna retrieval
|
131 |
-
# if event["event"] == "on_chain_end" and event["name"] in ["retrieve_documents","retrieve_local_data"] and event["data"]["output"] != None:
|
132 |
-
# df_output_vanna, sql_query = handle_numerical_data(
|
133 |
-
# event
|
134 |
-
# )
|
135 |
-
# vanna_data = {"df_output": df_output_vanna, "sql_query": sql_query}
|
136 |
-
|
137 |
-
|
138 |
-
if event["event"] == "on_chain_end" and event["name"] == "answer_search" :
|
139 |
-
docs = event["data"]["input"]["documents"]
|
140 |
-
docs_html = convert_to_docs_to_html(docs)
|
141 |
-
related_contents = event["data"]["input"]["related_contents"]
|
142 |
-
|
143 |
-
# Handle intent categorization
|
144 |
-
elif (event["event"] == "on_chain_end" and
|
145 |
-
node == "categorize_intent" and
|
146 |
-
event["name"] == "_write"):
|
147 |
-
intent = event["data"]["output"]["intent"]
|
148 |
-
output_language = event["data"]["output"].get("language", "English")
|
149 |
-
history[-1].content = f"Language identified: {output_language}\nIntent identified: {intent}"
|
150 |
-
|
151 |
-
# Handle processing steps display
|
152 |
-
elif event["name"] in steps_display and event["event"] == "on_chain_start":
|
153 |
-
event_description, display_output = steps_display[node]
|
154 |
-
if (not hasattr(history[-1], 'metadata') or
|
155 |
-
history[-1].metadata["title"] != event_description):
|
156 |
-
history.append(ChatMessage(
|
157 |
-
role="assistant",
|
158 |
-
content="",
|
159 |
-
metadata={'title': event_description}
|
160 |
-
))
|
161 |
-
|
162 |
-
# Handle answer streaming
|
163 |
-
elif (event["name"] != "transform_query" and
|
164 |
-
event["event"] == "on_chat_model_stream" and
|
165 |
-
node in ["answer_rag","answer_rag_no_docs", "answer_search", "answer_chitchat"]):
|
166 |
-
history, start_streaming, answer_message_content = stream_answer(
|
167 |
-
history, event, start_streaming, answer_message_content
|
168 |
-
)
|
169 |
-
|
170 |
-
# Handle graph retrieval
|
171 |
-
elif event["name"] in ["retrieve_graphs", "retrieve_graphs_ai"] and event["event"] == "on_chain_end":
|
172 |
-
graphs_html = handle_retrieved_owid_graphs(event, graphs_html)
|
173 |
-
|
174 |
-
# Handle query transformation
|
175 |
-
if event["name"] == "transform_query" and event["event"] == "on_chain_end":
|
176 |
-
if hasattr(history[-1], "content"):
|
177 |
-
sub_questions = [q["question"] + "-> relevant sources : " + str(q["sources"]) for q in event["data"]["output"]["questions_list"]]
|
178 |
-
history[-1].content += "Decompose question into sub-questions:\n\n - " + "\n - ".join(sub_questions)
|
179 |
-
|
180 |
-
# Handle follow up questions
|
181 |
-
if event["name"] == "generate_follow_up" and event["event"] == "on_chain_end":
|
182 |
-
follow_up_examples = event["data"]["output"].get("follow_up_questions", [])
|
183 |
-
follow_up_examples = gr.Dataset(samples= [ [question] for question in follow_up_examples ])
|
184 |
-
|
185 |
-
yield history, docs_html, output_query, output_language, related_contents, graphs_html, follow_up_examples#, vanna_data
|
186 |
-
|
187 |
-
except Exception as e:
|
188 |
-
print(f"Event {event} has failed")
|
189 |
-
raise gr.Error(str(e))
|
190 |
-
|
191 |
-
# Call the function to log interaction
|
192 |
-
log_interaction(history, output_query, sources, docs, share_client, user_id)
|
193 |
-
|
194 |
-
yield history, docs_html, output_query, output_language, related_contents, graphs_html, follow_up_examples#, vanna_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/constants.py
CHANGED
@@ -1,6 +1,4 @@
|
|
1 |
POSSIBLE_REPORTS = [
|
2 |
-
"IPBES IABWFH SPM",
|
3 |
-
"IPBES CBL SPM",
|
4 |
"IPCC AR6 WGI SPM",
|
5 |
"IPCC AR6 WGI FR",
|
6 |
"IPCC AR6 WGI TS",
|
@@ -44,60 +42,4 @@ POSSIBLE_REPORTS = [
|
|
44 |
"IPBES IAS A C5",
|
45 |
"IPBES IAS A C6",
|
46 |
"IPBES IAS A SPM"
|
47 |
-
]
|
48 |
-
|
49 |
-
OWID_CATEGORIES = ['Access to Energy', 'Agricultural Production',
|
50 |
-
'Agricultural Regulation & Policy', 'Air Pollution',
|
51 |
-
'Animal Welfare', 'Antibiotics', 'Biodiversity', 'Biofuels',
|
52 |
-
'Biological & Chemical Weapons', 'CO2 & Greenhouse Gas Emissions',
|
53 |
-
'COVID-19', 'Clean Water', 'Clean Water & Sanitation',
|
54 |
-
'Climate Change', 'Crop Yields', 'Diet Compositions',
|
55 |
-
'Electricity', 'Electricity Mix', 'Energy', 'Energy Efficiency',
|
56 |
-
'Energy Prices', 'Environmental Impacts of Food Production',
|
57 |
-
'Environmental Protection & Regulation', 'Famines', 'Farm Size',
|
58 |
-
'Fertilizers', 'Fish & Overfishing', 'Food Supply', 'Food Trade',
|
59 |
-
'Food Waste', 'Food and Agriculture', 'Forests & Deforestation',
|
60 |
-
'Fossil Fuels', 'Future Population Growth',
|
61 |
-
'Hunger & Undernourishment', 'Indoor Air Pollution', 'Land Use',
|
62 |
-
'Land Use & Yields in Agriculture', 'Lead Pollution',
|
63 |
-
'Meat & Dairy Production', 'Metals & Minerals',
|
64 |
-
'Natural Disasters', 'Nuclear Energy', 'Nuclear Weapons',
|
65 |
-
'Oil Spills', 'Outdoor Air Pollution', 'Ozone Layer', 'Pandemics',
|
66 |
-
'Pesticides', 'Plastic Pollution', 'Renewable Energy', 'Soil',
|
67 |
-
'Transport', 'Urbanization', 'Waste Management', 'Water Pollution',
|
68 |
-
'Water Use & Stress', 'Wildfires']
|
69 |
-
|
70 |
-
|
71 |
-
DOCUMENT_METADATA_DEFAULT_VALUES = {
|
72 |
-
"chunk_type": "",
|
73 |
-
"document_id": "",
|
74 |
-
"document_number": 0.0,
|
75 |
-
"element_id": "",
|
76 |
-
"figure_code": "",
|
77 |
-
"file_size": "",
|
78 |
-
"image_path": "",
|
79 |
-
"n_pages": 0.0,
|
80 |
-
"name": "",
|
81 |
-
"num_characters": 0.0,
|
82 |
-
"num_tokens": 0.0,
|
83 |
-
"num_tokens_approx": 0.0,
|
84 |
-
"num_words": 0.0,
|
85 |
-
"page_number": 0,
|
86 |
-
"release_date": 0.0,
|
87 |
-
"report_type": "",
|
88 |
-
"section_header": "",
|
89 |
-
"short_name": "",
|
90 |
-
"source": "",
|
91 |
-
"toc_level0": "",
|
92 |
-
"toc_level1": "",
|
93 |
-
"toc_level2": "",
|
94 |
-
"toc_level3": "",
|
95 |
-
"url": "",
|
96 |
-
"similarity_score": 0.0,
|
97 |
-
"content": "",
|
98 |
-
"reranking_score": 0.0,
|
99 |
-
"query_used_for_retrieval": "",
|
100 |
-
"sources_used": [""],
|
101 |
-
"question_used": "",
|
102 |
-
"index_used": ""
|
103 |
-
}
|
|
|
1 |
POSSIBLE_REPORTS = [
|
|
|
|
|
2 |
"IPCC AR6 WGI SPM",
|
3 |
"IPCC AR6 WGI FR",
|
4 |
"IPCC AR6 WGI TS",
|
|
|
42 |
"IPBES IAS A C5",
|
43 |
"IPBES IAS A C6",
|
44 |
"IPBES IAS A SPM"
|
45 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/chains/__init__.py
DELETED
File without changes
|
climateqa/engine/chains/answer_ai_impact.py
DELETED
@@ -1,46 +0,0 @@
|
|
1 |
-
from langchain.prompts import ChatPromptTemplate
|
2 |
-
from langchain_core.output_parsers import StrOutputParser
|
3 |
-
|
4 |
-
|
5 |
-
prompt_template = """
|
6 |
-
You are ClimateQ&A, an helpful AI Assistant specialized in answering climate-related questions using info from the IPCC and/or IPBES reports.
|
7 |
-
Always stay true to climate and nature science and do not make up information.
|
8 |
-
If you do not know the answer, just say you do not know.
|
9 |
-
|
10 |
-
## Guidelines
|
11 |
-
- Explain that the environmental impact of AI is not covered by the IPCC or IPBES reports, but you can recommend info based on the sources below
|
12 |
-
- Answer the question in the original language of the question
|
13 |
-
|
14 |
-
## Sources
|
15 |
-
- You can propose to visit this page https://climateqa.com/docs/carbon-footprint/ to learn more about ClimateQ&A's own carbon footprint
|
16 |
-
- You can recommend to look at the work of the AI & climate expert scientist Sasha Luccioni with in in particular those papers
|
17 |
-
- Power Hungry Processing: Watts Driving the Cost of AI Deployment? - https://arxiv.org/abs/2311.16863 - about the carbon footprint at the inference stage of AI models
|
18 |
-
- Counting Carbon: A Survey of Factors Influencing the Emissions of Machine Learning - https://arxiv.org/abs/2302.08476
|
19 |
-
- Estimating the Carbon Footprint of BLOOM, a 176B Parameter Language Model - https://arxiv.org/abs/2211.02001 - about the carbon footprint of training a large language model
|
20 |
-
- You can also recommend the following tools to calculate the carbon footprint of AI models
|
21 |
-
- CodeCarbon - https://github.com/mlco2/codecarbon to measure the carbon footprint of your code
|
22 |
-
- Ecologits - https://ecologits.ai/ to measure the carbon footprint of using LLMs APIs such
|
23 |
-
"""
|
24 |
-
|
25 |
-
|
26 |
-
def make_ai_impact_chain(llm):
|
27 |
-
|
28 |
-
prompt = ChatPromptTemplate.from_messages([
|
29 |
-
("system", prompt_template),
|
30 |
-
("user", "{question}")
|
31 |
-
])
|
32 |
-
|
33 |
-
chain = prompt | llm | StrOutputParser()
|
34 |
-
chain = chain.with_config({"run_name":"ai_impact_chain"})
|
35 |
-
|
36 |
-
return chain
|
37 |
-
|
38 |
-
def make_ai_impact_node(llm):
|
39 |
-
|
40 |
-
ai_impact_chain = make_ai_impact_chain(llm)
|
41 |
-
|
42 |
-
async def answer_ai_impact(state,config):
|
43 |
-
answer = await ai_impact_chain.ainvoke({"question":state["user_input"]},config)
|
44 |
-
return {"answer":answer}
|
45 |
-
|
46 |
-
return answer_ai_impact
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/chains/answer_chitchat.py
DELETED
@@ -1,56 +0,0 @@
|
|
1 |
-
from langchain.prompts import ChatPromptTemplate
|
2 |
-
from langchain_core.output_parsers import StrOutputParser
|
3 |
-
|
4 |
-
|
5 |
-
chitchat_prompt_template = """
|
6 |
-
You are ClimateQ&A, an helpful AI Assistant specialized in answering climate-related questions using info from the IPCC and/or IPBES reports.
|
7 |
-
Always stay true to climate and nature science and do not make up information.
|
8 |
-
If you do not know the answer, just say you do not know.
|
9 |
-
|
10 |
-
## Guidelines
|
11 |
-
- If it's a conversational question, you can normally chat with the user
|
12 |
-
- If the question is not related to any topic about the environment, refuse to answer and politely ask the user to ask another question about the environment
|
13 |
-
- If the user ask if you speak any language, you can say you speak all languages :)
|
14 |
-
- If the user ask about the bot itself "ClimateQ&A", you can explain that you are an AI assistant specialized in answering climate-related questions using info from the IPCC and/or IPBES reports and propose to visit the website here https://climateqa.com/docs/intro/ for more information
|
15 |
-
- If the question is about ESG regulations, standards, or frameworks like the CSRD, TCFD, SASB, GRI, CDP, etc., you can explain that this is not a topic covered by the IPCC or IPBES reports.
|
16 |
-
- Precise that you are specialized in finding trustworthy information from the scientific reports of the IPCC and IPBES and other scientific litterature
|
17 |
-
- If relevant you can propose up to 3 example of questions they could ask from the IPCC or IPBES reports from the examples below
|
18 |
-
- Always answer in the original language of the question
|
19 |
-
|
20 |
-
## Examples of questions you can suggest (in the original language of the question)
|
21 |
-
"What evidence do we have of climate change?",
|
22 |
-
"Are human activities causing global warming?",
|
23 |
-
"What are the impacts of climate change?",
|
24 |
-
"Can climate change be reversed?",
|
25 |
-
"What is the difference between climate change and global warming?",
|
26 |
-
"""
|
27 |
-
|
28 |
-
|
29 |
-
def make_chitchat_chain(llm):
|
30 |
-
|
31 |
-
prompt = ChatPromptTemplate.from_messages([
|
32 |
-
("system", chitchat_prompt_template),
|
33 |
-
("user", "{question}")
|
34 |
-
])
|
35 |
-
|
36 |
-
chain = prompt | llm | StrOutputParser()
|
37 |
-
chain = chain.with_config({"run_name":"chitchat_chain"})
|
38 |
-
|
39 |
-
return chain
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
def make_chitchat_node(llm):
|
44 |
-
|
45 |
-
chitchat_chain = make_chitchat_chain(llm)
|
46 |
-
|
47 |
-
async def answer_chitchat(state,config):
|
48 |
-
print("---- Answer chitchat ----")
|
49 |
-
|
50 |
-
answer = await chitchat_chain.ainvoke({"question":state["user_input"]},config)
|
51 |
-
state["answer"] = answer
|
52 |
-
return state
|
53 |
-
# return {"answer":answer}
|
54 |
-
|
55 |
-
return answer_chitchat
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/chains/chitchat_categorization.py
DELETED
@@ -1,43 +0,0 @@
|
|
1 |
-
|
2 |
-
from langchain_core.pydantic_v1 import BaseModel, Field
|
3 |
-
from typing import List
|
4 |
-
from typing import Literal
|
5 |
-
from langchain.prompts import ChatPromptTemplate
|
6 |
-
from langchain_core.utils.function_calling import convert_to_openai_function
|
7 |
-
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
8 |
-
|
9 |
-
|
10 |
-
class IntentCategorizer(BaseModel):
|
11 |
-
"""Analyzing the user message input"""
|
12 |
-
|
13 |
-
environment: bool = Field(
|
14 |
-
description="Return 'True' if the question relates to climate change, the environment, nature, etc. (Example: should I eat fish?). Return 'False' if the question is just chit chat or not related to the environment or climate change.",
|
15 |
-
)
|
16 |
-
|
17 |
-
|
18 |
-
def make_chitchat_intent_categorization_chain(llm):
|
19 |
-
|
20 |
-
openai_functions = [convert_to_openai_function(IntentCategorizer)]
|
21 |
-
llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"IntentCategorizer"})
|
22 |
-
|
23 |
-
prompt = ChatPromptTemplate.from_messages([
|
24 |
-
("system", "You are a helpful assistant, you will analyze, translate and reformulate the user input message using the function provided"),
|
25 |
-
("user", "input: {input}")
|
26 |
-
])
|
27 |
-
|
28 |
-
chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
|
29 |
-
return chain
|
30 |
-
|
31 |
-
|
32 |
-
def make_chitchat_intent_categorization_node(llm):
|
33 |
-
|
34 |
-
categorization_chain = make_chitchat_intent_categorization_chain(llm)
|
35 |
-
|
36 |
-
def categorize_message(state):
|
37 |
-
output = categorization_chain.invoke({"input": state["user_input"]})
|
38 |
-
print(f"\n\nChit chat output intent categorization: {output}\n")
|
39 |
-
state["search_graphs_chitchat"] = output["environment"]
|
40 |
-
print(f"\n\nChit chat output intent categorization: {state}\n")
|
41 |
-
return state
|
42 |
-
|
43 |
-
return categorize_message
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/chains/follow_up.py
DELETED
@@ -1,33 +0,0 @@
|
|
1 |
-
from typing import List
|
2 |
-
from langchain.prompts import ChatPromptTemplate
|
3 |
-
|
4 |
-
|
5 |
-
FOLLOW_UP_TEMPLATE = """Based on the previous question and answer, generate 2-3 relevant follow-up questions that would help explore the topic further.
|
6 |
-
|
7 |
-
Previous Question: {user_input}
|
8 |
-
Previous Answer: {answer}
|
9 |
-
|
10 |
-
Generate short, concise, focused follow-up questions
|
11 |
-
You don't need a full question as it will be reformulated later as a standalone question with the context. Eg. "Details the first point"
|
12 |
-
"""
|
13 |
-
|
14 |
-
def make_follow_up_node(llm):
|
15 |
-
prompt = ChatPromptTemplate.from_template(FOLLOW_UP_TEMPLATE)
|
16 |
-
|
17 |
-
def generate_follow_up(state):
|
18 |
-
print("---- Generate_follow_up ----")
|
19 |
-
if not state.get("answer"):
|
20 |
-
return state
|
21 |
-
|
22 |
-
response = llm.invoke(prompt.format(
|
23 |
-
user_input=state["user_input"],
|
24 |
-
answer=state["answer"]
|
25 |
-
))
|
26 |
-
|
27 |
-
# Extract questions from response
|
28 |
-
follow_ups = [q.strip() for q in response.content.split("\n") if q.strip()]
|
29 |
-
state["follow_up_questions"] = follow_ups
|
30 |
-
|
31 |
-
return state
|
32 |
-
|
33 |
-
return generate_follow_up
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/chains/graph_retriever.py
DELETED
@@ -1,130 +0,0 @@
|
|
1 |
-
import sys
|
2 |
-
import os
|
3 |
-
from contextlib import contextmanager
|
4 |
-
|
5 |
-
from ..reranker import rerank_docs
|
6 |
-
from ..graph_retriever import retrieve_graphs # GraphRetriever
|
7 |
-
from ...utils import remove_duplicates_keep_highest_score
|
8 |
-
|
9 |
-
|
10 |
-
def divide_into_parts(target, parts):
|
11 |
-
# Base value for each part
|
12 |
-
base = target // parts
|
13 |
-
# Remainder to distribute
|
14 |
-
remainder = target % parts
|
15 |
-
# List to hold the result
|
16 |
-
result = []
|
17 |
-
|
18 |
-
for i in range(parts):
|
19 |
-
if i < remainder:
|
20 |
-
# These parts get base value + 1
|
21 |
-
result.append(base + 1)
|
22 |
-
else:
|
23 |
-
# The rest get the base value
|
24 |
-
result.append(base)
|
25 |
-
|
26 |
-
return result
|
27 |
-
|
28 |
-
|
29 |
-
@contextmanager
|
30 |
-
def suppress_output():
|
31 |
-
# Open a null device
|
32 |
-
with open(os.devnull, 'w') as devnull:
|
33 |
-
# Store the original stdout and stderr
|
34 |
-
old_stdout = sys.stdout
|
35 |
-
old_stderr = sys.stderr
|
36 |
-
# Redirect stdout and stderr to the null device
|
37 |
-
sys.stdout = devnull
|
38 |
-
sys.stderr = devnull
|
39 |
-
try:
|
40 |
-
yield
|
41 |
-
finally:
|
42 |
-
# Restore stdout and stderr
|
43 |
-
sys.stdout = old_stdout
|
44 |
-
sys.stderr = old_stderr
|
45 |
-
|
46 |
-
|
47 |
-
def make_graph_retriever_node(vectorstore, reranker, rerank_by_question=True, k_final=15, k_before_reranking=100):
|
48 |
-
|
49 |
-
async def node_retrieve_graphs(state):
|
50 |
-
print("---- Retrieving graphs ----")
|
51 |
-
|
52 |
-
POSSIBLE_SOURCES = ["IEA", "OWID"]
|
53 |
-
# questions = state["remaining_questions"] if state["remaining_questions"] is not None and state["remaining_questions"]!=[] else [state["query"]]
|
54 |
-
questions = state["questions_list"] if state["questions_list"] is not None and state["questions_list"]!=[] else [state["query"]]
|
55 |
-
|
56 |
-
# sources_input = state["sources_input"]
|
57 |
-
sources_input = ["auto"]
|
58 |
-
|
59 |
-
auto_mode = "auto" in sources_input
|
60 |
-
|
61 |
-
# There are several options to get the final top k
|
62 |
-
# Option 1 - Get 100 documents by question and rerank by question
|
63 |
-
# Option 2 - Get 100/n documents by question and rerank the total
|
64 |
-
if rerank_by_question:
|
65 |
-
k_by_question = divide_into_parts(k_final,len(questions))
|
66 |
-
|
67 |
-
docs = []
|
68 |
-
|
69 |
-
for i,q in enumerate(questions):
|
70 |
-
|
71 |
-
question = q["question"] if isinstance(q, dict) else q
|
72 |
-
|
73 |
-
print(f"Subquestion {i}: {question}")
|
74 |
-
|
75 |
-
# If auto mode, we use all sources
|
76 |
-
if auto_mode:
|
77 |
-
sources = POSSIBLE_SOURCES
|
78 |
-
# Otherwise, we use the config
|
79 |
-
else:
|
80 |
-
sources = sources_input
|
81 |
-
|
82 |
-
if any([x in POSSIBLE_SOURCES for x in sources]):
|
83 |
-
|
84 |
-
sources = [x for x in sources if x in POSSIBLE_SOURCES]
|
85 |
-
|
86 |
-
# Search the document store using the retriever
|
87 |
-
docs_question = await retrieve_graphs(
|
88 |
-
query = question,
|
89 |
-
vectorstore = vectorstore,
|
90 |
-
sources = sources,
|
91 |
-
k_total = k_before_reranking,
|
92 |
-
threshold = 0.5,
|
93 |
-
)
|
94 |
-
# docs_question = retriever.get_relevant_documents(question)
|
95 |
-
|
96 |
-
# Rerank
|
97 |
-
if reranker is not None and docs_question!=[]:
|
98 |
-
with suppress_output():
|
99 |
-
docs_question = rerank_docs(reranker,docs_question,question)
|
100 |
-
else:
|
101 |
-
# Add a default reranking score
|
102 |
-
for doc in docs_question:
|
103 |
-
doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
|
104 |
-
|
105 |
-
# If rerank by question we select the top documents for each question
|
106 |
-
if rerank_by_question:
|
107 |
-
docs_question = docs_question[:k_by_question[i]]
|
108 |
-
|
109 |
-
# Add sources used in the metadata
|
110 |
-
for doc in docs_question:
|
111 |
-
doc.metadata["sources_used"] = sources
|
112 |
-
|
113 |
-
print(f"{len(docs_question)} graphs retrieved for subquestion {i + 1}: {docs_question}")
|
114 |
-
|
115 |
-
docs.extend(docs_question)
|
116 |
-
|
117 |
-
else:
|
118 |
-
print(f"There are no graphs which match the sources filtered on. Sources filtered on: {sources}. Sources available: {POSSIBLE_SOURCES}.")
|
119 |
-
|
120 |
-
# Remove duplicates and keep the duplicate document with the highest reranking score
|
121 |
-
docs = remove_duplicates_keep_highest_score(docs)
|
122 |
-
|
123 |
-
# Sorting the list in descending order by rerank_score
|
124 |
-
# Then select the top k
|
125 |
-
docs = sorted(docs, key=lambda x: x.metadata["reranking_score"], reverse=True)
|
126 |
-
docs = docs[:k_final]
|
127 |
-
|
128 |
-
return {"recommended_content": docs}
|
129 |
-
|
130 |
-
return node_retrieve_graphs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/chains/intent_categorization.py
DELETED
@@ -1,97 +0,0 @@
|
|
1 |
-
from langchain_core.pydantic_v1 import BaseModel, Field
|
2 |
-
from typing import List
|
3 |
-
from typing import Literal
|
4 |
-
from langchain.prompts import ChatPromptTemplate
|
5 |
-
from langchain_core.utils.function_calling import convert_to_openai_function
|
6 |
-
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
7 |
-
|
8 |
-
|
9 |
-
class IntentCategorizer(BaseModel):
|
10 |
-
"""Analyzing the user message input"""
|
11 |
-
|
12 |
-
language: str = Field(
|
13 |
-
description="Find the language of the message input in full words (ex: French, English, Spanish, ...), defaults to English",
|
14 |
-
default="English",
|
15 |
-
)
|
16 |
-
intent: str = Field(
|
17 |
-
enum=[
|
18 |
-
"ai_impact",
|
19 |
-
# "geo_info",
|
20 |
-
# "esg",
|
21 |
-
"search",
|
22 |
-
"chitchat",
|
23 |
-
],
|
24 |
-
description="""
|
25 |
-
Categorize the user input in one of the following category
|
26 |
-
Any question
|
27 |
-
|
28 |
-
Examples:
|
29 |
-
- ai_impact = Environmental impacts of AI: "What are the environmental impacts of AI", "How does AI affect the environment"
|
30 |
-
- search = Searching for any quesiton about climate change, energy, biodiversity, nature, and everything we can find the IPCC or IPBES reports or scientific papers,
|
31 |
-
- chitchat = Any general question that is not related to the environment or climate change or just conversational, or if you don't think searching the IPCC or IPBES reports would be relevant
|
32 |
-
""",
|
33 |
-
# - geo_info = Geolocated info about climate change: Any question where the user wants to know localized impacts of climate change, eg: "What will be the temperature in Marseille in 2050"
|
34 |
-
# - esg = Any question about the ESG regulation, frameworks and standards like the CSRD, TCFD, SASB, GRI, CDP, etc.
|
35 |
-
|
36 |
-
)
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
def make_intent_categorization_chain(llm):
|
41 |
-
|
42 |
-
openai_functions = [convert_to_openai_function(IntentCategorizer)]
|
43 |
-
llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"IntentCategorizer"})
|
44 |
-
|
45 |
-
prompt = ChatPromptTemplate.from_messages([
|
46 |
-
("system", "You are a helpful assistant, you will analyze, detect the language, and categorize the user input message using the function provided. You MUST detect and return the language of the input message. Categorize the user input as ai ONLY if it is related to Artificial Intelligence, search if it is related to the environment, climate change, energy, biodiversity, nature, etc. and chitchat if it is just general conversation."),
|
47 |
-
("user", "input: {input}")
|
48 |
-
])
|
49 |
-
|
50 |
-
chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
|
51 |
-
return chain
|
52 |
-
|
53 |
-
|
54 |
-
def make_intent_categorization_node(llm):
|
55 |
-
|
56 |
-
categorization_chain = make_intent_categorization_chain(llm)
|
57 |
-
|
58 |
-
def categorize_message(state):
|
59 |
-
print("---- Categorize_message ----")
|
60 |
-
print(f"Input state: {state}")
|
61 |
-
|
62 |
-
output = categorization_chain.invoke({"input": state["user_input"]})
|
63 |
-
print(f"\n\nRaw output from categorization: {output}\n")
|
64 |
-
|
65 |
-
if "language" not in output:
|
66 |
-
print("WARNING: Language field missing from output, setting default to English")
|
67 |
-
output["language"] = "English"
|
68 |
-
else:
|
69 |
-
print(f"Language detected: {output['language']}")
|
70 |
-
|
71 |
-
output["query"] = state["user_input"]
|
72 |
-
print(f"Final output: {output}")
|
73 |
-
return output
|
74 |
-
|
75 |
-
return categorize_message
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
# SAMPLE_QUESTIONS = [
|
81 |
-
# "Est-ce que l'IA a un impact sur l'environnement ?",
|
82 |
-
# "Que dit le GIEC sur l'impact de l'IA",
|
83 |
-
# "Qui sont les membres du GIEC",
|
84 |
-
# "What is the impact of El Nino ?",
|
85 |
-
# "Yo",
|
86 |
-
# "Hello ça va bien ?",
|
87 |
-
# "Par qui as tu été créé ?",
|
88 |
-
# "What role do cloud formations play in modulating the Earth's radiative balance, and how are they represented in current climate models?",
|
89 |
-
# "Which industries have the highest GHG emissions?",
|
90 |
-
# "What are invasive alien species and how do they threaten biodiversity and ecosystems?",
|
91 |
-
# "Are human activities causing global warming?",
|
92 |
-
# "What is the motivation behind mining the deep seabed?",
|
93 |
-
# "Tu peux m'écrire un poème sur le changement climatique ?",
|
94 |
-
# "Tu peux m'écrire un poème sur les bonbons ?",
|
95 |
-
# "What will be the temperature in 2100 in Strasbourg?",
|
96 |
-
# "C'est quoi le lien entre biodiversity and changement climatique ?",
|
97 |
-
# ]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/chains/keywords_extraction.py
DELETED
@@ -1,40 +0,0 @@
|
|
1 |
-
|
2 |
-
from langchain_core.pydantic_v1 import BaseModel, Field
|
3 |
-
from typing import List
|
4 |
-
from typing import Literal
|
5 |
-
from langchain.prompts import ChatPromptTemplate
|
6 |
-
from langchain_core.utils.function_calling import convert_to_openai_function
|
7 |
-
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
8 |
-
|
9 |
-
|
10 |
-
class KeywordExtraction(BaseModel):
|
11 |
-
"""
|
12 |
-
Analyzing the user query to extract keywords to feed a search engine
|
13 |
-
"""
|
14 |
-
|
15 |
-
keywords: List[str] = Field(
|
16 |
-
description="""
|
17 |
-
Extract the keywords from the user query to feed a search engine as a list
|
18 |
-
Avoid adding super specific keywords to prefer general keywords
|
19 |
-
Maximum 3 keywords
|
20 |
-
|
21 |
-
Examples:
|
22 |
-
- "What is the impact of deep sea mining ?" -> ["deep sea mining"]
|
23 |
-
- "How will El Nino be impacted by climate change" -> ["el nino","climate change"]
|
24 |
-
- "Is climate change a hoax" -> ["climate change","hoax"]
|
25 |
-
"""
|
26 |
-
)
|
27 |
-
|
28 |
-
|
29 |
-
def make_keywords_extraction_chain(llm):
|
30 |
-
|
31 |
-
openai_functions = [convert_to_openai_function(KeywordExtraction)]
|
32 |
-
llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"KeywordExtraction"})
|
33 |
-
|
34 |
-
prompt = ChatPromptTemplate.from_messages([
|
35 |
-
("system", "You are a helpful assistant"),
|
36 |
-
("user", "input: {input}")
|
37 |
-
])
|
38 |
-
|
39 |
-
chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
|
40 |
-
return chain
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/chains/query_transformation.py
DELETED
@@ -1,300 +0,0 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
from langchain_core.pydantic_v1 import BaseModel, Field
|
4 |
-
from typing import List
|
5 |
-
from typing import Literal
|
6 |
-
from langchain.prompts import ChatPromptTemplate
|
7 |
-
from langchain_core.utils.function_calling import convert_to_openai_function
|
8 |
-
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
9 |
-
|
10 |
-
# OLD QUERY ANALYSIS
|
11 |
-
# keywords: List[str] = Field(
|
12 |
-
# description="""
|
13 |
-
# Extract the keywords from the user query to feed a search engine as a list
|
14 |
-
# Maximum 3 keywords
|
15 |
-
|
16 |
-
# Examples:
|
17 |
-
# - "What is the impact of deep sea mining ?" -> deep sea mining
|
18 |
-
# - "How will El Nino be impacted by climate change" -> el nino;climate change
|
19 |
-
# - "Is climate change a hoax" -> climate change;hoax
|
20 |
-
# """
|
21 |
-
# )
|
22 |
-
|
23 |
-
# alternative_queries: List[str] = Field(
|
24 |
-
# description="""
|
25 |
-
# Generate alternative search questions from the user query to feed a search engine
|
26 |
-
# """
|
27 |
-
# )
|
28 |
-
|
29 |
-
# step_back_question: str = Field(
|
30 |
-
# description="""
|
31 |
-
# You are an expert at world knowledge. Your task is to step back and paraphrase a question to a more generic step-back question, which is easier to answer.
|
32 |
-
# This questions should help you get more context and information about the user query
|
33 |
-
# """
|
34 |
-
# )
|
35 |
-
# - OpenAlex is for any other questions that are not in the previous categories but could be found in the scientific litterature
|
36 |
-
#
|
37 |
-
|
38 |
-
|
39 |
-
# topics: List[Literal[
|
40 |
-
# "Climate change",
|
41 |
-
# "Biodiversity",
|
42 |
-
# "Energy",
|
43 |
-
# "Decarbonization",
|
44 |
-
# "Climate science",
|
45 |
-
# "Nature",
|
46 |
-
# "Climate policy and justice",
|
47 |
-
# "Oceans",
|
48 |
-
# "Deep sea mining",
|
49 |
-
# "ESG and regulations",
|
50 |
-
# "CSRD",
|
51 |
-
# ]] = Field(
|
52 |
-
# ...,
|
53 |
-
# description = """
|
54 |
-
# Choose the topics that are most relevant to the user query, ex: Climate change, Energy, Biodiversity, ...
|
55 |
-
# """,
|
56 |
-
# )
|
57 |
-
# date: str = Field(description="The date or period mentioned, ex: 2050, between 2020 and 2050")
|
58 |
-
# location:Location
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
ROUTING_INDEX = {
|
63 |
-
"IPx":["IPCC", "IPBES", "IPOS"],
|
64 |
-
"POC": ["AcclimaTerra", "PCAET","Biodiv"],
|
65 |
-
"OpenAlex":["OpenAlex"],
|
66 |
-
}
|
67 |
-
|
68 |
-
POSSIBLE_SOURCES = [y for values in ROUTING_INDEX.values() for y in values]
|
69 |
-
|
70 |
-
# Prompt from the original paper https://arxiv.org/pdf/2305.14283
|
71 |
-
# Query Rewriting for Retrieval-Augmented Large Language Models
|
72 |
-
class QueryDecomposition(BaseModel):
|
73 |
-
"""
|
74 |
-
Decompose the user query into smaller parts to think step by step to answer this question
|
75 |
-
Act as a simple planning agent
|
76 |
-
"""
|
77 |
-
|
78 |
-
questions: List[str] = Field(
|
79 |
-
description="""
|
80 |
-
Think step by step to answer this question, and provide one or several search engine questions in the provided language for knowledge that you need.
|
81 |
-
Suppose that the user is looking for information about climate change, energy, biodiversity, nature, and everything we can find the IPCC reports and scientific literature
|
82 |
-
- If it's already a standalone and explicit question, just return the reformulated question for the search engine
|
83 |
-
- If you need to decompose the question, output a list of maximum 2 to 3 questions
|
84 |
-
"""
|
85 |
-
)
|
86 |
-
|
87 |
-
|
88 |
-
class Location(BaseModel):
|
89 |
-
country:str = Field(...,description="The country if directly mentioned or inferred from the location (cities, regions, adresses), ex: France, USA, ...")
|
90 |
-
location:str = Field(...,description="The specific place if mentioned (cities, regions, addresses), ex: Marseille, New York, Wisconsin, ...")
|
91 |
-
|
92 |
-
class QueryTranslation(BaseModel):
|
93 |
-
"""Translate the query into a given language"""
|
94 |
-
|
95 |
-
question : str = Field(
|
96 |
-
description="""
|
97 |
-
Translate the questions into the given language
|
98 |
-
If the question is alrealdy in the given language, just return the same question
|
99 |
-
""",
|
100 |
-
)
|
101 |
-
|
102 |
-
|
103 |
-
class QueryAnalysis(BaseModel):
|
104 |
-
"""
|
105 |
-
Analyze the user query to extract the relevant sources
|
106 |
-
|
107 |
-
Deprecated:
|
108 |
-
Analyzing the user query to extract topics, sources and date
|
109 |
-
Also do query expansion to get alternative search queries
|
110 |
-
Also provide simple keywords to feed a search engine
|
111 |
-
"""
|
112 |
-
|
113 |
-
sources: List[Literal["IPCC", "IPBES", "IPOS", "AcclimaTerra", "PCAET","Biodiv"]] = Field( #,"OpenAlex"]] = Field(
|
114 |
-
...,
|
115 |
-
description="""
|
116 |
-
Given a user question choose which documents would be most relevant for answering their question,
|
117 |
-
- IPCC is for questions about climate change, energy, impacts, and everything we can find the IPCC reports
|
118 |
-
- IPBES is for questions about biodiversity and nature
|
119 |
-
- IPOS is for questions about the ocean and deep sea mining
|
120 |
-
- AcclimaTerra is for questions about any specific place in, or close to, the french region "Nouvelle-Aquitaine"
|
121 |
-
- PCAET is the Plan Climat Eneregie Territorial for the city of Paris
|
122 |
-
- Biodiv is the Biodiversity plan for the city of Paris
|
123 |
-
""",
|
124 |
-
)
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
def make_query_decomposition_chain(llm):
|
129 |
-
"""Chain to decompose a query into smaller parts to think step by step to answer this question
|
130 |
-
|
131 |
-
Args:
|
132 |
-
llm (_type_): _description_
|
133 |
-
|
134 |
-
Returns:
|
135 |
-
_type_: _description_
|
136 |
-
"""
|
137 |
-
|
138 |
-
openai_functions = [convert_to_openai_function(QueryDecomposition)]
|
139 |
-
llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"QueryDecomposition"})
|
140 |
-
|
141 |
-
prompt = ChatPromptTemplate.from_messages([
|
142 |
-
("system", "You are a helpful assistant, you will analyze, translate and reformulate the user input message using the function provided"),
|
143 |
-
("user", "input: {input}")
|
144 |
-
])
|
145 |
-
|
146 |
-
chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
|
147 |
-
return chain
|
148 |
-
|
149 |
-
|
150 |
-
def make_query_analysis_chain(llm):
|
151 |
-
"""Analyze the user query to extract the relevant sources"""
|
152 |
-
|
153 |
-
openai_functions = [convert_to_openai_function(QueryAnalysis)]
|
154 |
-
llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"QueryAnalysis"})
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
prompt = ChatPromptTemplate.from_messages([
|
159 |
-
("system", "You are a helpful assistant, you will analyze the user input message using the function provided"),
|
160 |
-
("user", "input: {input}")
|
161 |
-
])
|
162 |
-
|
163 |
-
|
164 |
-
chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
|
165 |
-
return chain
|
166 |
-
|
167 |
-
|
168 |
-
def make_query_translation_chain(llm):
|
169 |
-
"""Analyze the user query to extract the relevant sources"""
|
170 |
-
|
171 |
-
openai_functions = [convert_to_openai_function(QueryTranslation)]
|
172 |
-
llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"QueryTranslation"})
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
prompt = ChatPromptTemplate.from_messages([
|
177 |
-
("system", "You are a helpful assistant, translate the question into {language}"),
|
178 |
-
("user", "input: {input}")
|
179 |
-
])
|
180 |
-
|
181 |
-
|
182 |
-
chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
|
183 |
-
return chain
|
184 |
-
|
185 |
-
def group_by_sources_types(sources):
|
186 |
-
sources_types = {}
|
187 |
-
IPx_sources = ["IPCC", "IPBES", "IPOS"]
|
188 |
-
local_sources = ["AcclimaTerra", "PCAET","Biodiv"]
|
189 |
-
if any(source in IPx_sources for source in sources):
|
190 |
-
sources_types["IPx"] = list(set(sources).intersection(IPx_sources))
|
191 |
-
if any(source in local_sources for source in sources):
|
192 |
-
sources_types["POC"] = list(set(sources).intersection(local_sources))
|
193 |
-
return sources_types
|
194 |
-
|
195 |
-
|
196 |
-
def make_query_transform_node(llm,k_final=15):
|
197 |
-
"""
|
198 |
-
Creates a query transformation node that processes and transforms a given query state.
|
199 |
-
Args:
|
200 |
-
llm: The language model to be used for query decomposition and rewriting.
|
201 |
-
k_final (int, optional): The final number of questions to be generated. Defaults to 15.
|
202 |
-
Returns:
|
203 |
-
function: A function that takes a query state and returns a transformed state.
|
204 |
-
The returned function performs the following steps:
|
205 |
-
1. Checks if the query should be processed in auto mode based on the state.
|
206 |
-
2. Retrieves the input sources from the state or defaults to a predefined routing index.
|
207 |
-
3. Decomposes the query using the decomposition chain.
|
208 |
-
4. Analyzes each decomposed question using the rewriter chain.
|
209 |
-
5. Ensures that the sources returned by the language model are valid.
|
210 |
-
6. Explodes the questions into multiple questions with different sources based on the mode.
|
211 |
-
7. Constructs a new state with the transformed questions and their respective sources.
|
212 |
-
"""
|
213 |
-
|
214 |
-
|
215 |
-
decomposition_chain = make_query_decomposition_chain(llm)
|
216 |
-
query_analysis_chain = make_query_analysis_chain(llm)
|
217 |
-
query_translation_chain = make_query_translation_chain(llm)
|
218 |
-
|
219 |
-
def transform_query(state):
|
220 |
-
print("---- Transform query ----")
|
221 |
-
|
222 |
-
auto_mode = state.get("sources_auto", True)
|
223 |
-
sources_input = state.get("sources_input", ROUTING_INDEX["IPx"])
|
224 |
-
|
225 |
-
|
226 |
-
new_state = {}
|
227 |
-
|
228 |
-
# Decomposition
|
229 |
-
decomposition_output = decomposition_chain.invoke({"input":state["query"]})
|
230 |
-
new_state.update(decomposition_output)
|
231 |
-
|
232 |
-
|
233 |
-
# Query Analysis
|
234 |
-
questions = []
|
235 |
-
for question in new_state["questions"]:
|
236 |
-
question_state = {"question":question}
|
237 |
-
query_analysis_output = query_analysis_chain.invoke({"input":question})
|
238 |
-
|
239 |
-
# TODO WARNING llm should always return smthg
|
240 |
-
# The case when the llm does not return any sources or wrong ouput
|
241 |
-
if not query_analysis_output["sources"] or not all(source in ["IPCC", "IPBS", "IPOS","AcclimaTerra", "PCAET","Biodiv"] for source in query_analysis_output["sources"]):
|
242 |
-
query_analysis_output["sources"] = ["IPCC", "IPBES", "IPOS"]
|
243 |
-
|
244 |
-
sources_types = group_by_sources_types(query_analysis_output["sources"])
|
245 |
-
for source_type,sources in sources_types.items():
|
246 |
-
question_state = {
|
247 |
-
"question":question,
|
248 |
-
"sources":sources,
|
249 |
-
"source_type":source_type
|
250 |
-
}
|
251 |
-
|
252 |
-
questions.append(question_state)
|
253 |
-
|
254 |
-
# Translate question into the document language
|
255 |
-
for q in questions:
|
256 |
-
if q["source_type"]=="IPx":
|
257 |
-
translation_output = query_translation_chain.invoke({"input":q["question"],"language":"English"})
|
258 |
-
q["question"] = translation_output["question"]
|
259 |
-
elif q["source_type"]=="POC":
|
260 |
-
translation_output = query_translation_chain.invoke({"input":q["question"],"language":"French"})
|
261 |
-
q["question"] = translation_output["question"]
|
262 |
-
|
263 |
-
# Explode the questions into multiple questions with different sources
|
264 |
-
new_questions = []
|
265 |
-
for q in questions:
|
266 |
-
question,sources,source_type = q["question"],q["sources"], q["source_type"]
|
267 |
-
|
268 |
-
# If not auto mode we take the configuration
|
269 |
-
if not auto_mode:
|
270 |
-
sources = sources_input
|
271 |
-
|
272 |
-
for index,index_sources in ROUTING_INDEX.items():
|
273 |
-
selected_sources = list(set(sources).intersection(index_sources))
|
274 |
-
if len(selected_sources) > 0:
|
275 |
-
new_questions.append({"question":question,"sources":selected_sources,"index":index, "source_type":source_type})
|
276 |
-
|
277 |
-
# # Add the number of questions to search
|
278 |
-
# k_by_question = k_final // len(new_questions)
|
279 |
-
# for q in new_questions:
|
280 |
-
# q["k"] = k_by_question
|
281 |
-
|
282 |
-
# new_state["questions"] = new_questions
|
283 |
-
# new_state["remaining_questions"] = new_questions
|
284 |
-
|
285 |
-
n_questions = {
|
286 |
-
"total":len(new_questions),
|
287 |
-
"IPx":len([q for q in new_questions if q["index"] == "IPx"]),
|
288 |
-
"POC":len([q for q in new_questions if q["index"] == "POC"]),
|
289 |
-
}
|
290 |
-
|
291 |
-
new_state = {
|
292 |
-
"questions_list":new_questions,
|
293 |
-
"n_questions":n_questions,
|
294 |
-
"handled_questions_index":[],
|
295 |
-
}
|
296 |
-
print("New questions")
|
297 |
-
print(new_questions)
|
298 |
-
return new_state
|
299 |
-
|
300 |
-
return transform_query
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/chains/retrieve_documents.py
DELETED
@@ -1,705 +0,0 @@
|
|
1 |
-
import sys
|
2 |
-
import os
|
3 |
-
from contextlib import contextmanager
|
4 |
-
|
5 |
-
from langchain_core.tools import tool
|
6 |
-
from langchain_core.runnables import chain
|
7 |
-
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
|
8 |
-
from langchain_core.runnables import RunnableLambda
|
9 |
-
|
10 |
-
from ..reranker import rerank_docs, rerank_and_sort_docs
|
11 |
-
# from ...knowledge.retriever import ClimateQARetriever
|
12 |
-
from ...knowledge.openalex import OpenAlexRetriever
|
13 |
-
from .keywords_extraction import make_keywords_extraction_chain
|
14 |
-
from ..utils import log_event
|
15 |
-
from langchain_core.vectorstores import VectorStore
|
16 |
-
from typing import List
|
17 |
-
from langchain_core.documents.base import Document
|
18 |
-
from ..llm import get_llm
|
19 |
-
from .prompts import retrieve_chapter_prompt_template
|
20 |
-
from langchain_core.prompts import ChatPromptTemplate
|
21 |
-
from langchain_core.output_parsers import StrOutputParser
|
22 |
-
from ..vectorstore import get_pinecone_vectorstore
|
23 |
-
from ..embeddings import get_embeddings_function
|
24 |
-
import ast
|
25 |
-
|
26 |
-
import asyncio
|
27 |
-
|
28 |
-
from typing import Any, Dict, List, Tuple
|
29 |
-
|
30 |
-
|
31 |
-
def divide_into_parts(target, parts):
|
32 |
-
# Base value for each part
|
33 |
-
base = target // parts
|
34 |
-
# Remainder to distribute
|
35 |
-
remainder = target % parts
|
36 |
-
# List to hold the result
|
37 |
-
result = []
|
38 |
-
|
39 |
-
for i in range(parts):
|
40 |
-
if i < remainder:
|
41 |
-
# These parts get base value + 1
|
42 |
-
result.append(base + 1)
|
43 |
-
else:
|
44 |
-
# The rest get the base value
|
45 |
-
result.append(base)
|
46 |
-
|
47 |
-
return result
|
48 |
-
|
49 |
-
|
50 |
-
@contextmanager
|
51 |
-
def suppress_output():
|
52 |
-
# Open a null device
|
53 |
-
with open(os.devnull, 'w') as devnull:
|
54 |
-
# Store the original stdout and stderr
|
55 |
-
old_stdout = sys.stdout
|
56 |
-
old_stderr = sys.stderr
|
57 |
-
# Redirect stdout and stderr to the null device
|
58 |
-
sys.stdout = devnull
|
59 |
-
sys.stderr = devnull
|
60 |
-
try:
|
61 |
-
yield
|
62 |
-
finally:
|
63 |
-
# Restore stdout and stderr
|
64 |
-
sys.stdout = old_stdout
|
65 |
-
sys.stderr = old_stderr
|
66 |
-
|
67 |
-
|
68 |
-
@tool
|
69 |
-
def query_retriever(question):
|
70 |
-
"""Just a dummy tool to simulate the retriever query"""
|
71 |
-
return question
|
72 |
-
|
73 |
-
def _add_sources_used_in_metadata(docs,sources,question,index):
|
74 |
-
for doc in docs:
|
75 |
-
doc.metadata["sources_used"] = sources
|
76 |
-
doc.metadata["question_used"] = question
|
77 |
-
doc.metadata["index_used"] = index
|
78 |
-
return docs
|
79 |
-
|
80 |
-
def _get_k_summary_by_question(n_questions):
|
81 |
-
if n_questions == 0:
|
82 |
-
return 0
|
83 |
-
elif n_questions == 1:
|
84 |
-
return 5
|
85 |
-
elif n_questions == 2:
|
86 |
-
return 3
|
87 |
-
elif n_questions == 3:
|
88 |
-
return 2
|
89 |
-
else:
|
90 |
-
return 1
|
91 |
-
|
92 |
-
def _get_k_images_by_question(n_questions):
|
93 |
-
if n_questions == 0:
|
94 |
-
return 0
|
95 |
-
elif n_questions == 1:
|
96 |
-
return 7
|
97 |
-
elif n_questions == 2:
|
98 |
-
return 5
|
99 |
-
elif n_questions == 3:
|
100 |
-
return 3
|
101 |
-
else:
|
102 |
-
return 1
|
103 |
-
|
104 |
-
def _add_metadata_and_score(docs: List) -> Document:
|
105 |
-
# Add score to metadata
|
106 |
-
docs_with_metadata = []
|
107 |
-
for i,(doc,score) in enumerate(docs):
|
108 |
-
doc.page_content = doc.page_content.replace("\r\n"," ")
|
109 |
-
doc.metadata["similarity_score"] = score
|
110 |
-
doc.metadata["content"] = doc.page_content
|
111 |
-
if doc.metadata["page_number"] != "N/A":
|
112 |
-
doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
|
113 |
-
else:
|
114 |
-
doc.metadata["page_number"] = 1
|
115 |
-
# doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
|
116 |
-
docs_with_metadata.append(doc)
|
117 |
-
return docs_with_metadata
|
118 |
-
|
119 |
-
def remove_duplicates_chunks(docs):
|
120 |
-
# Remove duplicates or almost duplicates
|
121 |
-
docs = sorted(docs,key=lambda x: x[1],reverse=True)
|
122 |
-
seen = set()
|
123 |
-
result = []
|
124 |
-
for doc in docs:
|
125 |
-
if doc[0].page_content not in seen:
|
126 |
-
seen.add(doc[0].page_content)
|
127 |
-
result.append(doc)
|
128 |
-
return result
|
129 |
-
|
130 |
-
def get_ToCs(version: str) :
|
131 |
-
|
132 |
-
filters_text = {
|
133 |
-
"chunk_type":"toc",
|
134 |
-
"version": version
|
135 |
-
}
|
136 |
-
embeddings_function = get_embeddings_function()
|
137 |
-
vectorstore = get_pinecone_vectorstore(embeddings_function, index_name="climateqa-v2")
|
138 |
-
tocs = vectorstore.similarity_search_with_score(query="",filter = filters_text)
|
139 |
-
|
140 |
-
# remove duplicates or almost duplicates
|
141 |
-
tocs = remove_duplicates_chunks(tocs)
|
142 |
-
|
143 |
-
return tocs
|
144 |
-
|
145 |
-
async def get_POC_relevant_documents(
|
146 |
-
query: str,
|
147 |
-
vectorstore:VectorStore,
|
148 |
-
sources:list = ["Acclimaterra","PCAET","Plan Biodiversite"],
|
149 |
-
search_figures:bool = False,
|
150 |
-
search_only:bool = False,
|
151 |
-
k_documents:int = 10,
|
152 |
-
threshold:float = 0.6,
|
153 |
-
k_images: int = 5,
|
154 |
-
reports:list = [],
|
155 |
-
min_size:int = 200,
|
156 |
-
) :
|
157 |
-
# Prepare base search kwargs
|
158 |
-
filters = {}
|
159 |
-
docs_question = []
|
160 |
-
docs_images = []
|
161 |
-
|
162 |
-
# TODO add source selection
|
163 |
-
# if len(reports) > 0:
|
164 |
-
# filters["short_name"] = {"$in":reports}
|
165 |
-
# else:
|
166 |
-
# filters["source"] = { "$in": sources}
|
167 |
-
|
168 |
-
filters_text = {
|
169 |
-
**filters,
|
170 |
-
"chunk_type":"text",
|
171 |
-
# "report_type": {}, # TODO to be completed to choose the right documents / chapters according to the analysis of the question
|
172 |
-
}
|
173 |
-
|
174 |
-
docs_question = vectorstore.similarity_search_with_score(query=query,filter = filters_text,k = k_documents)
|
175 |
-
# remove duplicates or almost duplicates
|
176 |
-
docs_question = remove_duplicates_chunks(docs_question)
|
177 |
-
docs_question = [x for x in docs_question if x[1] > threshold]
|
178 |
-
|
179 |
-
if search_figures:
|
180 |
-
# Images
|
181 |
-
filters_image = {
|
182 |
-
**filters,
|
183 |
-
"chunk_type":"image"
|
184 |
-
}
|
185 |
-
docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images)
|
186 |
-
|
187 |
-
docs_question, docs_images = _add_metadata_and_score(docs_question), _add_metadata_and_score(docs_images)
|
188 |
-
|
189 |
-
docs_question = [x for x in docs_question if len(x.page_content) > min_size]
|
190 |
-
|
191 |
-
return {
|
192 |
-
"docs_question" : docs_question,
|
193 |
-
"docs_images" : docs_images
|
194 |
-
}
|
195 |
-
|
196 |
-
async def get_POC_documents_by_ToC_relevant_documents(
|
197 |
-
query: str,
|
198 |
-
tocs: list,
|
199 |
-
vectorstore:VectorStore,
|
200 |
-
version: str,
|
201 |
-
sources:list = ["Acclimaterra","PCAET","Plan Biodiversite"],
|
202 |
-
search_figures:bool = False,
|
203 |
-
search_only:bool = False,
|
204 |
-
k_documents:int = 10,
|
205 |
-
threshold:float = 0.6,
|
206 |
-
k_images: int = 5,
|
207 |
-
reports:list = [],
|
208 |
-
min_size:int = 200,
|
209 |
-
proportion: float = 0.5,
|
210 |
-
) :
|
211 |
-
"""
|
212 |
-
Args:
|
213 |
-
- tocs : list with the table of contents of each document
|
214 |
-
- version : version of the parsed documents (e.g. "v4")
|
215 |
-
- proportion : share of documents retrieved using ToCs
|
216 |
-
"""
|
217 |
-
# Prepare base search kwargs
|
218 |
-
filters = {}
|
219 |
-
docs_question = []
|
220 |
-
docs_images = []
|
221 |
-
|
222 |
-
# TODO add source selection
|
223 |
-
# if len(reports) > 0:
|
224 |
-
# filters["short_name"] = {"$in":reports}
|
225 |
-
# else:
|
226 |
-
# filters["source"] = { "$in": sources}
|
227 |
-
|
228 |
-
k_documents_toc = round(k_documents * proportion)
|
229 |
-
|
230 |
-
relevant_tocs = await get_relevant_toc_level_for_query(query, tocs)
|
231 |
-
|
232 |
-
print(f"Relevant ToCs : {relevant_tocs}")
|
233 |
-
# Transform the ToC dict {"document": str, "chapter": str} into a list of string
|
234 |
-
toc_filters = [toc['chapter'] for toc in relevant_tocs]
|
235 |
-
|
236 |
-
filters_text_toc = {
|
237 |
-
**filters,
|
238 |
-
"chunk_type":"text",
|
239 |
-
"toc_level0": {"$in": toc_filters},
|
240 |
-
"version": version
|
241 |
-
# "report_type": {}, # TODO to be completed to choose the right documents / chapters according to the analysis of the question
|
242 |
-
}
|
243 |
-
|
244 |
-
docs_question = vectorstore.similarity_search_with_score(query=query,filter = filters_text_toc,k = k_documents_toc)
|
245 |
-
|
246 |
-
filters_text = {
|
247 |
-
**filters,
|
248 |
-
"chunk_type":"text",
|
249 |
-
"version": version
|
250 |
-
# "report_type": {}, # TODO to be completed to choose the right documents / chapters according to the analysis of the question
|
251 |
-
}
|
252 |
-
|
253 |
-
docs_question += vectorstore.similarity_search_with_score(query=query,filter = filters_text,k = k_documents - k_documents_toc)
|
254 |
-
|
255 |
-
# remove duplicates or almost duplicates
|
256 |
-
docs_question = remove_duplicates_chunks(docs_question)
|
257 |
-
docs_question = [x for x in docs_question if x[1] > threshold]
|
258 |
-
|
259 |
-
if search_figures:
|
260 |
-
# Images
|
261 |
-
filters_image = {
|
262 |
-
**filters,
|
263 |
-
"chunk_type":"image"
|
264 |
-
}
|
265 |
-
docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images)
|
266 |
-
|
267 |
-
docs_question, docs_images = _add_metadata_and_score(docs_question), _add_metadata_and_score(docs_images)
|
268 |
-
|
269 |
-
docs_question = [x for x in docs_question if len(x.page_content) > min_size]
|
270 |
-
|
271 |
-
return {
|
272 |
-
"docs_question" : docs_question,
|
273 |
-
"docs_images" : docs_images
|
274 |
-
}
|
275 |
-
|
276 |
-
|
277 |
-
async def get_IPCC_relevant_documents(
|
278 |
-
query: str,
|
279 |
-
vectorstore:VectorStore,
|
280 |
-
sources:list = ["IPCC","IPBES","IPOS"],
|
281 |
-
search_figures:bool = False,
|
282 |
-
reports:list = [],
|
283 |
-
threshold:float = 0.6,
|
284 |
-
k_summary:int = 3,
|
285 |
-
k_total:int = 10,
|
286 |
-
k_images: int = 5,
|
287 |
-
namespace:str = "vectors",
|
288 |
-
min_size:int = 200,
|
289 |
-
search_only:bool = False,
|
290 |
-
) :
|
291 |
-
|
292 |
-
# Check if all elements in the list are either IPCC or IPBES
|
293 |
-
assert isinstance(sources,list)
|
294 |
-
assert sources
|
295 |
-
assert all([x in ["IPCC","IPBES","IPOS"] for x in sources])
|
296 |
-
assert k_total > k_summary, "k_total should be greater than k_summary"
|
297 |
-
|
298 |
-
# Prepare base search kwargs
|
299 |
-
filters = {}
|
300 |
-
|
301 |
-
if len(reports) > 0:
|
302 |
-
filters["short_name"] = {"$in":reports}
|
303 |
-
else:
|
304 |
-
filters["source"] = { "$in": sources}
|
305 |
-
|
306 |
-
# INIT
|
307 |
-
docs_summaries = []
|
308 |
-
docs_full = []
|
309 |
-
docs_images = []
|
310 |
-
|
311 |
-
if search_only:
|
312 |
-
# Only search for images if search_only is True
|
313 |
-
if search_figures:
|
314 |
-
filters_image = {
|
315 |
-
**filters,
|
316 |
-
"chunk_type":"image"
|
317 |
-
}
|
318 |
-
docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images)
|
319 |
-
docs_images = _add_metadata_and_score(docs_images)
|
320 |
-
else:
|
321 |
-
# Regular search flow for text and optionally images
|
322 |
-
# Search for k_summary documents in the summaries dataset
|
323 |
-
filters_summaries = {
|
324 |
-
**filters,
|
325 |
-
"chunk_type":"text",
|
326 |
-
"report_type": { "$in":["SPM"]},
|
327 |
-
}
|
328 |
-
|
329 |
-
docs_summaries = vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = k_summary)
|
330 |
-
docs_summaries = [x for x in docs_summaries if x[1] > threshold]
|
331 |
-
|
332 |
-
# Search for k_total - k_summary documents in the full reports dataset
|
333 |
-
filters_full = {
|
334 |
-
**filters,
|
335 |
-
"chunk_type":"text",
|
336 |
-
"report_type": { "$nin":["SPM"]},
|
337 |
-
}
|
338 |
-
docs_full = vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_total)
|
339 |
-
|
340 |
-
if search_figures:
|
341 |
-
# Images
|
342 |
-
filters_image = {
|
343 |
-
**filters,
|
344 |
-
"chunk_type":"image"
|
345 |
-
}
|
346 |
-
docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images)
|
347 |
-
|
348 |
-
docs_summaries, docs_full, docs_images = _add_metadata_and_score(docs_summaries), _add_metadata_and_score(docs_full), _add_metadata_and_score(docs_images)
|
349 |
-
|
350 |
-
# Filter if length are below threshold
|
351 |
-
docs_summaries = [x for x in docs_summaries if len(x.page_content) > min_size]
|
352 |
-
docs_full = [x for x in docs_full if len(x.page_content) > min_size]
|
353 |
-
|
354 |
-
return {
|
355 |
-
"docs_summaries" : docs_summaries,
|
356 |
-
"docs_full" : docs_full,
|
357 |
-
"docs_images" : docs_images,
|
358 |
-
}
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
def concatenate_documents(index, source_type, docs_question_dict, k_by_question, k_summary_by_question, k_images_by_question):
|
363 |
-
# Keep the right number of documents - The k_summary documents from SPM are placed in front
|
364 |
-
if source_type == "IPx":
|
365 |
-
docs_question = docs_question_dict["docs_summaries"][:k_summary_by_question] + docs_question_dict["docs_full"][:(k_by_question - k_summary_by_question)]
|
366 |
-
elif source_type == "POC" :
|
367 |
-
docs_question = docs_question_dict["docs_question"][:k_by_question]
|
368 |
-
else :
|
369 |
-
raise ValueError("source_type should be either Vector or POC")
|
370 |
-
# docs_question = [doc for key in docs_question_dict.keys() for doc in docs_question_dict[key]][:(k_by_question)]
|
371 |
-
|
372 |
-
images_question = docs_question_dict["docs_images"][:k_images_by_question]
|
373 |
-
|
374 |
-
return docs_question, images_question
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
# The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results
|
379 |
-
# @chain
|
380 |
-
async def retrieve_documents(
|
381 |
-
current_question: Dict[str, Any],
|
382 |
-
config: Dict[str, Any],
|
383 |
-
source_type: str,
|
384 |
-
vectorstore: VectorStore,
|
385 |
-
reranker: Any,
|
386 |
-
version: str = "",
|
387 |
-
search_figures: bool = False,
|
388 |
-
search_only: bool = False,
|
389 |
-
reports: list = [],
|
390 |
-
rerank_by_question: bool = True,
|
391 |
-
k_images_by_question: int = 5,
|
392 |
-
k_before_reranking: int = 100,
|
393 |
-
k_by_question: int = 5,
|
394 |
-
k_summary_by_question: int = 3,
|
395 |
-
tocs: list = [],
|
396 |
-
by_toc=False
|
397 |
-
) -> Tuple[List[Document], List[Document]]:
|
398 |
-
"""
|
399 |
-
Unpack the first question of the remaining questions, and retrieve and rerank corresponding documents, based on the question and selected_sources
|
400 |
-
|
401 |
-
Args:
|
402 |
-
state (dict): The current state containing documents, related content, relevant content sources, remaining questions and n_questions.
|
403 |
-
current_question (dict): The current question being processed.
|
404 |
-
config (dict): Configuration settings for logging and other purposes.
|
405 |
-
vectorstore (object): The vector store used to retrieve relevant documents.
|
406 |
-
reranker (object): The reranker used to rerank the retrieved documents.
|
407 |
-
llm (object): The language model used for processing.
|
408 |
-
rerank_by_question (bool, optional): Whether to rerank documents by question. Defaults to True.
|
409 |
-
k_final (int, optional): The final number of documents to retrieve. Defaults to 15.
|
410 |
-
k_before_reranking (int, optional): The number of documents to retrieve before reranking. Defaults to 100.
|
411 |
-
k_summary (int, optional): The number of summary documents to retrieve. Defaults to 5.
|
412 |
-
k_images (int, optional): The number of image documents to retrieve. Defaults to 5.
|
413 |
-
Returns:
|
414 |
-
dict: The updated state containing the retrieved and reranked documents, related content, and remaining questions.
|
415 |
-
"""
|
416 |
-
sources = current_question["sources"]
|
417 |
-
question = current_question["question"]
|
418 |
-
index = current_question["index"]
|
419 |
-
source_type = current_question["source_type"]
|
420 |
-
|
421 |
-
print(f"Retrieve documents for question: {question}")
|
422 |
-
await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config)
|
423 |
-
|
424 |
-
print(f"""---- Retrieve documents from {current_question["source_type"]}----""")
|
425 |
-
|
426 |
-
|
427 |
-
if source_type == "IPx":
|
428 |
-
docs_question_dict = await get_IPCC_relevant_documents(
|
429 |
-
query = question,
|
430 |
-
vectorstore=vectorstore,
|
431 |
-
search_figures = search_figures,
|
432 |
-
sources = sources,
|
433 |
-
min_size = 200,
|
434 |
-
k_summary = k_before_reranking-1,
|
435 |
-
k_total = k_before_reranking,
|
436 |
-
k_images = k_images_by_question,
|
437 |
-
threshold = 0.5,
|
438 |
-
search_only = search_only,
|
439 |
-
reports = reports,
|
440 |
-
)
|
441 |
-
|
442 |
-
if source_type == 'POC':
|
443 |
-
if by_toc == True:
|
444 |
-
print("---- Retrieve documents by ToC----")
|
445 |
-
docs_question_dict = await get_POC_documents_by_ToC_relevant_documents(
|
446 |
-
query=question,
|
447 |
-
tocs = tocs,
|
448 |
-
vectorstore=vectorstore,
|
449 |
-
version=version,
|
450 |
-
search_figures = search_figures,
|
451 |
-
sources = sources,
|
452 |
-
threshold = 0.5,
|
453 |
-
search_only = search_only,
|
454 |
-
reports = reports,
|
455 |
-
min_size= 200,
|
456 |
-
k_documents= k_before_reranking,
|
457 |
-
k_images= k_by_question
|
458 |
-
)
|
459 |
-
else :
|
460 |
-
docs_question_dict = await get_POC_relevant_documents(
|
461 |
-
query = question,
|
462 |
-
vectorstore=vectorstore,
|
463 |
-
search_figures = search_figures,
|
464 |
-
sources = sources,
|
465 |
-
threshold = 0.5,
|
466 |
-
search_only = search_only,
|
467 |
-
reports = reports,
|
468 |
-
min_size= 200,
|
469 |
-
k_documents= k_before_reranking,
|
470 |
-
k_images= k_by_question
|
471 |
-
)
|
472 |
-
|
473 |
-
# Rerank
|
474 |
-
if reranker is not None and rerank_by_question:
|
475 |
-
with suppress_output():
|
476 |
-
for key in docs_question_dict.keys():
|
477 |
-
docs_question_dict[key] = rerank_and_sort_docs(reranker,docs_question_dict[key],question)
|
478 |
-
else:
|
479 |
-
# Add a default reranking score
|
480 |
-
for key in docs_question_dict.keys():
|
481 |
-
if isinstance(docs_question_dict[key], list) and len(docs_question_dict[key]) > 0:
|
482 |
-
for doc in docs_question_dict[key]:
|
483 |
-
doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
|
484 |
-
|
485 |
-
# Keep the right number of documents
|
486 |
-
docs_question, images_question = concatenate_documents(index, source_type, docs_question_dict, k_by_question, k_summary_by_question, k_images_by_question)
|
487 |
-
|
488 |
-
# Rerank the documents to put the most relevant in front
|
489 |
-
if reranker is not None and rerank_by_question:
|
490 |
-
docs_question = rerank_and_sort_docs(reranker, docs_question, question)
|
491 |
-
|
492 |
-
# Add sources used in the metadata
|
493 |
-
docs_question = _add_sources_used_in_metadata(docs_question,sources,question,index)
|
494 |
-
images_question = _add_sources_used_in_metadata(images_question,sources,question,index)
|
495 |
-
|
496 |
-
return docs_question, images_question
|
497 |
-
|
498 |
-
|
499 |
-
async def retrieve_documents_for_all_questions(
|
500 |
-
search_figures,
|
501 |
-
search_only,
|
502 |
-
reports,
|
503 |
-
questions_list,
|
504 |
-
n_questions,
|
505 |
-
config,
|
506 |
-
source_type,
|
507 |
-
to_handle_questions_index,
|
508 |
-
vectorstore,
|
509 |
-
reranker,
|
510 |
-
rerank_by_question=True,
|
511 |
-
k_final=15,
|
512 |
-
k_before_reranking=100,
|
513 |
-
version: str = "",
|
514 |
-
tocs: list[dict] = [],
|
515 |
-
by_toc: bool = False
|
516 |
-
):
|
517 |
-
"""
|
518 |
-
Retrieve documents in parallel for all questions.
|
519 |
-
"""
|
520 |
-
# to_handle_questions_index = [x for x in state["questions_list"] if x["source_type"] == "IPx"]
|
521 |
-
|
522 |
-
# TODO split les questions selon le type de sources dans le state question + conditions sur le nombre de questions traités par type de source
|
523 |
-
# search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"]
|
524 |
-
# search_only = state["search_only"]
|
525 |
-
# reports = state["reports"]
|
526 |
-
# questions_list = state["questions_list"]
|
527 |
-
|
528 |
-
# k_by_question = k_final // state["n_questions"]["total"]
|
529 |
-
# k_summary_by_question = _get_k_summary_by_question(state["n_questions"]["total"])
|
530 |
-
# k_images_by_question = _get_k_images_by_question(state["n_questions"]["total"])
|
531 |
-
k_by_question = k_final // n_questions
|
532 |
-
k_summary_by_question = _get_k_summary_by_question(n_questions)
|
533 |
-
k_images_by_question = _get_k_images_by_question(n_questions)
|
534 |
-
k_before_reranking=100
|
535 |
-
|
536 |
-
print(f"Source type here is {source_type}")
|
537 |
-
tasks = [
|
538 |
-
retrieve_documents(
|
539 |
-
current_question=question,
|
540 |
-
config=config,
|
541 |
-
source_type=source_type,
|
542 |
-
vectorstore=vectorstore,
|
543 |
-
reranker=reranker,
|
544 |
-
search_figures=search_figures,
|
545 |
-
search_only=search_only,
|
546 |
-
reports=reports,
|
547 |
-
rerank_by_question=rerank_by_question,
|
548 |
-
k_images_by_question=k_images_by_question,
|
549 |
-
k_before_reranking=k_before_reranking,
|
550 |
-
k_by_question=k_by_question,
|
551 |
-
k_summary_by_question=k_summary_by_question,
|
552 |
-
tocs=tocs,
|
553 |
-
version=version,
|
554 |
-
by_toc=by_toc
|
555 |
-
)
|
556 |
-
for i, question in enumerate(questions_list) if i in to_handle_questions_index
|
557 |
-
]
|
558 |
-
results = await asyncio.gather(*tasks)
|
559 |
-
# Combine results
|
560 |
-
new_state = {"documents": [], "related_contents": [], "handled_questions_index": to_handle_questions_index}
|
561 |
-
for docs_question, images_question in results:
|
562 |
-
new_state["documents"].extend(docs_question)
|
563 |
-
new_state["related_contents"].extend(images_question)
|
564 |
-
return new_state
|
565 |
-
|
566 |
-
# ToC Retriever
|
567 |
-
async def get_relevant_toc_level_for_query(
|
568 |
-
query: str,
|
569 |
-
tocs: list[Document],
|
570 |
-
) -> list[dict] :
|
571 |
-
|
572 |
-
doc_list = []
|
573 |
-
for doc in tocs:
|
574 |
-
doc_name = doc[0].metadata['name']
|
575 |
-
toc = doc[0].page_content
|
576 |
-
doc_list.append({'document': doc_name, 'toc': toc})
|
577 |
-
|
578 |
-
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
|
579 |
-
|
580 |
-
prompt = ChatPromptTemplate.from_template(retrieve_chapter_prompt_template)
|
581 |
-
chain = prompt | llm | StrOutputParser()
|
582 |
-
response = chain.invoke({"query": query, "doc_list": doc_list})
|
583 |
-
|
584 |
-
try:
|
585 |
-
relevant_tocs = ast.literal_eval(response)
|
586 |
-
except Exception as e:
|
587 |
-
print(f" Failed to parse the result because of : {e}")
|
588 |
-
|
589 |
-
return relevant_tocs
|
590 |
-
|
591 |
-
|
592 |
-
def make_IPx_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
|
593 |
-
|
594 |
-
async def retrieve_IPx_docs(state, config):
|
595 |
-
source_type = "IPx"
|
596 |
-
IPx_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "IPx"]
|
597 |
-
|
598 |
-
search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"]
|
599 |
-
search_only = state["search_only"]
|
600 |
-
reports = state["reports"]
|
601 |
-
questions_list = state["questions_list"]
|
602 |
-
n_questions=state["n_questions"]["total"]
|
603 |
-
|
604 |
-
state = await retrieve_documents_for_all_questions(
|
605 |
-
search_figures=search_figures,
|
606 |
-
search_only=search_only,
|
607 |
-
reports=reports,
|
608 |
-
questions_list=questions_list,
|
609 |
-
n_questions=n_questions,
|
610 |
-
config=config,
|
611 |
-
source_type=source_type,
|
612 |
-
to_handle_questions_index=IPx_questions_index,
|
613 |
-
vectorstore=vectorstore,
|
614 |
-
reranker=reranker,
|
615 |
-
rerank_by_question=rerank_by_question,
|
616 |
-
k_final=k_final,
|
617 |
-
k_before_reranking=k_before_reranking,
|
618 |
-
)
|
619 |
-
return state
|
620 |
-
|
621 |
-
return retrieve_IPx_docs
|
622 |
-
|
623 |
-
|
624 |
-
def make_POC_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
|
625 |
-
|
626 |
-
async def retrieve_POC_docs_node(state, config):
|
627 |
-
source_type = "POC"
|
628 |
-
POC_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"]
|
629 |
-
|
630 |
-
search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"]
|
631 |
-
search_only = state["search_only"]
|
632 |
-
reports = state["reports"]
|
633 |
-
questions_list = state["questions_list"]
|
634 |
-
n_questions=state["n_questions"]["total"]
|
635 |
-
|
636 |
-
state = await retrieve_documents_for_all_questions(
|
637 |
-
search_figures=search_figures,
|
638 |
-
search_only=search_only,
|
639 |
-
reports=reports,
|
640 |
-
questions_list=questions_list,
|
641 |
-
n_questions=n_questions,
|
642 |
-
config=config,
|
643 |
-
source_type=source_type,
|
644 |
-
to_handle_questions_index=POC_questions_index,
|
645 |
-
vectorstore=vectorstore,
|
646 |
-
reranker=reranker,
|
647 |
-
rerank_by_question=rerank_by_question,
|
648 |
-
k_final=k_final,
|
649 |
-
k_before_reranking=k_before_reranking,
|
650 |
-
)
|
651 |
-
return state
|
652 |
-
|
653 |
-
return retrieve_POC_docs_node
|
654 |
-
|
655 |
-
|
656 |
-
def make_POC_by_ToC_retriever_node(
|
657 |
-
vectorstore: VectorStore,
|
658 |
-
reranker,
|
659 |
-
llm,
|
660 |
-
version: str = "",
|
661 |
-
rerank_by_question=True,
|
662 |
-
k_final=15,
|
663 |
-
k_before_reranking=100,
|
664 |
-
k_summary=5,
|
665 |
-
):
|
666 |
-
|
667 |
-
async def retrieve_POC_docs_node(state, config):
|
668 |
-
search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"]
|
669 |
-
search_only = state["search_only"]
|
670 |
-
search_only = state["search_only"]
|
671 |
-
reports = state["reports"]
|
672 |
-
questions_list = state["questions_list"]
|
673 |
-
n_questions=state["n_questions"]["total"]
|
674 |
-
|
675 |
-
tocs = get_ToCs(version=version)
|
676 |
-
|
677 |
-
source_type = "POC"
|
678 |
-
POC_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"]
|
679 |
-
|
680 |
-
state = await retrieve_documents_for_all_questions(
|
681 |
-
search_figures=search_figures,
|
682 |
-
search_only=search_only,
|
683 |
-
config=config,
|
684 |
-
reports=reports,
|
685 |
-
questions_list=questions_list,
|
686 |
-
n_questions=n_questions,
|
687 |
-
source_type=source_type,
|
688 |
-
to_handle_questions_index=POC_questions_index,
|
689 |
-
vectorstore=vectorstore,
|
690 |
-
reranker=reranker,
|
691 |
-
rerank_by_question=rerank_by_question,
|
692 |
-
k_final=k_final,
|
693 |
-
k_before_reranking=k_before_reranking,
|
694 |
-
tocs=tocs,
|
695 |
-
version=version,
|
696 |
-
by_toc=True
|
697 |
-
)
|
698 |
-
return state
|
699 |
-
|
700 |
-
return retrieve_POC_docs_node
|
701 |
-
|
702 |
-
|
703 |
-
|
704 |
-
|
705 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/chains/retrieve_papers.py
DELETED
@@ -1,95 +0,0 @@
|
|
1 |
-
from climateqa.engine.keywords import make_keywords_chain
|
2 |
-
from climateqa.engine.llm import get_llm
|
3 |
-
from climateqa.knowledge.openalex import OpenAlex
|
4 |
-
from climateqa.engine.chains.answer_rag import make_rag_papers_chain
|
5 |
-
from front.utils import make_html_papers
|
6 |
-
from climateqa.engine.reranker import get_reranker
|
7 |
-
|
8 |
-
oa = OpenAlex()
|
9 |
-
|
10 |
-
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
|
11 |
-
reranker = get_reranker("nano")
|
12 |
-
|
13 |
-
|
14 |
-
papers_cols_widths = {
|
15 |
-
"id":100,
|
16 |
-
"title":300,
|
17 |
-
"doi":100,
|
18 |
-
"publication_year":100,
|
19 |
-
"abstract":500,
|
20 |
-
"is_oa":50,
|
21 |
-
}
|
22 |
-
|
23 |
-
papers_cols = list(papers_cols_widths.keys())
|
24 |
-
papers_cols_widths = list(papers_cols_widths.values())
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
def generate_keywords(query):
|
29 |
-
chain = make_keywords_chain(llm)
|
30 |
-
keywords = chain.invoke(query)
|
31 |
-
keywords = " AND ".join(keywords["keywords"])
|
32 |
-
return keywords
|
33 |
-
|
34 |
-
|
35 |
-
async def find_papers(query,after, relevant_content_sources_selection, reranker= reranker):
|
36 |
-
if "Papers (OpenAlex)" in relevant_content_sources_selection:
|
37 |
-
summary = ""
|
38 |
-
keywords = generate_keywords(query)
|
39 |
-
df_works = oa.search(keywords,after = after)
|
40 |
-
|
41 |
-
print(f"Found {len(df_works)} papers")
|
42 |
-
|
43 |
-
if not df_works.empty:
|
44 |
-
df_works = df_works.dropna(subset=["abstract"])
|
45 |
-
df_works = df_works[df_works["abstract"] != ""].reset_index(drop = True)
|
46 |
-
df_works = oa.rerank(query,df_works,reranker)
|
47 |
-
df_works = df_works.sort_values("rerank_score",ascending=False)
|
48 |
-
docs_html = []
|
49 |
-
for i in range(10):
|
50 |
-
docs_html.append(make_html_papers(df_works, i))
|
51 |
-
docs_html = "".join(docs_html)
|
52 |
-
G = oa.make_network(df_works)
|
53 |
-
|
54 |
-
height = "750px"
|
55 |
-
network = oa.show_network(G,color_by = "rerank_score",notebook=False,height = height)
|
56 |
-
network_html = network.generate_html()
|
57 |
-
|
58 |
-
network_html = network_html.replace("'", "\"")
|
59 |
-
css_to_inject = "<style>#mynetwork { border: none !important; } .card { border: none !important; }</style>"
|
60 |
-
network_html = network_html + css_to_inject
|
61 |
-
|
62 |
-
|
63 |
-
network_html = f"""<iframe style="width: 100%; height: {height};margin:0 auto" name="result" allow="midi; geolocation; microphone; camera;
|
64 |
-
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
|
65 |
-
allow-scripts allow-same-origin allow-popups
|
66 |
-
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
|
67 |
-
allowpaymentrequest="" frameborder="0" srcdoc='{network_html}'></iframe>"""
|
68 |
-
|
69 |
-
|
70 |
-
docs = df_works["content"].head(10).tolist()
|
71 |
-
|
72 |
-
df_works = df_works.reset_index(drop = True).reset_index().rename(columns = {"index":"doc"})
|
73 |
-
df_works["doc"] = df_works["doc"] + 1
|
74 |
-
df_works = df_works[papers_cols]
|
75 |
-
|
76 |
-
yield docs_html, network_html, summary
|
77 |
-
|
78 |
-
chain = make_rag_papers_chain(llm)
|
79 |
-
result = chain.astream_log({"question": query,"docs": docs,"language":"English"})
|
80 |
-
path_answer = "/logs/StrOutputParser/streamed_output/-"
|
81 |
-
|
82 |
-
async for op in result:
|
83 |
-
|
84 |
-
op = op.ops[0]
|
85 |
-
|
86 |
-
if op['path'] == path_answer: # reforulated question
|
87 |
-
new_token = op['value'] # str
|
88 |
-
summary += new_token
|
89 |
-
else:
|
90 |
-
continue
|
91 |
-
yield docs_html, network_html, summary
|
92 |
-
else :
|
93 |
-
print("No papers found")
|
94 |
-
else :
|
95 |
-
yield "","", ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/chains/retriever.py
DELETED
@@ -1,126 +0,0 @@
|
|
1 |
-
# import sys
|
2 |
-
# import os
|
3 |
-
# from contextlib import contextmanager
|
4 |
-
|
5 |
-
# from ..reranker import rerank_docs
|
6 |
-
# from ...knowledge.retriever import ClimateQARetriever
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
# def divide_into_parts(target, parts):
|
12 |
-
# # Base value for each part
|
13 |
-
# base = target // parts
|
14 |
-
# # Remainder to distribute
|
15 |
-
# remainder = target % parts
|
16 |
-
# # List to hold the result
|
17 |
-
# result = []
|
18 |
-
|
19 |
-
# for i in range(parts):
|
20 |
-
# if i < remainder:
|
21 |
-
# # These parts get base value + 1
|
22 |
-
# result.append(base + 1)
|
23 |
-
# else:
|
24 |
-
# # The rest get the base value
|
25 |
-
# result.append(base)
|
26 |
-
|
27 |
-
# return result
|
28 |
-
|
29 |
-
|
30 |
-
# @contextmanager
|
31 |
-
# def suppress_output():
|
32 |
-
# # Open a null device
|
33 |
-
# with open(os.devnull, 'w') as devnull:
|
34 |
-
# # Store the original stdout and stderr
|
35 |
-
# old_stdout = sys.stdout
|
36 |
-
# old_stderr = sys.stderr
|
37 |
-
# # Redirect stdout and stderr to the null device
|
38 |
-
# sys.stdout = devnull
|
39 |
-
# sys.stderr = devnull
|
40 |
-
# try:
|
41 |
-
# yield
|
42 |
-
# finally:
|
43 |
-
# # Restore stdout and stderr
|
44 |
-
# sys.stdout = old_stdout
|
45 |
-
# sys.stderr = old_stderr
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
# def make_retriever_node(vectorstore,reranker,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
|
50 |
-
|
51 |
-
# def retrieve_documents(state):
|
52 |
-
|
53 |
-
# POSSIBLE_SOURCES = ["IPCC","IPBES","IPOS"] # ,"OpenAlex"]
|
54 |
-
# questions = state["questions"]
|
55 |
-
|
56 |
-
# # Use sources from the user input or from the LLM detection
|
57 |
-
# if "sources_input" not in state or state["sources_input"] is None:
|
58 |
-
# sources_input = ["auto"]
|
59 |
-
# else:
|
60 |
-
# sources_input = state["sources_input"]
|
61 |
-
# auto_mode = "auto" in sources_input
|
62 |
-
|
63 |
-
# # There are several options to get the final top k
|
64 |
-
# # Option 1 - Get 100 documents by question and rerank by question
|
65 |
-
# # Option 2 - Get 100/n documents by question and rerank the total
|
66 |
-
# if rerank_by_question:
|
67 |
-
# k_by_question = divide_into_parts(k_final,len(questions))
|
68 |
-
|
69 |
-
# docs = []
|
70 |
-
|
71 |
-
# for i,q in enumerate(questions):
|
72 |
-
|
73 |
-
# sources = q["sources"]
|
74 |
-
# question = q["question"]
|
75 |
-
|
76 |
-
# # If auto mode, we use the sources detected by the LLM
|
77 |
-
# if auto_mode:
|
78 |
-
# sources = [x for x in sources if x in POSSIBLE_SOURCES]
|
79 |
-
|
80 |
-
# # Otherwise, we use the config
|
81 |
-
# else:
|
82 |
-
# sources = sources_input
|
83 |
-
|
84 |
-
# # Search the document store using the retriever
|
85 |
-
# # Configure high top k for further reranking step
|
86 |
-
# retriever = ClimateQARetriever(
|
87 |
-
# vectorstore=vectorstore,
|
88 |
-
# sources = sources,
|
89 |
-
# # reports = ias_reports,
|
90 |
-
# min_size = 200,
|
91 |
-
# k_summary = k_summary,
|
92 |
-
# k_total = k_before_reranking,
|
93 |
-
# threshold = 0.5,
|
94 |
-
# )
|
95 |
-
# docs_question = retriever.get_relevant_documents(question)
|
96 |
-
|
97 |
-
# # Rerank
|
98 |
-
# if reranker is not None:
|
99 |
-
# with suppress_output():
|
100 |
-
# docs_question = rerank_docs(reranker,docs_question,question)
|
101 |
-
# else:
|
102 |
-
# # Add a default reranking score
|
103 |
-
# for doc in docs_question:
|
104 |
-
# doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
|
105 |
-
|
106 |
-
# # If rerank by question we select the top documents for each question
|
107 |
-
# if rerank_by_question:
|
108 |
-
# docs_question = docs_question[:k_by_question[i]]
|
109 |
-
|
110 |
-
# # Add sources used in the metadata
|
111 |
-
# for doc in docs_question:
|
112 |
-
# doc.metadata["sources_used"] = sources
|
113 |
-
|
114 |
-
# # Add to the list of docs
|
115 |
-
# docs.extend(docs_question)
|
116 |
-
|
117 |
-
# # Sorting the list in descending order by rerank_score
|
118 |
-
# # Then select the top k
|
119 |
-
# docs = sorted(docs, key=lambda x: x.metadata["reranking_score"], reverse=True)
|
120 |
-
# docs = docs[:k_final]
|
121 |
-
|
122 |
-
# new_state = {"documents":docs}
|
123 |
-
# return new_state
|
124 |
-
|
125 |
-
# return retrieve_documents
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/chains/sample_router.py
DELETED
@@ -1,66 +0,0 @@
|
|
1 |
-
|
2 |
-
# from typing import List
|
3 |
-
# from typing import Literal
|
4 |
-
# from langchain.prompts import ChatPromptTemplate
|
5 |
-
# from langchain_core.utils.function_calling import convert_to_openai_function
|
6 |
-
# from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
7 |
-
|
8 |
-
# # https://livingdatalab.com/posts/2023-11-05-openai-function-calling-with-langchain.html
|
9 |
-
|
10 |
-
# class Location(BaseModel):
|
11 |
-
# country:str = Field(...,description="The country if directly mentioned or inferred from the location (cities, regions, adresses), ex: France, USA, ...")
|
12 |
-
# location:str = Field(...,description="The specific place if mentioned (cities, regions, addresses), ex: Marseille, New York, Wisconsin, ...")
|
13 |
-
|
14 |
-
# class QueryAnalysis(BaseModel):
|
15 |
-
# """Analyzing the user query"""
|
16 |
-
|
17 |
-
# language: str = Field(
|
18 |
-
# description="Find the language of the query in full words (ex: French, English, Spanish, ...), defaults to English"
|
19 |
-
# )
|
20 |
-
# intent: str = Field(
|
21 |
-
# enum=[
|
22 |
-
# "Environmental impacts of AI",
|
23 |
-
# "Geolocated info about climate change",
|
24 |
-
# "Climate change",
|
25 |
-
# "Biodiversity",
|
26 |
-
# "Deep sea mining",
|
27 |
-
# "Chitchat",
|
28 |
-
# ],
|
29 |
-
# description="""
|
30 |
-
# Categorize the user query in one of the following category,
|
31 |
-
|
32 |
-
# Examples:
|
33 |
-
# - Geolocated info about climate change: "What will be the temperature in Marseille in 2050"
|
34 |
-
# - Climate change: "What is radiative forcing", "How much will
|
35 |
-
# """,
|
36 |
-
# )
|
37 |
-
# sources: List[Literal["IPCC", "IPBES", "IPOS"]] = Field(
|
38 |
-
# ...,
|
39 |
-
# description="""
|
40 |
-
# Given a user question choose which documents would be most relevant for answering their question,
|
41 |
-
# - IPCC is for questions about climate change, energy, impacts, and everything we can find the IPCC reports
|
42 |
-
# - IPBES is for questions about biodiversity and nature
|
43 |
-
# - IPOS is for questions about the ocean and deep sea mining
|
44 |
-
|
45 |
-
# """,
|
46 |
-
# )
|
47 |
-
# date: str = Field(description="The date or period mentioned, ex: 2050, between 2020 and 2050")
|
48 |
-
# location:Location
|
49 |
-
# # query: str = Field(
|
50 |
-
# # description = """
|
51 |
-
# # Translate to english and reformulate the following user message to be a short standalone question, in the context of an educational discussion about climate change.
|
52 |
-
# # The reformulated question will used in a search engine
|
53 |
-
# # By default, assume that the user is asking information about the last century,
|
54 |
-
# # Use the following examples
|
55 |
-
|
56 |
-
# # ### Examples:
|
57 |
-
# # La technologie nous sauvera-t-elle ? -> Can technology help humanity mitigate the effects of climate change?
|
58 |
-
# # what are our reserves in fossil fuel? -> What are the current reserves of fossil fuels and how long will they last?
|
59 |
-
# # what are the main causes of climate change? -> What are the main causes of climate change in the last century?
|
60 |
-
|
61 |
-
# # Question in English:
|
62 |
-
# # """
|
63 |
-
# # )
|
64 |
-
|
65 |
-
# openai_functions = [convert_to_openai_function(QueryAnalysis)]
|
66 |
-
# llm2 = llm.bind(functions = openai_functions,function_call={"name":"QueryAnalysis"})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/chains/set_defaults.py
DELETED
@@ -1,13 +0,0 @@
|
|
1 |
-
def set_defaults(state):
|
2 |
-
print("---- Setting defaults ----")
|
3 |
-
|
4 |
-
if not state["audience"] or state["audience"] is None:
|
5 |
-
state.update({"audience": "experts"})
|
6 |
-
|
7 |
-
sources_input = state["sources_input"] if "sources_input" in state else ["auto"]
|
8 |
-
state.update({"sources_input": sources_input})
|
9 |
-
|
10 |
-
# if not state["sources_input"] or state["sources_input"] is None:
|
11 |
-
# state.update({"sources_input": ["auto"]})
|
12 |
-
|
13 |
-
return state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/chains/standalone_question.py
DELETED
@@ -1,42 +0,0 @@
|
|
1 |
-
from langchain.prompts import ChatPromptTemplate
|
2 |
-
|
3 |
-
def make_standalone_question_chain(llm):
|
4 |
-
prompt = ChatPromptTemplate.from_messages([
|
5 |
-
("system", """You are a helpful assistant that transforms user questions into standalone questions
|
6 |
-
by incorporating context from the chat history if needed. The output should be a self-contained
|
7 |
-
question that can be understood without any additional context.
|
8 |
-
|
9 |
-
Examples:
|
10 |
-
Chat History: "Let's talk about renewable energy"
|
11 |
-
User Input: "What about solar?"
|
12 |
-
Output: "What are the key aspects of solar energy as a renewable energy source?"
|
13 |
-
|
14 |
-
Chat History: "What causes global warming?"
|
15 |
-
User Input: "And what are its effects?"
|
16 |
-
Output: "What are the effects of global warming on the environment and society?"
|
17 |
-
"""),
|
18 |
-
("user", """Chat History: {chat_history}
|
19 |
-
User Question: {question}
|
20 |
-
|
21 |
-
Transform this into a standalone question:
|
22 |
-
Make sure to keep the original language of the question.""")
|
23 |
-
])
|
24 |
-
|
25 |
-
chain = prompt | llm
|
26 |
-
return chain
|
27 |
-
|
28 |
-
def make_standalone_question_node(llm):
|
29 |
-
standalone_chain = make_standalone_question_chain(llm)
|
30 |
-
|
31 |
-
def transform_to_standalone(state):
|
32 |
-
chat_history = state.get("chat_history", "")
|
33 |
-
if chat_history == "":
|
34 |
-
return {}
|
35 |
-
output = standalone_chain.invoke({
|
36 |
-
"chat_history": chat_history,
|
37 |
-
"question": state["user_input"]
|
38 |
-
})
|
39 |
-
state["user_input"] = output.content
|
40 |
-
return state
|
41 |
-
|
42 |
-
return transform_to_standalone
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/chains/translation.py
DELETED
@@ -1,42 +0,0 @@
|
|
1 |
-
|
2 |
-
from langchain_core.pydantic_v1 import BaseModel, Field
|
3 |
-
from typing import List
|
4 |
-
from typing import Literal
|
5 |
-
from langchain.prompts import ChatPromptTemplate
|
6 |
-
from langchain_core.utils.function_calling import convert_to_openai_function
|
7 |
-
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
8 |
-
|
9 |
-
|
10 |
-
class Translation(BaseModel):
|
11 |
-
"""Analyzing the user message input"""
|
12 |
-
|
13 |
-
translation: str = Field(
|
14 |
-
description="Translate the message input to English",
|
15 |
-
)
|
16 |
-
|
17 |
-
|
18 |
-
def make_translation_chain(llm):
|
19 |
-
|
20 |
-
openai_functions = [convert_to_openai_function(Translation)]
|
21 |
-
llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"Translation"})
|
22 |
-
|
23 |
-
prompt = ChatPromptTemplate.from_messages([
|
24 |
-
("system", "You are a helpful assistant, you will translate the user input message to English using the function provided"),
|
25 |
-
("user", "input: {input}")
|
26 |
-
])
|
27 |
-
|
28 |
-
chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
|
29 |
-
return chain
|
30 |
-
|
31 |
-
|
32 |
-
def make_translation_node(llm):
|
33 |
-
translation_chain = make_translation_chain(llm)
|
34 |
-
|
35 |
-
def translate_query(state):
|
36 |
-
print("---- Translate query ----")
|
37 |
-
|
38 |
-
user_input = state["user_input"]
|
39 |
-
translation = translation_chain.invoke({"input":user_input})
|
40 |
-
return {"query":translation["translation"]}
|
41 |
-
|
42 |
-
return translate_query
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/embeddings.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
3 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
4 |
|
5 |
-
def get_embeddings_function(version = "v1.2"
|
6 |
|
7 |
if version == "v1.2":
|
8 |
|
@@ -10,12 +10,12 @@ def get_embeddings_function(version = "v1.2",query_instruction = "Represent this
|
|
10 |
# Best embedding model at a reasonable size at the moment (2023-11-22)
|
11 |
|
12 |
model_name = "BAAI/bge-base-en-v1.5"
|
13 |
-
encode_kwargs = {'normalize_embeddings': True
|
14 |
print("Loading embeddings model: ", model_name)
|
15 |
embeddings_function = HuggingFaceBgeEmbeddings(
|
16 |
model_name=model_name,
|
17 |
encode_kwargs=encode_kwargs,
|
18 |
-
query_instruction=
|
19 |
)
|
20 |
|
21 |
else:
|
@@ -23,6 +23,3 @@ def get_embeddings_function(version = "v1.2",query_instruction = "Represent this
|
|
23 |
embeddings_function = HuggingFaceEmbeddings(model_name = "sentence-transformers/multi-qa-mpnet-base-dot-v1")
|
24 |
|
25 |
return embeddings_function
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
2 |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
3 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
4 |
|
5 |
+
def get_embeddings_function(version = "v1.2"):
|
6 |
|
7 |
if version == "v1.2":
|
8 |
|
|
|
10 |
# Best embedding model at a reasonable size at the moment (2023-11-22)
|
11 |
|
12 |
model_name = "BAAI/bge-base-en-v1.5"
|
13 |
+
encode_kwargs = {'normalize_embeddings': True} # set True to compute cosine similarity
|
14 |
print("Loading embeddings model: ", model_name)
|
15 |
embeddings_function = HuggingFaceBgeEmbeddings(
|
16 |
model_name=model_name,
|
17 |
encode_kwargs=encode_kwargs,
|
18 |
+
query_instruction="Represent this sentence for searching relevant passages: "
|
19 |
)
|
20 |
|
21 |
else:
|
|
|
23 |
embeddings_function = HuggingFaceEmbeddings(model_name = "sentence-transformers/multi-qa-mpnet-base-dot-v1")
|
24 |
|
25 |
return embeddings_function
|
|
|
|
|
|
climateqa/engine/graph.py
DELETED
@@ -1,346 +0,0 @@
|
|
1 |
-
import sys
|
2 |
-
import os
|
3 |
-
from contextlib import contextmanager
|
4 |
-
|
5 |
-
from langchain.schema import Document
|
6 |
-
from langgraph.graph import END, StateGraph
|
7 |
-
from langchain_core.runnables.graph import CurveStyle, MermaidDrawMethod
|
8 |
-
|
9 |
-
from typing_extensions import TypedDict
|
10 |
-
from typing import List, Dict
|
11 |
-
|
12 |
-
import operator
|
13 |
-
from typing import Annotated
|
14 |
-
import pandas as pd
|
15 |
-
from IPython.display import display, HTML, Image
|
16 |
-
|
17 |
-
from .chains.answer_chitchat import make_chitchat_node
|
18 |
-
from .chains.answer_ai_impact import make_ai_impact_node
|
19 |
-
from .chains.query_transformation import make_query_transform_node
|
20 |
-
from .chains.translation import make_translation_node
|
21 |
-
from .chains.intent_categorization import make_intent_categorization_node
|
22 |
-
from .chains.retrieve_documents import make_IPx_retriever_node, make_POC_retriever_node, make_POC_by_ToC_retriever_node
|
23 |
-
from .chains.answer_rag import make_rag_node
|
24 |
-
from .chains.graph_retriever import make_graph_retriever_node
|
25 |
-
from .chains.chitchat_categorization import make_chitchat_intent_categorization_node
|
26 |
-
from .chains.standalone_question import make_standalone_question_node
|
27 |
-
from .chains.follow_up import make_follow_up_node # Add this import
|
28 |
-
|
29 |
-
class GraphState(TypedDict):
|
30 |
-
"""
|
31 |
-
Represents the state of our graph.
|
32 |
-
"""
|
33 |
-
user_input : str
|
34 |
-
chat_history : str
|
35 |
-
language : str
|
36 |
-
intent : str
|
37 |
-
search_graphs_chitchat : bool
|
38 |
-
query: str
|
39 |
-
questions_list : List[dict]
|
40 |
-
handled_questions_index : Annotated[list[int], operator.add]
|
41 |
-
n_questions : int
|
42 |
-
answer: str
|
43 |
-
audience: str = "experts"
|
44 |
-
sources_input: List[str] = ["IPCC","IPBES"] # Deprecated -> used only graphs that can only be OWID
|
45 |
-
relevant_content_sources_selection: List[str] = ["Figures (IPCC/IPBES)"]
|
46 |
-
sources_auto: bool = True
|
47 |
-
min_year: int = 1960
|
48 |
-
max_year: int = None
|
49 |
-
documents: Annotated[List[Document], operator.add]
|
50 |
-
related_contents : Annotated[List[Document], operator.add] # Images
|
51 |
-
recommended_content : List[Document] # OWID Graphs # TODO merge with related_contents
|
52 |
-
search_only : bool = False
|
53 |
-
reports : List[str] = []
|
54 |
-
follow_up_questions: List[str] = []
|
55 |
-
|
56 |
-
def dummy(state):
|
57 |
-
return
|
58 |
-
|
59 |
-
def search(state): #TODO
|
60 |
-
return
|
61 |
-
|
62 |
-
def answer_search(state):#TODO
|
63 |
-
return
|
64 |
-
|
65 |
-
def route_intent(state):
|
66 |
-
intent = state["intent"]
|
67 |
-
if intent in ["chitchat","esg"]:
|
68 |
-
return "answer_chitchat"
|
69 |
-
# elif intent == "ai_impact":
|
70 |
-
# return "answer_ai_impact"
|
71 |
-
else:
|
72 |
-
# Search route
|
73 |
-
return "answer_climate"
|
74 |
-
|
75 |
-
def chitchat_route_intent(state):
|
76 |
-
intent = state["search_graphs_chitchat"]
|
77 |
-
if intent is True:
|
78 |
-
return END #TODO
|
79 |
-
elif intent is False:
|
80 |
-
return END
|
81 |
-
|
82 |
-
def route_translation(state):
|
83 |
-
if state["language"].lower() == "english":
|
84 |
-
return "transform_query"
|
85 |
-
else:
|
86 |
-
return "transform_query"
|
87 |
-
# return "translate_query" #TODO : add translation
|
88 |
-
|
89 |
-
|
90 |
-
def route_based_on_relevant_docs(state,threshold_docs=0.2):
|
91 |
-
docs = [x for x in state["documents"] if x.metadata["reranking_score"] > threshold_docs]
|
92 |
-
print("Route : ", ["answer_rag" if len(docs) > 0 else "answer_rag_no_docs"])
|
93 |
-
if len(docs) > 0:
|
94 |
-
return "answer_rag"
|
95 |
-
else:
|
96 |
-
return "answer_rag_no_docs"
|
97 |
-
|
98 |
-
def route_continue_retrieve_documents(state):
|
99 |
-
index_question_ipx = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "IPx"]
|
100 |
-
questions_ipx_finished = all(elem in state["handled_questions_index"] for elem in index_question_ipx)
|
101 |
-
if questions_ipx_finished:
|
102 |
-
return "end_retrieve_IPx_documents"
|
103 |
-
else:
|
104 |
-
return "retrieve_documents"
|
105 |
-
|
106 |
-
|
107 |
-
def route_retrieve_documents(state):
|
108 |
-
sources_to_retrieve = []
|
109 |
-
|
110 |
-
if "Graphs (OurWorldInData)" in state["relevant_content_sources_selection"] :
|
111 |
-
sources_to_retrieve.append("retrieve_graphs")
|
112 |
-
|
113 |
-
if sources_to_retrieve == []:
|
114 |
-
return END
|
115 |
-
return sources_to_retrieve
|
116 |
-
|
117 |
-
def route_follow_up(state):
|
118 |
-
if state["follow_up_questions"]:
|
119 |
-
return "process_follow_up"
|
120 |
-
return END
|
121 |
-
|
122 |
-
def make_id_dict(values):
|
123 |
-
return {k:k for k in values}
|
124 |
-
|
125 |
-
def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_region, reranker, threshold_docs=0.2):
|
126 |
-
|
127 |
-
workflow = StateGraph(GraphState)
|
128 |
-
|
129 |
-
# Define the node functions
|
130 |
-
standalone_question_node = make_standalone_question_node(llm)
|
131 |
-
categorize_intent = make_intent_categorization_node(llm)
|
132 |
-
transform_query = make_query_transform_node(llm)
|
133 |
-
translate_query = make_translation_node(llm)
|
134 |
-
answer_chitchat = make_chitchat_node(llm)
|
135 |
-
answer_ai_impact = make_ai_impact_node(llm)
|
136 |
-
retrieve_documents = make_IPx_retriever_node(vectorstore_ipcc, reranker, llm)
|
137 |
-
retrieve_graphs = make_graph_retriever_node(vectorstore_graphs, reranker)
|
138 |
-
# retrieve_local_data = make_POC_retriever_node(vectorstore_region, reranker, llm)
|
139 |
-
answer_rag = make_rag_node(llm, with_docs=True)
|
140 |
-
answer_rag_no_docs = make_rag_node(llm, with_docs=False)
|
141 |
-
chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
|
142 |
-
generate_follow_up = make_follow_up_node(llm)
|
143 |
-
|
144 |
-
# Define the nodes
|
145 |
-
# workflow.add_node("set_defaults", set_defaults)
|
146 |
-
workflow.add_node("standalone_question", standalone_question_node)
|
147 |
-
workflow.add_node("categorize_intent", categorize_intent)
|
148 |
-
workflow.add_node("answer_climate", dummy)
|
149 |
-
workflow.add_node("answer_search", answer_search)
|
150 |
-
workflow.add_node("transform_query", transform_query)
|
151 |
-
workflow.add_node("translate_query", translate_query)
|
152 |
-
workflow.add_node("answer_chitchat", answer_chitchat)
|
153 |
-
workflow.add_node("chitchat_categorize_intent", chitchat_categorize_intent)
|
154 |
-
workflow.add_node("retrieve_graphs", retrieve_graphs)
|
155 |
-
# workflow.add_node("retrieve_local_data", retrieve_local_data)
|
156 |
-
workflow.add_node("retrieve_graphs_chitchat", retrieve_graphs)
|
157 |
-
workflow.add_node("retrieve_documents", retrieve_documents)
|
158 |
-
workflow.add_node("answer_rag", answer_rag)
|
159 |
-
workflow.add_node("answer_rag_no_docs", answer_rag_no_docs)
|
160 |
-
workflow.add_node("generate_follow_up", generate_follow_up)
|
161 |
-
# workflow.add_node("process_follow_up", standalone_question_node)
|
162 |
-
|
163 |
-
# Entry point
|
164 |
-
workflow.set_entry_point("standalone_question")
|
165 |
-
|
166 |
-
# CONDITIONAL EDGES
|
167 |
-
workflow.add_conditional_edges(
|
168 |
-
"categorize_intent",
|
169 |
-
route_intent,
|
170 |
-
make_id_dict(["answer_chitchat","answer_climate"])
|
171 |
-
)
|
172 |
-
|
173 |
-
workflow.add_conditional_edges(
|
174 |
-
"chitchat_categorize_intent",
|
175 |
-
chitchat_route_intent,
|
176 |
-
make_id_dict(["retrieve_graphs_chitchat", END])
|
177 |
-
)
|
178 |
-
|
179 |
-
workflow.add_conditional_edges(
|
180 |
-
"answer_climate",
|
181 |
-
route_translation,
|
182 |
-
make_id_dict(["translate_query","transform_query"])
|
183 |
-
)
|
184 |
-
|
185 |
-
workflow.add_conditional_edges(
|
186 |
-
"answer_search",
|
187 |
-
lambda x : route_based_on_relevant_docs(x,threshold_docs=threshold_docs),
|
188 |
-
make_id_dict(["answer_rag","answer_rag_no_docs"])
|
189 |
-
)
|
190 |
-
workflow.add_conditional_edges(
|
191 |
-
"transform_query",
|
192 |
-
route_retrieve_documents,
|
193 |
-
make_id_dict(["retrieve_graphs", END])
|
194 |
-
)
|
195 |
-
|
196 |
-
# workflow.add_conditional_edges(
|
197 |
-
# "generate_follow_up",
|
198 |
-
# route_follow_up,
|
199 |
-
# make_id_dict(["process_follow_up", END])
|
200 |
-
# )
|
201 |
-
|
202 |
-
# Define the edges
|
203 |
-
workflow.add_edge("standalone_question", "categorize_intent")
|
204 |
-
workflow.add_edge("translate_query", "transform_query")
|
205 |
-
workflow.add_edge("transform_query", "retrieve_documents") #TODO put back
|
206 |
-
# workflow.add_edge("transform_query", "retrieve_local_data")
|
207 |
-
# workflow.add_edge("transform_query", END) # TODO remove
|
208 |
-
|
209 |
-
workflow.add_edge("retrieve_graphs", END)
|
210 |
-
workflow.add_edge("answer_rag", "generate_follow_up")
|
211 |
-
workflow.add_edge("answer_rag_no_docs", "generate_follow_up")
|
212 |
-
workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
|
213 |
-
workflow.add_edge("retrieve_graphs_chitchat", END)
|
214 |
-
|
215 |
-
# workflow.add_edge("retrieve_local_data", "answer_search")
|
216 |
-
workflow.add_edge("retrieve_documents", "answer_search")
|
217 |
-
workflow.add_edge("generate_follow_up",END)
|
218 |
-
# workflow.add_edge("process_follow_up", "categorize_intent")
|
219 |
-
|
220 |
-
# Compile
|
221 |
-
app = workflow.compile()
|
222 |
-
return app
|
223 |
-
|
224 |
-
def make_graph_agent_poc(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_region, reranker, version:str, threshold_docs=0.2):
|
225 |
-
"""_summary_
|
226 |
-
|
227 |
-
Args:
|
228 |
-
llm (_type_): _description_
|
229 |
-
vectorstore_ipcc (_type_): _description_
|
230 |
-
vectorstore_graphs (_type_): _description_
|
231 |
-
vectorstore_region (_type_): _description_
|
232 |
-
reranker (_type_): _description_
|
233 |
-
version (str): version of the parsed documents (e.g "v4")
|
234 |
-
threshold_docs (float, optional): _description_. Defaults to 0.2.
|
235 |
-
|
236 |
-
Returns:
|
237 |
-
_type_: _description_
|
238 |
-
"""
|
239 |
-
|
240 |
-
|
241 |
-
workflow = StateGraph(GraphState)
|
242 |
-
|
243 |
-
# Define the node functions
|
244 |
-
standalone_question_node = make_standalone_question_node(llm)
|
245 |
-
|
246 |
-
categorize_intent = make_intent_categorization_node(llm)
|
247 |
-
transform_query = make_query_transform_node(llm)
|
248 |
-
translate_query = make_translation_node(llm)
|
249 |
-
answer_chitchat = make_chitchat_node(llm)
|
250 |
-
answer_ai_impact = make_ai_impact_node(llm)
|
251 |
-
retrieve_documents = make_IPx_retriever_node(vectorstore_ipcc, reranker, llm)
|
252 |
-
retrieve_graphs = make_graph_retriever_node(vectorstore_graphs, reranker)
|
253 |
-
# retrieve_local_data = make_POC_retriever_node(vectorstore_region, reranker, llm)
|
254 |
-
retrieve_local_data = make_POC_by_ToC_retriever_node(vectorstore_region, reranker, llm, version=version)
|
255 |
-
answer_rag = make_rag_node(llm, with_docs=True)
|
256 |
-
answer_rag_no_docs = make_rag_node(llm, with_docs=False)
|
257 |
-
chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
|
258 |
-
generate_follow_up = make_follow_up_node(llm)
|
259 |
-
|
260 |
-
# Define the nodes
|
261 |
-
# workflow.add_node("set_defaults", set_defaults)
|
262 |
-
workflow.add_node("standalone_question", standalone_question_node)
|
263 |
-
workflow.add_node("categorize_intent", categorize_intent)
|
264 |
-
workflow.add_node("answer_climate", dummy)
|
265 |
-
workflow.add_node("answer_search", answer_search)
|
266 |
-
# workflow.add_node("end_retrieve_local_documents", dummy)
|
267 |
-
# workflow.add_node("end_retrieve_IPx_documents", dummy)
|
268 |
-
workflow.add_node("transform_query", transform_query)
|
269 |
-
workflow.add_node("translate_query", translate_query)
|
270 |
-
workflow.add_node("answer_chitchat", answer_chitchat)
|
271 |
-
workflow.add_node("chitchat_categorize_intent", chitchat_categorize_intent)
|
272 |
-
workflow.add_node("retrieve_graphs", retrieve_graphs)
|
273 |
-
workflow.add_node("retrieve_local_data", retrieve_local_data)
|
274 |
-
workflow.add_node("retrieve_graphs_chitchat", retrieve_graphs)
|
275 |
-
workflow.add_node("retrieve_documents", retrieve_documents)
|
276 |
-
workflow.add_node("answer_rag", answer_rag)
|
277 |
-
workflow.add_node("answer_rag_no_docs", answer_rag_no_docs)
|
278 |
-
workflow.add_node("generate_follow_up", generate_follow_up)
|
279 |
-
|
280 |
-
# Entry point
|
281 |
-
workflow.set_entry_point("standalone_question")
|
282 |
-
|
283 |
-
# CONDITIONAL EDGES
|
284 |
-
workflow.add_conditional_edges(
|
285 |
-
"categorize_intent",
|
286 |
-
route_intent,
|
287 |
-
make_id_dict(["answer_chitchat","answer_climate"])
|
288 |
-
)
|
289 |
-
|
290 |
-
workflow.add_conditional_edges(
|
291 |
-
"chitchat_categorize_intent",
|
292 |
-
chitchat_route_intent,
|
293 |
-
make_id_dict(["retrieve_graphs_chitchat", END])
|
294 |
-
)
|
295 |
-
|
296 |
-
workflow.add_conditional_edges(
|
297 |
-
"answer_climate",
|
298 |
-
route_translation,
|
299 |
-
make_id_dict(["translate_query","transform_query"])
|
300 |
-
)
|
301 |
-
|
302 |
-
workflow.add_conditional_edges(
|
303 |
-
"answer_search",
|
304 |
-
lambda x : route_based_on_relevant_docs(x,threshold_docs=threshold_docs),
|
305 |
-
make_id_dict(["answer_rag","answer_rag_no_docs"])
|
306 |
-
)
|
307 |
-
workflow.add_conditional_edges(
|
308 |
-
"transform_query",
|
309 |
-
route_retrieve_documents,
|
310 |
-
make_id_dict(["retrieve_graphs", END])
|
311 |
-
)
|
312 |
-
|
313 |
-
# Define the edges
|
314 |
-
workflow.add_edge("standalone_question", "categorize_intent")
|
315 |
-
workflow.add_edge("translate_query", "transform_query")
|
316 |
-
workflow.add_edge("transform_query", "retrieve_documents") #TODO put back
|
317 |
-
workflow.add_edge("transform_query", "retrieve_local_data")
|
318 |
-
# workflow.add_edge("transform_query", END) # TODO remove
|
319 |
-
|
320 |
-
workflow.add_edge("retrieve_graphs", END)
|
321 |
-
workflow.add_edge("answer_rag", "generate_follow_up")
|
322 |
-
workflow.add_edge("answer_rag_no_docs", "generate_follow_up")
|
323 |
-
workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
|
324 |
-
workflow.add_edge("retrieve_graphs_chitchat", END)
|
325 |
-
|
326 |
-
workflow.add_edge("retrieve_local_data", "answer_search")
|
327 |
-
workflow.add_edge("retrieve_documents", "answer_search")
|
328 |
-
workflow.add_edge("generate_follow_up",END)
|
329 |
-
|
330 |
-
|
331 |
-
# Compile
|
332 |
-
app = workflow.compile()
|
333 |
-
return app
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
def display_graph(app):
|
339 |
-
|
340 |
-
display(
|
341 |
-
Image(
|
342 |
-
app.get_graph(xray = True).draw_mermaid_png(
|
343 |
-
draw_method=MermaidDrawMethod.API,
|
344 |
-
)
|
345 |
-
)
|
346 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/graph_retriever.py
DELETED
@@ -1,88 +0,0 @@
|
|
1 |
-
from langchain_core.retrievers import BaseRetriever
|
2 |
-
from langchain_core.documents.base import Document
|
3 |
-
from langchain_core.vectorstores import VectorStore
|
4 |
-
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
|
5 |
-
|
6 |
-
from typing import List
|
7 |
-
|
8 |
-
# class GraphRetriever(BaseRetriever):
|
9 |
-
# vectorstore:VectorStore
|
10 |
-
# sources:list = ["OWID"] # plus tard ajouter OurWorldInData # faudra integrate avec l'autre retriever
|
11 |
-
# threshold:float = 0.5
|
12 |
-
# k_total:int = 10
|
13 |
-
|
14 |
-
# def _get_relevant_documents(
|
15 |
-
# self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
16 |
-
# ) -> List[Document]:
|
17 |
-
|
18 |
-
# # Check if all elements in the list are IEA or OWID
|
19 |
-
# assert isinstance(self.sources,list)
|
20 |
-
# assert self.sources
|
21 |
-
# assert any([x in ["OWID"] for x in self.sources])
|
22 |
-
|
23 |
-
# # Prepare base search kwargs
|
24 |
-
# filters = {}
|
25 |
-
|
26 |
-
# filters["source"] = {"$in": self.sources}
|
27 |
-
|
28 |
-
# docs = self.vectorstore.similarity_search_with_score(query=query, filter=filters, k=self.k_total)
|
29 |
-
|
30 |
-
# # Filter if scores are below threshold
|
31 |
-
# docs = [x for x in docs if x[1] > self.threshold]
|
32 |
-
|
33 |
-
# # Remove duplicate documents
|
34 |
-
# unique_docs = []
|
35 |
-
# seen_docs = []
|
36 |
-
# for i, doc in enumerate(docs):
|
37 |
-
# if doc[0].page_content not in seen_docs:
|
38 |
-
# unique_docs.append(doc)
|
39 |
-
# seen_docs.append(doc[0].page_content)
|
40 |
-
|
41 |
-
# # Add score to metadata
|
42 |
-
# results = []
|
43 |
-
# for i,(doc,score) in enumerate(unique_docs):
|
44 |
-
# doc.metadata["similarity_score"] = score
|
45 |
-
# doc.metadata["content"] = doc.page_content
|
46 |
-
# results.append(doc)
|
47 |
-
|
48 |
-
# return results
|
49 |
-
|
50 |
-
async def retrieve_graphs(
|
51 |
-
query: str,
|
52 |
-
vectorstore:VectorStore,
|
53 |
-
sources:list = ["OWID"], # plus tard ajouter OurWorldInData # faudra integrate avec l'autre retriever
|
54 |
-
threshold:float = 0.5,
|
55 |
-
k_total:int = 10,
|
56 |
-
)-> List[Document]:
|
57 |
-
|
58 |
-
# Check if all elements in the list are IEA or OWID
|
59 |
-
assert isinstance(sources,list)
|
60 |
-
assert sources
|
61 |
-
assert any([x in ["OWID"] for x in sources])
|
62 |
-
|
63 |
-
# Prepare base search kwargs
|
64 |
-
filters = {}
|
65 |
-
|
66 |
-
filters["source"] = {"$in": sources}
|
67 |
-
|
68 |
-
docs = vectorstore.similarity_search_with_score(query=query, filter=filters, k=k_total)
|
69 |
-
|
70 |
-
# Filter if scores are below threshold
|
71 |
-
docs = [x for x in docs if x[1] > threshold]
|
72 |
-
|
73 |
-
# Remove duplicate documents
|
74 |
-
unique_docs = []
|
75 |
-
seen_docs = []
|
76 |
-
for i, doc in enumerate(docs):
|
77 |
-
if doc[0].page_content not in seen_docs:
|
78 |
-
unique_docs.append(doc)
|
79 |
-
seen_docs.append(doc[0].page_content)
|
80 |
-
|
81 |
-
# Add score to metadata
|
82 |
-
results = []
|
83 |
-
for i,(doc,score) in enumerate(unique_docs):
|
84 |
-
doc.metadata["similarity_score"] = score
|
85 |
-
doc.metadata["content"] = doc.page_content
|
86 |
-
results.append(doc)
|
87 |
-
|
88 |
-
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/keywords.py
CHANGED
@@ -11,12 +11,10 @@ class KeywordsOutput(BaseModel):
|
|
11 |
|
12 |
keywords: list = Field(
|
13 |
description="""
|
14 |
-
Generate 1 or 2 relevant keywords from the user query to ask a search engine for scientific research papers.
|
15 |
-
Do not use special characters or accents.
|
16 |
|
17 |
Example:
|
18 |
- "What is the impact of deep sea mining ?" -> ["deep sea mining"]
|
19 |
-
- "Quel est l'impact de l'exploitation minière en haute mer ?" -> ["deep sea mining"]
|
20 |
- "How will El Nino be impacted by climate change" -> ["el nino"]
|
21 |
- "Is climate change a hoax" -> [Climate change","hoax"]
|
22 |
"""
|
|
|
11 |
|
12 |
keywords: list = Field(
|
13 |
description="""
|
14 |
+
Generate 1 or 2 relevant keywords from the user query to ask a search engine for scientific research papers.
|
|
|
15 |
|
16 |
Example:
|
17 |
- "What is the impact of deep sea mining ?" -> ["deep sea mining"]
|
|
|
18 |
- "How will El Nino be impacted by climate change" -> ["el nino"]
|
19 |
- "Is climate change a hoax" -> [Climate change","hoax"]
|
20 |
"""
|
climateqa/engine/llm/__init__.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
from climateqa.engine.llm.openai import get_llm as get_openai_llm
|
2 |
from climateqa.engine.llm.azure import get_llm as get_azure_llm
|
3 |
-
from climateqa.engine.llm.ollama import get_llm as get_ollama_llm
|
4 |
|
5 |
|
6 |
def get_llm(provider="openai",**kwargs):
|
@@ -9,8 +8,6 @@ def get_llm(provider="openai",**kwargs):
|
|
9 |
return get_openai_llm(**kwargs)
|
10 |
elif provider == "azure":
|
11 |
return get_azure_llm(**kwargs)
|
12 |
-
elif provider == "ollama":
|
13 |
-
return get_ollama_llm(**kwargs)
|
14 |
else:
|
15 |
raise ValueError(f"Unknown provider: {provider}")
|
16 |
|
|
|
1 |
from climateqa.engine.llm.openai import get_llm as get_openai_llm
|
2 |
from climateqa.engine.llm.azure import get_llm as get_azure_llm
|
|
|
3 |
|
4 |
|
5 |
def get_llm(provider="openai",**kwargs):
|
|
|
8 |
return get_openai_llm(**kwargs)
|
9 |
elif provider == "azure":
|
10 |
return get_azure_llm(**kwargs)
|
|
|
|
|
11 |
else:
|
12 |
raise ValueError(f"Unknown provider: {provider}")
|
13 |
|
climateqa/engine/llm/ollama.py
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
from langchain_community.llms import Ollama
|
4 |
-
|
5 |
-
def get_llm(model="llama3", **kwargs):
|
6 |
-
return Ollama(model=model, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/llm/openai.py
CHANGED
@@ -7,7 +7,7 @@ try:
|
|
7 |
except Exception:
|
8 |
pass
|
9 |
|
10 |
-
def get_llm(model="gpt-
|
11 |
|
12 |
llm = ChatOpenAI(
|
13 |
model=model,
|
|
|
7 |
except Exception:
|
8 |
pass
|
9 |
|
10 |
+
def get_llm(model="gpt-3.5-turbo-0125",max_tokens=1024, temperature=0.0, streaming=True,timeout=30, **kwargs):
|
11 |
|
12 |
llm = ChatOpenAI(
|
13 |
model=model,
|
climateqa/engine/{chains/prompts.py → prompts.py}
RENAMED
@@ -36,41 +36,13 @@ You are given a question and extracted passages of the IPCC and/or IPBES reports
|
|
36 |
"""
|
37 |
|
38 |
|
39 |
-
# answer_prompt_template_old = """
|
40 |
-
# You are ClimateQ&A, an AI Assistant created by Ekimetrics. You are given a question and extracted passages of reports. Provide a clear and structured answer based on the passages provided, the context and the guidelines.
|
41 |
-
|
42 |
-
# Guidelines:
|
43 |
-
# - If the passages have useful facts or numbers, use them in your answer.
|
44 |
-
# - When you use information from a passage, mention where it came from by using [Doc i] at the end of the sentence. i stands for the number of the document.
|
45 |
-
# - Do not use the sentence 'Doc i says ...' to say where information came from.
|
46 |
-
# - If the same thing is said in more than one document, you can mention all of them like this: [Doc i, Doc j, Doc k]
|
47 |
-
# - Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
|
48 |
-
# - If it makes sense, use bullet points and lists to make your answers easier to understand.
|
49 |
-
# - You do not need to use every passage. Only use the ones that help answer the question.
|
50 |
-
# - If the documents do not have the information needed to answer the question, just say you do not have enough information.
|
51 |
-
# - Consider by default that the question is about the past century unless it is specified otherwise.
|
52 |
-
# - If the passage is the caption of a picture, you can still use it as part of your answer as any other document.
|
53 |
-
|
54 |
-
# -----------------------
|
55 |
-
# Passages:
|
56 |
-
# {context}
|
57 |
-
|
58 |
-
# -----------------------
|
59 |
-
# Question: {query} - Explained to {audience}
|
60 |
-
# Answer in {language} with the passages citations:
|
61 |
-
# """
|
62 |
-
|
63 |
answer_prompt_template = """
|
64 |
-
You are ClimateQ&A, an AI Assistant created by Ekimetrics. You are given a question and extracted passages of reports. Provide a clear and structured answer based on the passages provided, the context and the guidelines.
|
65 |
|
66 |
Guidelines:
|
67 |
- If the passages have useful facts or numbers, use them in your answer.
|
68 |
- When you use information from a passage, mention where it came from by using [Doc i] at the end of the sentence. i stands for the number of the document.
|
69 |
-
-
|
70 |
-
- The different sources are IPCC, IPBES, PPCP (for Plan Climat Air Energie Territorial de Paris), PBDP (for Plan Biodiversité de Paris), Acclimaterra (Rapport scientifique de la région Nouvelle Aquitaine en France).
|
71 |
-
- If the reports are local (like PPCP, PBDP, Acclimaterra), consider that the information is specific to the region and not global. If the document is about a nearby region (for example, an extract from Acclimaterra for a question about Britain), explicitly state the concerned region.
|
72 |
-
- Do not mention that you are using specific extract documents, but mention only the source information. "According to IPCC, ..." rather than "According to the provided document from IPCC ..."
|
73 |
-
- Make a clear distinction between information from IPCC, IPBES, Acclimaterra that are scientific reports and PPCP, PBDP that are strategic reports. Strategic reports should not be taken as verified facts, but as political or strategic decisions.
|
74 |
- If the same thing is said in more than one document, you can mention all of them like this: [Doc i, Doc j, Doc k]
|
75 |
- Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
|
76 |
- If it makes sense, use bullet points and lists to make your answers easier to understand.
|
@@ -79,16 +51,16 @@ Guidelines:
|
|
79 |
- Consider by default that the question is about the past century unless it is specified otherwise.
|
80 |
- If the passage is the caption of a picture, you can still use it as part of your answer as any other document.
|
81 |
|
82 |
-
|
83 |
-----------------------
|
84 |
Passages:
|
85 |
{context}
|
86 |
|
87 |
-----------------------
|
88 |
-
Question: {
|
89 |
Answer in {language} with the passages citations:
|
90 |
"""
|
91 |
|
|
|
92 |
papers_prompt_template = """
|
93 |
You are ClimateQ&A, an AI Assistant created by Ekimetrics. You are given a question and extracted abstracts of scientific papers. Provide a clear and structured answer based on the abstracts provided, the context and the guidelines.
|
94 |
|
@@ -165,7 +137,7 @@ Guidelines:
|
|
165 |
- If the question is not related to environmental issues, never never answer it. Say it's not your role.
|
166 |
- Make paragraphs by starting new lines to make your answers more readable.
|
167 |
|
168 |
-
Question: {
|
169 |
Answer in {language}:
|
170 |
"""
|
171 |
|
@@ -175,77 +147,4 @@ audience_prompts = {
|
|
175 |
"children": "6 year old children that don't know anything about science and climate change and need metaphors to learn",
|
176 |
"general": "the general public who know the basics in science and climate change and want to learn more about it without technical terms. Still use references to passages.",
|
177 |
"experts": "expert and climate scientists that are not afraid of technical terms",
|
178 |
-
}
|
179 |
-
|
180 |
-
|
181 |
-
answer_prompt_graph_template = """
|
182 |
-
Given the user question and a list of graphs which are related to the question, rank the graphs based on relevance to the user question. ALWAYS follow the guidelines given below.
|
183 |
-
|
184 |
-
### Guidelines ###
|
185 |
-
- Keep all the graphs that are given to you.
|
186 |
-
- NEVER modify the graph HTML embedding, the category or the source leave them exactly as they are given.
|
187 |
-
- Return the ranked graphs as a list of dictionaries with keys 'embedding', 'category', and 'source'.
|
188 |
-
- Return a valid JSON output.
|
189 |
-
|
190 |
-
-----------------------
|
191 |
-
User question:
|
192 |
-
{query}
|
193 |
-
|
194 |
-
Graphs and their HTML embedding:
|
195 |
-
{recommended_content}
|
196 |
-
|
197 |
-
-----------------------
|
198 |
-
{format_instructions}
|
199 |
-
|
200 |
-
Output the result as json with a key "graphs" containing a list of dictionaries of the relevant graphs with keys 'embedding', 'category', and 'source'. Do not modify the graph HTML embedding, the category or the source. Do not put any message or text before or after the JSON output.
|
201 |
-
"""
|
202 |
-
|
203 |
-
retrieve_chapter_prompt_template = """Given the user question and a list of documents with their table of contents, retrieve the 5 most relevant level 0 chapters which could help to answer to the question while taking account their sub-chapters.
|
204 |
-
|
205 |
-
The table of contents is structured like that :
|
206 |
-
{{
|
207 |
-
"level": 0,
|
208 |
-
"Chapter 1": {{}},
|
209 |
-
"Chapter 2" : {{
|
210 |
-
"level": 1,
|
211 |
-
"Chapter 2.1": {{
|
212 |
-
...
|
213 |
-
}}
|
214 |
-
}},
|
215 |
-
}}
|
216 |
-
|
217 |
-
Here level is the level of the chapter. For example, Chapter 1 and Chapter 2 are at level 0, and Chapter 2.1 is at level 1.
|
218 |
-
|
219 |
-
### Guidelines ###
|
220 |
-
- Keep all the list of documents that is given to you
|
221 |
-
- Each chapter must keep **EXACTLY** its assigned level in the table of contents. **DO NOT MODIFY THE LEVELS. **
|
222 |
-
- Check systematically the level of a chapter before including it in the answer.
|
223 |
-
- Return **valid JSON** result.
|
224 |
-
|
225 |
-
--------------------
|
226 |
-
User question :
|
227 |
-
{query}
|
228 |
-
|
229 |
-
List of documents with their table of contents :
|
230 |
-
{doc_list}
|
231 |
-
|
232 |
-
--------------------
|
233 |
-
|
234 |
-
Return a JSON result with a list of relevant chapters with the following keys **WITHOUT** the json markdown indicator ```json at the beginning:
|
235 |
-
- "document" : the document in which we can find the chapter
|
236 |
-
- "chapter" : the title of the chapter
|
237 |
-
|
238 |
-
**IMPORTANT : Make sure that the levels of the answer are exactly the same as the ones in the table of contents**
|
239 |
-
|
240 |
-
Example of a JSON response:
|
241 |
-
[
|
242 |
-
{{
|
243 |
-
"document": "Document A",
|
244 |
-
"chapter": "Chapter 1",
|
245 |
-
}},
|
246 |
-
{{
|
247 |
-
"document": "Document B",
|
248 |
-
"chapter": "Chapter 5",
|
249 |
-
}}
|
250 |
-
]
|
251 |
-
"""
|
|
|
36 |
"""
|
37 |
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
answer_prompt_template = """
|
40 |
+
You are ClimateQ&A, an AI Assistant created by Ekimetrics. You are given a question and extracted passages of the IPCC and/or IPBES reports. Provide a clear and structured answer based on the passages provided, the context and the guidelines.
|
41 |
|
42 |
Guidelines:
|
43 |
- If the passages have useful facts or numbers, use them in your answer.
|
44 |
- When you use information from a passage, mention where it came from by using [Doc i] at the end of the sentence. i stands for the number of the document.
|
45 |
+
- Do not use the sentence 'Doc i says ...' to say where information came from.
|
|
|
|
|
|
|
|
|
46 |
- If the same thing is said in more than one document, you can mention all of them like this: [Doc i, Doc j, Doc k]
|
47 |
- Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
|
48 |
- If it makes sense, use bullet points and lists to make your answers easier to understand.
|
|
|
51 |
- Consider by default that the question is about the past century unless it is specified otherwise.
|
52 |
- If the passage is the caption of a picture, you can still use it as part of your answer as any other document.
|
53 |
|
|
|
54 |
-----------------------
|
55 |
Passages:
|
56 |
{context}
|
57 |
|
58 |
-----------------------
|
59 |
+
Question: {question} - Explained to {audience}
|
60 |
Answer in {language} with the passages citations:
|
61 |
"""
|
62 |
|
63 |
+
|
64 |
papers_prompt_template = """
|
65 |
You are ClimateQ&A, an AI Assistant created by Ekimetrics. You are given a question and extracted abstracts of scientific papers. Provide a clear and structured answer based on the abstracts provided, the context and the guidelines.
|
66 |
|
|
|
137 |
- If the question is not related to environmental issues, never never answer it. Say it's not your role.
|
138 |
- Make paragraphs by starting new lines to make your answers more readable.
|
139 |
|
140 |
+
Question: {question}
|
141 |
Answer in {language}:
|
142 |
"""
|
143 |
|
|
|
147 |
"children": "6 year old children that don't know anything about science and climate change and need metaphors to learn",
|
148 |
"general": "the general public who know the basics in science and climate change and want to learn more about it without technical terms. Still use references to passages.",
|
149 |
"experts": "expert and climate scientists that are not afraid of technical terms",
|
150 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/{chains/answer_rag.py → rag.py}
RENAMED
@@ -2,16 +2,17 @@ from operator import itemgetter
|
|
2 |
|
3 |
from langchain_core.prompts import ChatPromptTemplate
|
4 |
from langchain_core.output_parsers import StrOutputParser
|
|
|
5 |
from langchain_core.prompts.prompt import PromptTemplate
|
6 |
from langchain_core.prompts.base import format_document
|
7 |
|
8 |
-
from climateqa.engine.
|
9 |
-
from climateqa.engine.
|
10 |
-
import
|
11 |
-
from
|
|
|
12 |
|
13 |
-
|
14 |
-
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="Source : {source} - Content : {page_content}")
|
15 |
|
16 |
def _combine_documents(
|
17 |
docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, sep="\n\n"
|
@@ -39,54 +40,72 @@ def get_text_docs(x):
|
|
39 |
def get_image_docs(x):
|
40 |
return [doc for doc in x if doc.metadata["chunk_type"] == "image"]
|
41 |
|
42 |
-
|
|
|
|
|
|
|
43 |
prompt = ChatPromptTemplate.from_template(answer_prompt_template)
|
44 |
-
|
45 |
-
"context":lambda x : _combine_documents(x["documents"]),
|
46 |
-
"context_length":lambda x : print("CONTEXT LENGTH : " , len(_combine_documents(x["documents"]))),
|
47 |
-
"query":itemgetter("query"),
|
48 |
-
"language":itemgetter("language"),
|
49 |
-
"audience":itemgetter("audience"),
|
50 |
-
} | prompt | llm | StrOutputParser())
|
51 |
-
return chain
|
52 |
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
return chain
|
57 |
|
58 |
-
|
|
|
|
|
|
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
async def answer_rag(state,config):
|
66 |
-
print("---- Answer RAG ----")
|
67 |
-
start_time = time.time()
|
68 |
-
chat_history = state.get("chat_history",[])
|
69 |
-
print("Sources used : " + "\n".join([x.metadata["short_name"] + " - page " + str(x.metadata["page_number"]) for x in state["documents"]]))
|
70 |
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
-
|
74 |
-
|
75 |
-
print("RAG elapsed time: ", elapsed_time)
|
76 |
-
print("Answer size : ", len(answer))
|
77 |
-
|
78 |
-
chat_history.append({"question":state["query"],"answer":answer})
|
79 |
-
|
80 |
-
return {"answer":answer,"chat_history": chat_history}
|
81 |
|
82 |
-
|
|
|
|
|
|
|
83 |
|
84 |
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
|
87 |
def make_rag_papers_chain(llm):
|
88 |
|
89 |
prompt = ChatPromptTemplate.from_template(papers_prompt_template)
|
|
|
90 |
input_documents = {
|
91 |
"context":lambda x : _combine_documents(x["docs"]),
|
92 |
**pass_values(["question","language"])
|
@@ -112,4 +131,4 @@ def make_illustration_chain(llm):
|
|
112 |
}
|
113 |
|
114 |
illustration_chain = input_description_images | prompt_with_images | llm | StrOutputParser()
|
115 |
-
return illustration_chain
|
|
|
2 |
|
3 |
from langchain_core.prompts import ChatPromptTemplate
|
4 |
from langchain_core.output_parsers import StrOutputParser
|
5 |
+
from langchain_core.runnables import RunnablePassthrough, RunnableLambda, RunnableBranch
|
6 |
from langchain_core.prompts.prompt import PromptTemplate
|
7 |
from langchain_core.prompts.base import format_document
|
8 |
|
9 |
+
from climateqa.engine.reformulation import make_reformulation_chain
|
10 |
+
from climateqa.engine.prompts import answer_prompt_template,answer_prompt_without_docs_template,answer_prompt_images_template
|
11 |
+
from climateqa.engine.prompts import papers_prompt_template
|
12 |
+
from climateqa.engine.utils import pass_values, flatten_dict,prepare_chain,rename_chain
|
13 |
+
from climateqa.engine.keywords import make_keywords_chain
|
14 |
|
15 |
+
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
|
|
|
16 |
|
17 |
def _combine_documents(
|
18 |
docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, sep="\n\n"
|
|
|
40 |
def get_image_docs(x):
|
41 |
return [doc for doc in x if doc.metadata["chunk_type"] == "image"]
|
42 |
|
43 |
+
|
44 |
+
def make_rag_chain(retriever,llm):
|
45 |
+
|
46 |
+
# Construct the prompt
|
47 |
prompt = ChatPromptTemplate.from_template(answer_prompt_template)
|
48 |
+
prompt_without_docs = ChatPromptTemplate.from_template(answer_prompt_without_docs_template)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
+
# ------- CHAIN 0 - Reformulation
|
51 |
+
reformulation = make_reformulation_chain(llm)
|
52 |
+
reformulation = prepare_chain(reformulation,"reformulation")
|
|
|
53 |
|
54 |
+
# ------- Find all keywords from the reformulated query
|
55 |
+
keywords = make_keywords_chain(llm)
|
56 |
+
keywords = {"keywords":itemgetter("question") | keywords}
|
57 |
+
keywords = prepare_chain(keywords,"keywords")
|
58 |
|
59 |
+
# ------- CHAIN 1
|
60 |
+
# Retrieved documents
|
61 |
+
find_documents = {"docs": itemgetter("question") | retriever} | RunnablePassthrough()
|
62 |
+
find_documents = prepare_chain(find_documents,"find_documents")
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
+
# ------- CHAIN 2
|
65 |
+
# Construct inputs for the llm
|
66 |
+
input_documents = {
|
67 |
+
"context":lambda x : _combine_documents(x["docs"]),
|
68 |
+
**pass_values(["question","audience","language","keywords"])
|
69 |
+
}
|
70 |
+
|
71 |
+
# ------- CHAIN 3
|
72 |
+
# Bot answer
|
73 |
+
llm_final = rename_chain(llm,"answer")
|
74 |
+
|
75 |
+
answer_with_docs = {
|
76 |
+
"answer": input_documents | prompt | llm_final | StrOutputParser(),
|
77 |
+
**pass_values(["question","audience","language","query","docs","keywords"]),
|
78 |
+
}
|
79 |
+
|
80 |
+
answer_without_docs = {
|
81 |
+
"answer": prompt_without_docs | llm_final | StrOutputParser(),
|
82 |
+
**pass_values(["question","audience","language","query","docs","keywords"]),
|
83 |
+
}
|
84 |
+
|
85 |
+
# def has_images(x):
|
86 |
+
# image_docs = [doc for doc in x["docs"] if doc.metadata["chunk_type"]=="image"]
|
87 |
+
# return len(image_docs) > 0
|
88 |
|
89 |
+
def has_docs(x):
|
90 |
+
return len(x["docs"]) > 0
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
+
answer = RunnableBranch(
|
93 |
+
(lambda x: has_docs(x), answer_with_docs),
|
94 |
+
answer_without_docs,
|
95 |
+
)
|
96 |
|
97 |
|
98 |
+
# ------- FINAL CHAIN
|
99 |
+
# Build the final chain
|
100 |
+
rag_chain = reformulation | keywords | find_documents | answer
|
101 |
+
|
102 |
+
return rag_chain
|
103 |
|
104 |
|
105 |
def make_rag_papers_chain(llm):
|
106 |
|
107 |
prompt = ChatPromptTemplate.from_template(papers_prompt_template)
|
108 |
+
|
109 |
input_documents = {
|
110 |
"context":lambda x : _combine_documents(x["docs"]),
|
111 |
**pass_values(["question","language"])
|
|
|
131 |
}
|
132 |
|
133 |
illustration_chain = input_description_images | prompt_with_images | llm | StrOutputParser()
|
134 |
+
return illustration_chain
|
climateqa/engine/{chains/reformulation.py → reformulation.py}
RENAMED
@@ -3,7 +3,7 @@ from langchain.output_parsers.structured import StructuredOutputParser, Response
|
|
3 |
from langchain_core.prompts import PromptTemplate
|
4 |
from langchain_core.runnables import RunnablePassthrough, RunnableLambda, RunnableBranch
|
5 |
|
6 |
-
from climateqa.engine.
|
7 |
from climateqa.engine.utils import pass_values, flatten_dict
|
8 |
|
9 |
|
|
|
3 |
from langchain_core.prompts import PromptTemplate
|
4 |
from langchain_core.runnables import RunnablePassthrough, RunnableLambda, RunnableBranch
|
5 |
|
6 |
+
from climateqa.engine.prompts import reformulation_prompt_template
|
7 |
from climateqa.engine.utils import pass_values, flatten_dict
|
8 |
|
9 |
|
climateqa/engine/reranker.py
DELETED
@@ -1,55 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from dotenv import load_dotenv
|
3 |
-
from scipy.special import expit, logit
|
4 |
-
from rerankers import Reranker
|
5 |
-
from sentence_transformers import CrossEncoder
|
6 |
-
|
7 |
-
load_dotenv()
|
8 |
-
|
9 |
-
def get_reranker(model = "nano", cohere_api_key = None):
|
10 |
-
|
11 |
-
assert model in ["nano","tiny","small","large", "jina"]
|
12 |
-
|
13 |
-
if model == "nano":
|
14 |
-
reranker = Reranker('ms-marco-TinyBERT-L-2-v2', model_type='flashrank')
|
15 |
-
elif model == "tiny":
|
16 |
-
reranker = Reranker('ms-marco-MiniLM-L-12-v2', model_type='flashrank')
|
17 |
-
elif model == "small":
|
18 |
-
reranker = Reranker("mixedbread-ai/mxbai-rerank-xsmall-v1", model_type='cross-encoder')
|
19 |
-
elif model == "large":
|
20 |
-
if cohere_api_key is None:
|
21 |
-
cohere_api_key = os.environ["COHERE_API_KEY"]
|
22 |
-
reranker = Reranker("cohere", lang='en', api_key = cohere_api_key)
|
23 |
-
elif model == "jina":
|
24 |
-
# Reached token quota so does not work
|
25 |
-
reranker = Reranker("jina-reranker-v2-base-multilingual", api_key = os.getenv("JINA_RERANKER_API_KEY"))
|
26 |
-
# marche pas sans gpu ? et anyways returns with another structure donc faudrait changer le code du retriever node
|
27 |
-
# reranker = CrossEncoder("jinaai/jina-reranker-v2-base-multilingual", automodel_args={"torch_dtype": "auto"}, trust_remote_code=True,)
|
28 |
-
return reranker
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
def rerank_docs(reranker,docs,query):
|
33 |
-
if docs == []:
|
34 |
-
return []
|
35 |
-
|
36 |
-
# Get a list of texts from langchain docs
|
37 |
-
input_docs = [x.page_content for x in docs]
|
38 |
-
|
39 |
-
# Rerank using rerankers library
|
40 |
-
results = reranker.rank(query=query, docs=input_docs)
|
41 |
-
|
42 |
-
# Prepare langchain list of docs
|
43 |
-
docs_reranked = []
|
44 |
-
for result in results.results:
|
45 |
-
doc_id = result.document.doc_id
|
46 |
-
doc = docs[doc_id]
|
47 |
-
doc.metadata["reranking_score"] = result.score
|
48 |
-
doc.metadata["query_used_for_retrieval"] = query
|
49 |
-
docs_reranked.append(doc)
|
50 |
-
return docs_reranked
|
51 |
-
|
52 |
-
def rerank_and_sort_docs(reranker, docs, query):
|
53 |
-
docs_reranked = rerank_docs(reranker,docs,query)
|
54 |
-
docs_reranked = sorted(docs_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
|
55 |
-
return docs_reranked
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/retriever.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/langchain-ai/langchain/issues/8623
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
|
5 |
+
from langchain_core.retrievers import BaseRetriever
|
6 |
+
from langchain_core.vectorstores import VectorStoreRetriever
|
7 |
+
from langchain_core.documents.base import Document
|
8 |
+
from langchain_core.vectorstores import VectorStore
|
9 |
+
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
|
10 |
+
|
11 |
+
from typing import List
|
12 |
+
from pydantic import Field
|
13 |
+
|
14 |
+
class ClimateQARetriever(BaseRetriever):
|
15 |
+
vectorstore:VectorStore
|
16 |
+
sources:list = ["IPCC","IPBES","IPOS"]
|
17 |
+
reports:list = []
|
18 |
+
threshold:float = 0.6
|
19 |
+
k_summary:int = 3
|
20 |
+
k_total:int = 10
|
21 |
+
namespace:str = "vectors",
|
22 |
+
min_size:int = 200,
|
23 |
+
|
24 |
+
|
25 |
+
def _get_relevant_documents(
|
26 |
+
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
27 |
+
) -> List[Document]:
|
28 |
+
|
29 |
+
# Check if all elements in the list are either IPCC or IPBES
|
30 |
+
assert isinstance(self.sources,list)
|
31 |
+
assert all([x in ["IPCC","IPBES","IPOS"] for x in self.sources])
|
32 |
+
assert self.k_total > self.k_summary, "k_total should be greater than k_summary"
|
33 |
+
|
34 |
+
# Prepare base search kwargs
|
35 |
+
filters = {}
|
36 |
+
|
37 |
+
if len(self.reports) > 0:
|
38 |
+
filters["short_name"] = {"$in":self.reports}
|
39 |
+
else:
|
40 |
+
filters["source"] = { "$in":self.sources}
|
41 |
+
|
42 |
+
# Search for k_summary documents in the summaries dataset
|
43 |
+
filters_summaries = {
|
44 |
+
**filters,
|
45 |
+
"report_type": { "$in":["SPM"]},
|
46 |
+
}
|
47 |
+
|
48 |
+
docs_summaries = self.vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = self.k_summary)
|
49 |
+
docs_summaries = [x for x in docs_summaries if x[1] > self.threshold]
|
50 |
+
|
51 |
+
# Search for k_total - k_summary documents in the full reports dataset
|
52 |
+
filters_full = {
|
53 |
+
**filters,
|
54 |
+
"report_type": { "$nin":["SPM"]},
|
55 |
+
}
|
56 |
+
k_full = self.k_total - len(docs_summaries)
|
57 |
+
docs_full = self.vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_full)
|
58 |
+
|
59 |
+
# Concatenate documents
|
60 |
+
docs = docs_summaries + docs_full
|
61 |
+
|
62 |
+
# Filter if scores are below threshold
|
63 |
+
docs = [x for x in docs if len(x[0].page_content) > self.min_size]
|
64 |
+
# docs = [x for x in docs if x[1] > self.threshold]
|
65 |
+
|
66 |
+
# Add score to metadata
|
67 |
+
results = []
|
68 |
+
for i,(doc,score) in enumerate(docs):
|
69 |
+
doc.metadata["similarity_score"] = score
|
70 |
+
doc.metadata["content"] = doc.page_content
|
71 |
+
doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
|
72 |
+
# doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
|
73 |
+
results.append(doc)
|
74 |
+
|
75 |
+
# Sort by score
|
76 |
+
# results = sorted(results,key = lambda x : x.metadata["similarity_score"],reverse = True)
|
77 |
+
|
78 |
+
return results
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
# def filter_summaries(df,k_summary = 3,k_total = 10):
|
84 |
+
# # assert source in ["IPCC","IPBES","ALL"], "source arg should be in (IPCC,IPBES,ALL)"
|
85 |
+
|
86 |
+
# # # Filter by source
|
87 |
+
# # if source == "IPCC":
|
88 |
+
# # df = df.loc[df["source"]=="IPCC"]
|
89 |
+
# # elif source == "IPBES":
|
90 |
+
# # df = df.loc[df["source"]=="IPBES"]
|
91 |
+
# # else:
|
92 |
+
# # pass
|
93 |
+
|
94 |
+
# # Separate summaries and full reports
|
95 |
+
# df_summaries = df.loc[df["report_type"].isin(["SPM","TS"])]
|
96 |
+
# df_full = df.loc[~df["report_type"].isin(["SPM","TS"])]
|
97 |
+
|
98 |
+
# # Find passages from summaries dataset
|
99 |
+
# passages_summaries = df_summaries.head(k_summary)
|
100 |
+
|
101 |
+
# # Find passages from full reports dataset
|
102 |
+
# passages_fullreports = df_full.head(k_total - len(passages_summaries))
|
103 |
+
|
104 |
+
# # Concatenate passages
|
105 |
+
# passages = pd.concat([passages_summaries,passages_fullreports],axis = 0,ignore_index = True)
|
106 |
+
# return passages
|
107 |
+
|
108 |
+
|
109 |
+
|
110 |
+
|
111 |
+
# def retrieve_with_summaries(query,retriever,k_summary = 3,k_total = 10,sources = ["IPCC","IPBES"],max_k = 100,threshold = 0.555,as_dict = True,min_length = 300):
|
112 |
+
# assert max_k > k_total
|
113 |
+
|
114 |
+
# validated_sources = ["IPCC","IPBES"]
|
115 |
+
# sources = [x for x in sources if x in validated_sources]
|
116 |
+
# filters = {
|
117 |
+
# "source": { "$in": sources },
|
118 |
+
# }
|
119 |
+
# print(filters)
|
120 |
+
|
121 |
+
# # Retrieve documents
|
122 |
+
# docs = retriever.retrieve(query,top_k = max_k,filters = filters)
|
123 |
+
|
124 |
+
# # Filter by score
|
125 |
+
# docs = [{**x.meta,"score":x.score,"content":x.content} for x in docs if x.score > threshold]
|
126 |
+
|
127 |
+
# if len(docs) == 0:
|
128 |
+
# return []
|
129 |
+
# res = pd.DataFrame(docs)
|
130 |
+
# passages_df = filter_summaries(res,k_summary,k_total)
|
131 |
+
# if as_dict:
|
132 |
+
# contents = passages_df["content"].tolist()
|
133 |
+
# meta = passages_df.drop(columns = ["content"]).to_dict(orient = "records")
|
134 |
+
# passages = []
|
135 |
+
# for i in range(len(contents)):
|
136 |
+
# passages.append({"content":contents[i],"meta":meta[i]})
|
137 |
+
# return passages
|
138 |
+
# else:
|
139 |
+
# return passages_df
|
140 |
+
|
141 |
+
|
142 |
+
|
143 |
+
# def retrieve(query,sources = ["IPCC"],threshold = 0.555,k = 10):
|
144 |
+
|
145 |
+
|
146 |
+
# print("hellooooo")
|
147 |
+
|
148 |
+
# # Reformulate queries
|
149 |
+
# reformulated_query,language = reformulate(query)
|
150 |
+
|
151 |
+
# print(reformulated_query)
|
152 |
+
|
153 |
+
# # Retrieve documents
|
154 |
+
# passages = retrieve_with_summaries(reformulated_query,retriever,k_total = k,k_summary = 3,as_dict = True,sources = sources,threshold = threshold)
|
155 |
+
# response = {
|
156 |
+
# "query":query,
|
157 |
+
# "reformulated_query":reformulated_query,
|
158 |
+
# "language":language,
|
159 |
+
# "sources":passages,
|
160 |
+
# "prompts":{"init_prompt":init_prompt,"sources_prompt":sources_prompt},
|
161 |
+
# }
|
162 |
+
# return response
|
163 |
+
|
climateqa/engine/talk_to_data/config.py
DELETED
@@ -1,11 +0,0 @@
|
|
1 |
-
# Path configuration for climateqa project
|
2 |
-
|
3 |
-
# IPCC dataset path
|
4 |
-
IPCC_DATASET_URL = "hf://datasets/ekimetrics/ipcc-atlas"
|
5 |
-
|
6 |
-
# DRIAS dataset paths
|
7 |
-
DRIAS_DATASET_URL = "hf://datasets/timeki/drias_db"
|
8 |
-
|
9 |
-
# Table paths
|
10 |
-
DRIAS_MEAN_ANNUAL_TEMPERATURE_PATH = f"{DRIAS_DATASET_URL}/mean_annual_temperature.parquet"
|
11 |
-
IPCC_COORDINATES_PATH = f"{IPCC_DATASET_URL}/coordinates.parquet"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/talk_to_data/drias/config.py
DELETED
@@ -1,124 +0,0 @@
|
|
1 |
-
|
2 |
-
from climateqa.engine.talk_to_data.ui_config import PRECIPITATION_COLORSCALE, TEMPERATURE_COLORSCALE
|
3 |
-
|
4 |
-
|
5 |
-
DRIAS_TABLES = [
|
6 |
-
"total_winter_precipitation",
|
7 |
-
"total_summer_precipitation",
|
8 |
-
"total_annual_precipitation",
|
9 |
-
"total_remarkable_daily_precipitation",
|
10 |
-
"frequency_of_remarkable_daily_precipitation",
|
11 |
-
"extreme_precipitation_intensity",
|
12 |
-
"mean_winter_temperature",
|
13 |
-
"mean_summer_temperature",
|
14 |
-
"mean_annual_temperature",
|
15 |
-
"number_of_tropical_nights",
|
16 |
-
"maximum_summer_temperature",
|
17 |
-
"number_of_days_with_tx_above_30",
|
18 |
-
"number_of_days_with_tx_above_35",
|
19 |
-
"number_of_days_with_a_dry_ground",
|
20 |
-
]
|
21 |
-
|
22 |
-
DRIAS_INDICATOR_COLUMNS_PER_TABLE = {
|
23 |
-
"total_winter_precipitation": "total_winter_precipitation",
|
24 |
-
"total_summer_precipitation": "total_summer_precipitation",
|
25 |
-
"total_annual_precipitation": "total_annual_precipitation",
|
26 |
-
"total_remarkable_daily_precipitation": "total_remarkable_daily_precipitation",
|
27 |
-
"frequency_of_remarkable_daily_precipitation": "frequency_of_remarkable_daily_precipitation",
|
28 |
-
"extreme_precipitation_intensity": "extreme_precipitation_intensity",
|
29 |
-
"mean_winter_temperature": "mean_winter_temperature",
|
30 |
-
"mean_summer_temperature": "mean_summer_temperature",
|
31 |
-
"mean_annual_temperature": "mean_annual_temperature",
|
32 |
-
"number_of_tropical_nights": "number_tropical_nights",
|
33 |
-
"maximum_summer_temperature": "maximum_summer_temperature",
|
34 |
-
"number_of_days_with_tx_above_30": "number_of_days_with_tx_above_30",
|
35 |
-
"number_of_days_with_tx_above_35": "number_of_days_with_tx_above_35",
|
36 |
-
"number_of_days_with_a_dry_ground": "number_of_days_with_dry_ground"
|
37 |
-
}
|
38 |
-
|
39 |
-
DRIAS_MODELS = [
|
40 |
-
'ALL',
|
41 |
-
'RegCM4-6_MPI-ESM-LR',
|
42 |
-
'RACMO22E_EC-EARTH',
|
43 |
-
'RegCM4-6_HadGEM2-ES',
|
44 |
-
'HadREM3-GA7_EC-EARTH',
|
45 |
-
'HadREM3-GA7_CNRM-CM5',
|
46 |
-
'REMO2015_NorESM1-M',
|
47 |
-
'SMHI-RCA4_EC-EARTH',
|
48 |
-
'WRF381P_NorESM1-M',
|
49 |
-
'ALADIN63_CNRM-CM5',
|
50 |
-
'CCLM4-8-17_MPI-ESM-LR',
|
51 |
-
'HIRHAM5_IPSL-CM5A-MR',
|
52 |
-
'HadREM3-GA7_HadGEM2-ES',
|
53 |
-
'SMHI-RCA4_IPSL-CM5A-MR',
|
54 |
-
'HIRHAM5_NorESM1-M',
|
55 |
-
'REMO2009_MPI-ESM-LR',
|
56 |
-
'CCLM4-8-17_HadGEM2-ES'
|
57 |
-
]
|
58 |
-
# Mapping between indicator columns and their units
|
59 |
-
DRIAS_INDICATOR_TO_UNIT = {
|
60 |
-
"total_winter_precipitation": "mm",
|
61 |
-
"total_summer_precipitation": "mm",
|
62 |
-
"total_annual_precipitation": "mm",
|
63 |
-
"total_remarkable_daily_precipitation": "mm",
|
64 |
-
"frequency_of_remarkable_daily_precipitation": "days",
|
65 |
-
"extreme_precipitation_intensity": "mm",
|
66 |
-
"mean_winter_temperature": "°C",
|
67 |
-
"mean_summer_temperature": "°C",
|
68 |
-
"mean_annual_temperature": "°C",
|
69 |
-
"number_tropical_nights": "days",
|
70 |
-
"maximum_summer_temperature": "°C",
|
71 |
-
"number_of_days_with_tx_above_30": "days",
|
72 |
-
"number_of_days_with_tx_above_35": "days",
|
73 |
-
"number_of_days_with_dry_ground": "days"
|
74 |
-
}
|
75 |
-
|
76 |
-
DRIAS_PLOT_PARAMETERS = [
|
77 |
-
'year',
|
78 |
-
'location'
|
79 |
-
]
|
80 |
-
|
81 |
-
DRIAS_INDICATOR_TO_COLORSCALE = {
|
82 |
-
"total_winter_precipitation": PRECIPITATION_COLORSCALE,
|
83 |
-
"total_summer_precipitation": PRECIPITATION_COLORSCALE,
|
84 |
-
"total_annual_precipitation": PRECIPITATION_COLORSCALE,
|
85 |
-
"total_remarkable_daily_precipitation": PRECIPITATION_COLORSCALE,
|
86 |
-
"frequency_of_remarkable_daily_precipitation": PRECIPITATION_COLORSCALE,
|
87 |
-
"extreme_precipitation_intensity": PRECIPITATION_COLORSCALE,
|
88 |
-
"mean_winter_temperature":TEMPERATURE_COLORSCALE,
|
89 |
-
"mean_summer_temperature":TEMPERATURE_COLORSCALE,
|
90 |
-
"mean_annual_temperature":TEMPERATURE_COLORSCALE,
|
91 |
-
"number_tropical_nights": TEMPERATURE_COLORSCALE,
|
92 |
-
"maximum_summer_temperature":TEMPERATURE_COLORSCALE,
|
93 |
-
"number_of_days_with_tx_above_30": TEMPERATURE_COLORSCALE,
|
94 |
-
"number_of_days_with_tx_above_35": TEMPERATURE_COLORSCALE,
|
95 |
-
"number_of_days_with_dry_ground": TEMPERATURE_COLORSCALE
|
96 |
-
}
|
97 |
-
|
98 |
-
DRIAS_UI_TEXT = """
|
99 |
-
Hi, I'm **Talk to Drias**, designed to answer your questions using [**DRIAS - TRACC 2023**](https://www.drias-climat.fr/accompagnement/sections/401) data.
|
100 |
-
I'll answer by displaying a list of SQL queries, graphs and data most relevant to your question.
|
101 |
-
|
102 |
-
You can ask me anything about these climate indicators: **temperature**, **precipitation** or **drought**.
|
103 |
-
You can specify **location** and/or **year**.
|
104 |
-
You can choose from a list of climate models. By default, we take the **average of each model**.
|
105 |
-
|
106 |
-
For example, you can ask:
|
107 |
-
- What will the temperature be like in Paris?
|
108 |
-
- What will be the total rainfall in France in 2030?
|
109 |
-
- How frequent will extreme events be in Lyon?
|
110 |
-
|
111 |
-
**Example of indicators in the data**:
|
112 |
-
- Mean temperature (annual, winter, summer)
|
113 |
-
- Total precipitation (annual, winter, summer)
|
114 |
-
- Number of days with remarkable precipitations, with dry ground, with temperature above 30°C
|
115 |
-
|
116 |
-
⚠️ **Limitations**:
|
117 |
-
- You can't ask anything that isn't related to **DRIAS - TRACC 2023** data.
|
118 |
-
- You can only ask about **locations in France**.
|
119 |
-
- If you specify a year, there may be **no data for that year for some models**.
|
120 |
-
- You **cannot compare two models**.
|
121 |
-
|
122 |
-
🛈 **Information**
|
123 |
-
Please note that we **log your questions for meta-analysis purposes**, so avoid sharing any sensitive or personal information.
|
124 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/talk_to_data/drias/plot_informations.py
DELETED
@@ -1,88 +0,0 @@
|
|
1 |
-
from climateqa.engine.talk_to_data.drias.config import DRIAS_INDICATOR_TO_UNIT
|
2 |
-
|
3 |
-
def indicator_evolution_informations(
|
4 |
-
indicator: str,
|
5 |
-
params: dict[str, str]
|
6 |
-
) -> str:
|
7 |
-
unit = DRIAS_INDICATOR_TO_UNIT[indicator]
|
8 |
-
if "location" not in params:
|
9 |
-
raise ValueError('"location" must be provided in params')
|
10 |
-
location = params["location"]
|
11 |
-
return f"""
|
12 |
-
This plot shows how the climate indicator **{indicator}** evolves over time in **{location}**.
|
13 |
-
|
14 |
-
It combines both historical observations and future projections according to the climate scenario RCP8.5.
|
15 |
-
|
16 |
-
The x-axis represents the years, and the y-axis shows the value of the indicator ({unit}).
|
17 |
-
|
18 |
-
A 10-year rolling average curve is displayed to give a better idea of the overall trend.
|
19 |
-
|
20 |
-
**Data source:**
|
21 |
-
- The data come from the DRIAS TRACC data. The data were initially extracted from [the DRIAS website](https://www.drias-climat.fr/drias_prod/accueil/okapiWebDrias/index.jsp?iddrias=climat) and then preprocessed to a tabular format and uploaded as parquet in this [Hugging Face dataset](https://huggingface.co/datasets/timeki/drias_db).
|
22 |
-
- For each year and climate model, the value of {indicator} in {location} is collected, to build the time series.
|
23 |
-
- The coordinates used for {location} correspond to the closest available point in the DRIAS database, which uses a regular grid with a spatial resolution of 8 km.
|
24 |
-
- The indicator values shown are those for the selected climate model.
|
25 |
-
- If ALL climate model is selected, the average value of the indicator between all the climate models is used.
|
26 |
-
"""
|
27 |
-
|
28 |
-
def indicator_number_of_days_per_year_informations(
|
29 |
-
indicator: str,
|
30 |
-
params: dict[str, str]
|
31 |
-
) -> str:
|
32 |
-
unit = DRIAS_INDICATOR_TO_UNIT[indicator]
|
33 |
-
if "location" not in params:
|
34 |
-
raise ValueError('"location" must be provided in params')
|
35 |
-
location = params["location"]
|
36 |
-
return f"""
|
37 |
-
This plot displays a bar chart showing the yearly frequency of the climate indicator **{indicator}** in **{location}**.
|
38 |
-
|
39 |
-
The x-axis represents the years, and the y-axis shows the frequency of {indicator} ({unit}) per year.
|
40 |
-
|
41 |
-
**Data source:**
|
42 |
-
- The data come from the DRIAS TRACC data. The data were initially extracted from [the DRIAS website](https://www.drias-climat.fr/drias_prod/accueil/okapiWebDrias/index.jsp?iddrias=climat) and then preprocessed to a tabular format and uploaded as parquet in this [Hugging Face dataset](https://huggingface.co/datasets/timeki/drias_db).
|
43 |
-
- For each year and climate model, the value of {indicator} in {location} is collected, to build the time series.
|
44 |
-
- The coordinates used for {location} correspond to the closest available point in the DRIAS database, which uses a regular grid with a spatial resolution of 8 km.
|
45 |
-
- The indicator values shown are those for the selected climate model.
|
46 |
-
- If ALL climate model is selected, the average value of the indicator between all the climate models is used.
|
47 |
-
"""
|
48 |
-
|
49 |
-
def distribution_of_indicator_for_given_year_informations(
|
50 |
-
indicator: str,
|
51 |
-
params: dict[str, str]
|
52 |
-
) -> str:
|
53 |
-
unit = DRIAS_INDICATOR_TO_UNIT[indicator]
|
54 |
-
year = params["year"]
|
55 |
-
if year is None:
|
56 |
-
year = 2030
|
57 |
-
return f"""
|
58 |
-
This plot shows a histogram of the distribution of the climate indicator **{indicator}** across all locations for the year **{year}**.
|
59 |
-
|
60 |
-
It allows you to visualize how the values of {indicator} ({unit}) are spread for a given year.
|
61 |
-
|
62 |
-
**Data source:**
|
63 |
-
- The data come from the DRIAS TRACC data. The data were initially extracted from [the DRIAS website](https://www.drias-climat.fr/drias_prod/accueil/okapiWebDrias/index.jsp?iddrias=climat) and then preprocessed to a tabular format and uploaded as parquet in this [Hugging Face dataset](https://huggingface.co/datasets/timeki/drias_db).
|
64 |
-
- For each grid point in the dataset and climate model, the value of {indicator} for the year {year} is extracted.
|
65 |
-
- The indicator values shown are those for the selected climate model.
|
66 |
-
- If ALL climate model is selected, the average value of the indicator between all the climate models is used.
|
67 |
-
"""
|
68 |
-
|
69 |
-
def map_of_france_of_indicator_for_given_year_informations(
|
70 |
-
indicator: str,
|
71 |
-
params: dict[str, str]
|
72 |
-
) -> str:
|
73 |
-
unit = DRIAS_INDICATOR_TO_UNIT[indicator]
|
74 |
-
year = params["year"]
|
75 |
-
if year is None:
|
76 |
-
year = 2030
|
77 |
-
return f"""
|
78 |
-
This plot displays a choropleth map showing the spatial distribution of **{indicator}** across all regions of France for the year **{year}**.
|
79 |
-
|
80 |
-
Each region is colored according to the value of the indicator ({unit}), allowing you to visually compare how {indicator} varies geographically within France for the selected year and climate model.
|
81 |
-
|
82 |
-
**Data source:**
|
83 |
-
- The data come from the DRIAS TRACC data. The data were initially extracted from [the DRIAS website](https://www.drias-climat.fr/drias_prod/accueil/okapiWebDrias/index.jsp?iddrias=climat) and then preprocessed to a tabular format and uploaded as parquet in this [Hugging Face dataset](https://huggingface.co/datasets/timeki/drias_db).
|
84 |
-
- For each region of France, the value of {indicator} in {year} and for the selected climate model is extracted and mapped to its geographic coordinates.
|
85 |
-
- The regions correspond to 8 km squares centered on the grid points of the DRIAS dataset.
|
86 |
-
- The indicator values shown are those for the selected climate model.
|
87 |
-
- If ALL climate model is selected, the average value of the indicator between all the climate models is used.
|
88 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/talk_to_data/drias/plots.py
DELETED
@@ -1,434 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import geojson
|
3 |
-
from math import cos, radians
|
4 |
-
from typing import Callable
|
5 |
-
import pandas as pd
|
6 |
-
from plotly.graph_objects import Figure
|
7 |
-
import plotly.graph_objects as go
|
8 |
-
from climateqa.engine.talk_to_data.drias.plot_informations import distribution_of_indicator_for_given_year_informations, indicator_evolution_informations, indicator_number_of_days_per_year_informations, map_of_france_of_indicator_for_given_year_informations
|
9 |
-
from climateqa.engine.talk_to_data.objects.plot import Plot
|
10 |
-
from climateqa.engine.talk_to_data.drias.queries import (
|
11 |
-
indicator_for_given_year_query,
|
12 |
-
indicator_per_year_at_location_query,
|
13 |
-
)
|
14 |
-
from climateqa.engine.talk_to_data.drias.config import DRIAS_INDICATOR_TO_COLORSCALE, DRIAS_INDICATOR_TO_UNIT
|
15 |
-
|
16 |
-
def generate_geojson_polygons(latitudes: list[float], longitudes: list[float], indicators: list[float]) -> geojson.FeatureCollection:
|
17 |
-
side_km = 8
|
18 |
-
delta_lat = side_km / 111
|
19 |
-
features = []
|
20 |
-
for idx, (lat, lon, val) in enumerate(zip(latitudes, longitudes, indicators)):
|
21 |
-
delta_lon = side_km / (111 * cos(radians(lat)))
|
22 |
-
half_lat = delta_lat / 2
|
23 |
-
half_lon = delta_lon / 2
|
24 |
-
features.append(geojson.Feature(
|
25 |
-
geometry=geojson.Polygon([[
|
26 |
-
[lon - half_lon, lat - half_lat],
|
27 |
-
[lon + half_lon, lat - half_lat],
|
28 |
-
[lon + half_lon, lat + half_lat],
|
29 |
-
[lon - half_lon, lat + half_lat],
|
30 |
-
[lon - half_lon, lat - half_lat]
|
31 |
-
]]),
|
32 |
-
properties={"value": val},
|
33 |
-
id=str(idx)
|
34 |
-
))
|
35 |
-
|
36 |
-
return geojson.FeatureCollection(features)
|
37 |
-
|
38 |
-
def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
|
39 |
-
"""Generates a function to plot indicator evolution over time at a location.
|
40 |
-
|
41 |
-
This function creates a line plot showing how a climate indicator changes
|
42 |
-
over time at a specific location. It handles temperature, precipitation,
|
43 |
-
and other climate indicators.
|
44 |
-
|
45 |
-
Args:
|
46 |
-
params (dict): Dictionary containing:
|
47 |
-
- indicator_column (str): The column name for the indicator
|
48 |
-
- location (str): The location to plot
|
49 |
-
- model (str): The climate model to use
|
50 |
-
|
51 |
-
Returns:
|
52 |
-
Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
|
53 |
-
|
54 |
-
Example:
|
55 |
-
>>> plot_func = plot_indicator_evolution_at_location({
|
56 |
-
... 'indicator_column': 'mean_temperature',
|
57 |
-
... 'location': 'Paris',
|
58 |
-
... 'model': 'ALL'
|
59 |
-
... })
|
60 |
-
>>> fig = plot_func(df)
|
61 |
-
"""
|
62 |
-
indicator = params["indicator_column"]
|
63 |
-
location = params["location"]
|
64 |
-
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
|
65 |
-
unit = DRIAS_INDICATOR_TO_UNIT.get(indicator, "")
|
66 |
-
|
67 |
-
def plot_data(df: pd.DataFrame) -> Figure:
|
68 |
-
"""Generates the actual plot from the data.
|
69 |
-
|
70 |
-
Args:
|
71 |
-
df (pd.DataFrame): DataFrame containing the data to plot
|
72 |
-
|
73 |
-
Returns:
|
74 |
-
Figure: A plotly Figure object showing the indicator evolution
|
75 |
-
"""
|
76 |
-
fig = go.Figure()
|
77 |
-
if df['model'].nunique() != 1:
|
78 |
-
df_avg = df.groupby("year", as_index=False)[indicator].mean()
|
79 |
-
|
80 |
-
# Transform to list to avoid pandas encoding
|
81 |
-
indicators = df_avg[indicator].astype(float).tolist()
|
82 |
-
years = df_avg["year"].astype(int).tolist()
|
83 |
-
|
84 |
-
# Compute the 10-year rolling average
|
85 |
-
rolling_window = 10
|
86 |
-
sliding_averages = (
|
87 |
-
df_avg[indicator]
|
88 |
-
.rolling(window=rolling_window, min_periods=rolling_window)
|
89 |
-
.mean()
|
90 |
-
.astype(float)
|
91 |
-
.tolist()
|
92 |
-
)
|
93 |
-
model_label = "Model Average"
|
94 |
-
|
95 |
-
# Only add rolling average if we have enough data points
|
96 |
-
if len([x for x in sliding_averages if pd.notna(x)]) > 0:
|
97 |
-
# Sliding average dashed line
|
98 |
-
fig.add_scatter(
|
99 |
-
x=years,
|
100 |
-
y=sliding_averages,
|
101 |
-
mode="lines",
|
102 |
-
name="10 years rolling average",
|
103 |
-
line=dict(dash="dash"),
|
104 |
-
marker=dict(color="#d62728"),
|
105 |
-
hovertemplate=f"10-year average: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
|
106 |
-
)
|
107 |
-
|
108 |
-
else:
|
109 |
-
df_model = df
|
110 |
-
|
111 |
-
# Transform to list to avoid pandas encoding
|
112 |
-
indicators = df_model[indicator].astype(float).tolist()
|
113 |
-
years = df_model["year"].astype(int).tolist()
|
114 |
-
|
115 |
-
# Compute the 10-year rolling average
|
116 |
-
rolling_window = 10
|
117 |
-
sliding_averages = (
|
118 |
-
df_model[indicator]
|
119 |
-
.rolling(window=rolling_window, min_periods=rolling_window)
|
120 |
-
.mean()
|
121 |
-
.astype(float)
|
122 |
-
.tolist()
|
123 |
-
)
|
124 |
-
model_label = f"Model : {df['model'].unique()[0]}"
|
125 |
-
|
126 |
-
# Only add rolling average if we have enough data points
|
127 |
-
if len([x for x in sliding_averages if pd.notna(x)]) > 0:
|
128 |
-
# Sliding average dashed line
|
129 |
-
fig.add_scatter(
|
130 |
-
x=years,
|
131 |
-
y=sliding_averages,
|
132 |
-
mode="lines",
|
133 |
-
name="10 years rolling average",
|
134 |
-
line=dict(dash="dash"),
|
135 |
-
marker=dict(color="#d62728"),
|
136 |
-
hovertemplate=f"10-year average: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
|
137 |
-
)
|
138 |
-
|
139 |
-
# Indicator per year plot
|
140 |
-
fig.add_scatter(
|
141 |
-
x=years,
|
142 |
-
y=indicators,
|
143 |
-
name=f"Yearly {indicator_label}",
|
144 |
-
mode="lines",
|
145 |
-
marker=dict(color="#1f77b4"),
|
146 |
-
hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
|
147 |
-
)
|
148 |
-
fig.update_layout(
|
149 |
-
title=f"Evolution of {indicator_label} in {location} ({model_label})",
|
150 |
-
xaxis_title="Year",
|
151 |
-
yaxis_title=f"{indicator_label} ({unit})",
|
152 |
-
template="plotly_white",
|
153 |
-
height=900,
|
154 |
-
)
|
155 |
-
return fig
|
156 |
-
|
157 |
-
return plot_data
|
158 |
-
|
159 |
-
|
160 |
-
indicator_evolution_at_location: Plot = {
|
161 |
-
"name": "Indicator evolution at location",
|
162 |
-
"description": "Plot an evolution of the indicator at a certain location",
|
163 |
-
"params": ["indicator_column", "location", "model"],
|
164 |
-
"plot_function": plot_indicator_evolution_at_location,
|
165 |
-
"sql_query": indicator_per_year_at_location_query,
|
166 |
-
"plot_information": indicator_evolution_informations,
|
167 |
-
'short_name': 'Evolution'
|
168 |
-
}
|
169 |
-
|
170 |
-
|
171 |
-
def plot_indicator_number_of_days_per_year_at_location(
|
172 |
-
params: dict,
|
173 |
-
) -> Callable[..., Figure]:
|
174 |
-
"""Generates a function to plot the number of days per year for an indicator.
|
175 |
-
|
176 |
-
This function creates a bar chart showing the frequency of certain climate
|
177 |
-
events (like days above a temperature threshold) per year at a specific location.
|
178 |
-
|
179 |
-
Args:
|
180 |
-
params (dict): Dictionary containing:
|
181 |
-
- indicator_column (str): The column name for the indicator
|
182 |
-
- location (str): The location to plot
|
183 |
-
- model (str): The climate model to use
|
184 |
-
|
185 |
-
Returns:
|
186 |
-
Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
|
187 |
-
"""
|
188 |
-
indicator = params["indicator_column"]
|
189 |
-
location = params["location"]
|
190 |
-
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
|
191 |
-
unit = DRIAS_INDICATOR_TO_UNIT.get(indicator, "")
|
192 |
-
|
193 |
-
def plot_data(df: pd.DataFrame) -> Figure:
|
194 |
-
"""Generate the figure thanks to the dataframe
|
195 |
-
|
196 |
-
Args:
|
197 |
-
df (pd.DataFrame): pandas dataframe with the required data
|
198 |
-
|
199 |
-
Returns:
|
200 |
-
Figure: Plotly figure
|
201 |
-
"""
|
202 |
-
fig = go.Figure()
|
203 |
-
if df['model'].nunique() != 1:
|
204 |
-
df_avg = df.groupby("year", as_index=False)[indicator].mean()
|
205 |
-
|
206 |
-
# Transform to list to avoid pandas encoding
|
207 |
-
indicators = df_avg[indicator].astype(float).tolist()
|
208 |
-
years = df_avg["year"].astype(int).tolist()
|
209 |
-
model_label = "Model Average"
|
210 |
-
|
211 |
-
else:
|
212 |
-
df_model = df
|
213 |
-
# Transform to list to avoid pandas encoding
|
214 |
-
indicators = df_model[indicator].astype(float).tolist()
|
215 |
-
years = df_model["year"].astype(int).tolist()
|
216 |
-
model_label = f"Model : {df['model'].unique()[0]}"
|
217 |
-
|
218 |
-
|
219 |
-
# Bar plot
|
220 |
-
fig.add_trace(
|
221 |
-
go.Bar(
|
222 |
-
x=years,
|
223 |
-
y=indicators,
|
224 |
-
width=0.5,
|
225 |
-
marker=dict(color="#1f77b4"),
|
226 |
-
hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
|
227 |
-
)
|
228 |
-
)
|
229 |
-
|
230 |
-
fig.update_layout(
|
231 |
-
title=f"{indicator_label} in {location} ({model_label})",
|
232 |
-
xaxis_title="Year",
|
233 |
-
yaxis_title=f"{indicator_label} ({unit})",
|
234 |
-
yaxis=dict(range=[0, max(indicators)]),
|
235 |
-
bargap=0.5,
|
236 |
-
height=900,
|
237 |
-
template="plotly_white",
|
238 |
-
)
|
239 |
-
|
240 |
-
return fig
|
241 |
-
|
242 |
-
return plot_data
|
243 |
-
|
244 |
-
|
245 |
-
indicator_number_of_days_per_year_at_location: Plot = {
|
246 |
-
"name": "Indicator number of days per year at location",
|
247 |
-
"description": "Plot a barchart of the number of days per year of a certain indicator at a certain location. It is appropriate for frequency indicator.",
|
248 |
-
"params": ["indicator_column", "location", "model"],
|
249 |
-
"plot_function": plot_indicator_number_of_days_per_year_at_location,
|
250 |
-
"sql_query": indicator_per_year_at_location_query,
|
251 |
-
"plot_information": indicator_number_of_days_per_year_informations,
|
252 |
-
"short_name": "Yearly Frequency",
|
253 |
-
}
|
254 |
-
|
255 |
-
|
256 |
-
def plot_distribution_of_indicator_for_given_year(
|
257 |
-
params: dict,
|
258 |
-
) -> Callable[..., Figure]:
|
259 |
-
"""Generates a function to plot the distribution of an indicator for a year.
|
260 |
-
|
261 |
-
This function creates a histogram showing the distribution of a climate
|
262 |
-
indicator across different locations for a specific year.
|
263 |
-
|
264 |
-
Args:
|
265 |
-
params (dict): Dictionary containing:
|
266 |
-
- indicator_column (str): The column name for the indicator
|
267 |
-
- year (str): The year to plot
|
268 |
-
- model (str): The climate model to use
|
269 |
-
|
270 |
-
Returns:
|
271 |
-
Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
|
272 |
-
"""
|
273 |
-
indicator = params["indicator_column"]
|
274 |
-
year = params["year"]
|
275 |
-
if year is None:
|
276 |
-
year = 2030
|
277 |
-
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
|
278 |
-
unit = DRIAS_INDICATOR_TO_UNIT.get(indicator, "")
|
279 |
-
|
280 |
-
def plot_data(df: pd.DataFrame) -> Figure:
|
281 |
-
"""Generate the figure thanks to the dataframe
|
282 |
-
|
283 |
-
Args:
|
284 |
-
df (pd.DataFrame): pandas dataframe with the required data
|
285 |
-
|
286 |
-
Returns:
|
287 |
-
Figure: Plotly figure
|
288 |
-
"""
|
289 |
-
fig = go.Figure()
|
290 |
-
if df['model'].nunique() != 1:
|
291 |
-
df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
|
292 |
-
indicator
|
293 |
-
].mean()
|
294 |
-
|
295 |
-
# Transform to list to avoid pandas encoding
|
296 |
-
indicators = df_avg[indicator].astype(float).tolist()
|
297 |
-
model_label = "Model Average"
|
298 |
-
|
299 |
-
else:
|
300 |
-
df_model = df
|
301 |
-
|
302 |
-
# Transform to list to avoid pandas encoding
|
303 |
-
indicators = df_model[indicator].astype(float).tolist()
|
304 |
-
model_label = f"Model : {df['model'].unique()[0]}"
|
305 |
-
|
306 |
-
|
307 |
-
fig.add_trace(
|
308 |
-
go.Histogram(
|
309 |
-
x=indicators,
|
310 |
-
opacity=0.8,
|
311 |
-
histnorm="percent",
|
312 |
-
marker=dict(color="#1f77b4"),
|
313 |
-
hovertemplate=f"{indicator_label}: %{{x:.2f}} {unit}<br>Frequency: %{{y:.2f}}%<extra></extra>"
|
314 |
-
)
|
315 |
-
)
|
316 |
-
|
317 |
-
fig.update_layout(
|
318 |
-
title=f"Distribution of {indicator_label} in {year} ({model_label})",
|
319 |
-
xaxis_title=f"{indicator_label} ({unit})",
|
320 |
-
yaxis_title="Frequency (%)",
|
321 |
-
plot_bgcolor="rgba(0, 0, 0, 0)",
|
322 |
-
showlegend=False,
|
323 |
-
height=900,
|
324 |
-
)
|
325 |
-
|
326 |
-
return fig
|
327 |
-
|
328 |
-
return plot_data
|
329 |
-
|
330 |
-
|
331 |
-
distribution_of_indicator_for_given_year: Plot = {
|
332 |
-
"name": "Distribution of an indicator for a given year",
|
333 |
-
"description": "Plot an histogram of the distribution for a given year of the values of an indicator",
|
334 |
-
"params": ["indicator_column", "model", "year"],
|
335 |
-
"plot_function": plot_distribution_of_indicator_for_given_year,
|
336 |
-
"sql_query": indicator_for_given_year_query,
|
337 |
-
"plot_information": distribution_of_indicator_for_given_year_informations,
|
338 |
-
'short_name': 'Distribution'
|
339 |
-
}
|
340 |
-
|
341 |
-
|
342 |
-
def plot_map_of_france_of_indicator_for_given_year(
|
343 |
-
params: dict,
|
344 |
-
) -> Callable[..., Figure]:
|
345 |
-
"""Generates a function to plot a map of France for an indicator.
|
346 |
-
|
347 |
-
This function creates a choropleth map of France showing the spatial
|
348 |
-
distribution of a climate indicator for a specific year.
|
349 |
-
|
350 |
-
Args:
|
351 |
-
params (dict): Dictionary containing:
|
352 |
-
- indicator_column (str): The column name for the indicator
|
353 |
-
- year (str): The year to plot
|
354 |
-
- model (str): The climate model to use
|
355 |
-
|
356 |
-
Returns:
|
357 |
-
Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
|
358 |
-
"""
|
359 |
-
indicator = params["indicator_column"]
|
360 |
-
year = params["year"]
|
361 |
-
if year is None:
|
362 |
-
year = 2030
|
363 |
-
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
|
364 |
-
unit = DRIAS_INDICATOR_TO_UNIT.get(indicator, "")
|
365 |
-
|
366 |
-
def plot_data(df: pd.DataFrame) -> Figure:
|
367 |
-
fig = go.Figure()
|
368 |
-
if df['model'].nunique() != 1:
|
369 |
-
df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
|
370 |
-
indicator
|
371 |
-
].mean()
|
372 |
-
|
373 |
-
indicators = df_avg[indicator].astype(float).tolist()
|
374 |
-
latitudes = df_avg["latitude"].astype(float).tolist()
|
375 |
-
longitudes = df_avg["longitude"].astype(float).tolist()
|
376 |
-
model_label = "Model Average"
|
377 |
-
|
378 |
-
else:
|
379 |
-
df_model = df
|
380 |
-
|
381 |
-
# Transform to list to avoid pandas encoding
|
382 |
-
indicators = df_model[indicator].astype(float).tolist()
|
383 |
-
latitudes = df_model["latitude"].astype(float).tolist()
|
384 |
-
longitudes = df_model["longitude"].astype(float).tolist()
|
385 |
-
model_label = f"Model : {df['model'].unique()[0]}"
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
geojson_data = generate_geojson_polygons(latitudes, longitudes, indicators)
|
390 |
-
|
391 |
-
fig = go.Figure(go.Choroplethmapbox(
|
392 |
-
geojson=geojson_data,
|
393 |
-
locations=[str(i) for i in range(len(indicators))],
|
394 |
-
featureidkey="id",
|
395 |
-
z=indicators,
|
396 |
-
colorscale=DRIAS_INDICATOR_TO_COLORSCALE[indicator],
|
397 |
-
zmin=min(indicators),
|
398 |
-
zmax=max(indicators),
|
399 |
-
marker_opacity=0.7,
|
400 |
-
marker_line_width=0,
|
401 |
-
colorbar_title=f"{indicator_label} ({unit})",
|
402 |
-
text=[f"{indicator_label}: {value:.2f} {unit}" for value in indicators], # Add hover text showing the indicator value
|
403 |
-
hoverinfo="text"
|
404 |
-
))
|
405 |
-
|
406 |
-
fig.update_layout(
|
407 |
-
mapbox_style="open-street-map", # Use OpenStreetMap
|
408 |
-
mapbox_zoom=5,
|
409 |
-
height=900,
|
410 |
-
mapbox_center={"lat": 46.6, "lon": 2.0},
|
411 |
-
coloraxis_colorbar=dict(title=f"{indicator_label} ({unit})"), # Add legend
|
412 |
-
title=f"{indicator_label} in {year} in France ({model_label}) " # Title
|
413 |
-
)
|
414 |
-
return fig
|
415 |
-
|
416 |
-
return plot_data
|
417 |
-
|
418 |
-
|
419 |
-
map_of_france_of_indicator_for_given_year: Plot = {
|
420 |
-
"name": "Map of France of an indicator for a given year",
|
421 |
-
"description": "Heatmap on the map of France of the values of an indicator for a given year",
|
422 |
-
"params": ["indicator_column", "year", "model"],
|
423 |
-
"plot_function": plot_map_of_france_of_indicator_for_given_year,
|
424 |
-
"sql_query": indicator_for_given_year_query,
|
425 |
-
"plot_information": map_of_france_of_indicator_for_given_year_informations,
|
426 |
-
'short_name': 'Map of France'
|
427 |
-
}
|
428 |
-
|
429 |
-
DRIAS_PLOTS = [
|
430 |
-
indicator_evolution_at_location,
|
431 |
-
indicator_number_of_days_per_year_at_location,
|
432 |
-
distribution_of_indicator_for_given_year,
|
433 |
-
map_of_france_of_indicator_for_given_year,
|
434 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/talk_to_data/drias/queries.py
DELETED
@@ -1,83 +0,0 @@
|
|
1 |
-
from typing import TypedDict
|
2 |
-
from climateqa.engine.talk_to_data.config import DRIAS_DATASET_URL
|
3 |
-
|
4 |
-
class IndicatorPerYearAtLocationQueryParams(TypedDict, total=False):
|
5 |
-
"""Parameters for querying an indicator's values over time at a location.
|
6 |
-
|
7 |
-
This class defines the parameters needed to query climate indicator data
|
8 |
-
for a specific location over multiple years.
|
9 |
-
|
10 |
-
Attributes:
|
11 |
-
indicator_column (str): The column name for the climate indicator
|
12 |
-
latitude (str): The latitude coordinate of the location
|
13 |
-
longitude (str): The longitude coordinate of the location
|
14 |
-
model (str): The climate model to use (optional)
|
15 |
-
"""
|
16 |
-
indicator_column: str
|
17 |
-
latitude: str
|
18 |
-
longitude: str
|
19 |
-
model: str
|
20 |
-
|
21 |
-
def indicator_per_year_at_location_query(
|
22 |
-
table: str, params: IndicatorPerYearAtLocationQueryParams
|
23 |
-
) -> str:
|
24 |
-
"""SQL Query to get the evolution of an indicator per year at a certain location
|
25 |
-
|
26 |
-
Args:
|
27 |
-
table (str): sql table of the indicator
|
28 |
-
params (IndicatorPerYearAtLocationQueryParams) : dictionary with the required params for the query
|
29 |
-
|
30 |
-
Returns:
|
31 |
-
str: the sql query
|
32 |
-
"""
|
33 |
-
indicator_column = params.get("indicator_column")
|
34 |
-
latitude = params.get("latitude")
|
35 |
-
longitude = params.get("longitude")
|
36 |
-
|
37 |
-
if indicator_column is None or latitude is None or longitude is None: # If one parameter is missing, returns an empty query
|
38 |
-
return ""
|
39 |
-
|
40 |
-
table = f"'{DRIAS_DATASET_URL}/{table.lower()}.parquet'"
|
41 |
-
|
42 |
-
sql_query = f"SELECT year, {indicator_column}, model\nFROM {table}\nWHERE latitude = {latitude} \nAnd longitude = {longitude} \nOrder by Year"
|
43 |
-
|
44 |
-
return sql_query
|
45 |
-
|
46 |
-
class IndicatorForGivenYearQueryParams(TypedDict, total=False):
|
47 |
-
"""Parameters for querying an indicator's values across locations for a year.
|
48 |
-
|
49 |
-
This class defines the parameters needed to query climate indicator data
|
50 |
-
across different locations for a specific year.
|
51 |
-
|
52 |
-
Attributes:
|
53 |
-
indicator_column (str): The column name for the climate indicator
|
54 |
-
year (str): The year to query
|
55 |
-
model (str): The climate model to use (optional)
|
56 |
-
"""
|
57 |
-
indicator_column: str
|
58 |
-
year: str
|
59 |
-
model: str
|
60 |
-
|
61 |
-
def indicator_for_given_year_query(
|
62 |
-
table:str, params: IndicatorForGivenYearQueryParams
|
63 |
-
) -> str:
|
64 |
-
"""SQL Query to get the values of an indicator with their latitudes, longitudes and models for a given year
|
65 |
-
|
66 |
-
Args:
|
67 |
-
table (str): sql table of the indicator
|
68 |
-
params (IndicatorForGivenYearQueryParams): dictionarry with the required params for the query
|
69 |
-
|
70 |
-
Returns:
|
71 |
-
str: the sql query
|
72 |
-
"""
|
73 |
-
indicator_column = params.get("indicator_column")
|
74 |
-
year = params.get('year')
|
75 |
-
if year is None:
|
76 |
-
year = 2050
|
77 |
-
if year is None or indicator_column is None: # If one parameter is missing, returns an empty query
|
78 |
-
return ""
|
79 |
-
|
80 |
-
table = f"'{DRIAS_DATASET_URL}/{table.lower()}.parquet'"
|
81 |
-
|
82 |
-
sql_query = f"Select {indicator_column}, latitude, longitude, model\nFrom {table}\nWhere year = {year}"
|
83 |
-
return sql_query
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/talk_to_data/input_processing.py
DELETED
@@ -1,257 +0,0 @@
|
|
1 |
-
from typing import Any, Literal, Optional, cast
|
2 |
-
import ast
|
3 |
-
from langchain_core.prompts import ChatPromptTemplate
|
4 |
-
from geopy.geocoders import Nominatim
|
5 |
-
from climateqa.engine.llm import get_llm
|
6 |
-
import duckdb
|
7 |
-
import os
|
8 |
-
from climateqa.engine.talk_to_data.config import DRIAS_MEAN_ANNUAL_TEMPERATURE_PATH, IPCC_COORDINATES_PATH
|
9 |
-
from climateqa.engine.talk_to_data.objects.llm_outputs import ArrayOutput
|
10 |
-
from climateqa.engine.talk_to_data.objects.location import Location
|
11 |
-
from climateqa.engine.talk_to_data.objects.plot import Plot
|
12 |
-
from climateqa.engine.talk_to_data.objects.states import State
|
13 |
-
|
14 |
-
async def detect_location_with_openai(sentence: str) -> str:
|
15 |
-
"""
|
16 |
-
Detects locations in a sentence using OpenAI's API via LangChain.
|
17 |
-
"""
|
18 |
-
llm = get_llm()
|
19 |
-
|
20 |
-
prompt = f"""
|
21 |
-
Extract all locations (cities, countries, states, or geographical areas) mentioned in the following sentence.
|
22 |
-
Return the result as a Python list. If no locations are mentioned, return an empty list.
|
23 |
-
|
24 |
-
Sentence: "{sentence}"
|
25 |
-
"""
|
26 |
-
|
27 |
-
response = await llm.ainvoke(prompt)
|
28 |
-
location_list = ast.literal_eval(response.content.strip("```python\n").strip())
|
29 |
-
if location_list:
|
30 |
-
return location_list[0]
|
31 |
-
else:
|
32 |
-
return ""
|
33 |
-
|
34 |
-
def loc_to_coords(location: str) -> tuple[float, float]:
|
35 |
-
"""Converts a location name to geographic coordinates.
|
36 |
-
|
37 |
-
This function uses the Nominatim geocoding service to convert
|
38 |
-
a location name (e.g., city name) to its latitude and longitude.
|
39 |
-
|
40 |
-
Args:
|
41 |
-
location (str): The name of the location to geocode
|
42 |
-
|
43 |
-
Returns:
|
44 |
-
tuple[float, float]: A tuple containing (latitude, longitude)
|
45 |
-
|
46 |
-
Raises:
|
47 |
-
AttributeError: If the location cannot be found
|
48 |
-
"""
|
49 |
-
geolocator = Nominatim(user_agent="city_to_latlong", timeout=5)
|
50 |
-
coords = geolocator.geocode(location)
|
51 |
-
return (coords.latitude, coords.longitude)
|
52 |
-
|
53 |
-
def coords_to_country(coords: tuple[float, float]) -> tuple[str,str]:
|
54 |
-
"""Converts geographic coordinates to a country name.
|
55 |
-
|
56 |
-
This function uses the Nominatim reverse geocoding service to convert
|
57 |
-
latitude and longitude coordinates to a country name.
|
58 |
-
|
59 |
-
Args:
|
60 |
-
coords (tuple[float, float]): A tuple containing (latitude, longitude)
|
61 |
-
|
62 |
-
Returns:
|
63 |
-
tuple[str,str]: A tuple containg (country_code, country_name, admin1)
|
64 |
-
|
65 |
-
Raises:
|
66 |
-
AttributeError: If the coordinates cannot be found
|
67 |
-
"""
|
68 |
-
geolocator = Nominatim(user_agent="latlong_to_country")
|
69 |
-
location = geolocator.reverse(coords)
|
70 |
-
address = location.raw['address']
|
71 |
-
return address['country_code'].upper(), address['country']
|
72 |
-
|
73 |
-
def nearest_neighbour_sql(location: tuple, mode: Literal['DRIAS', 'IPCC']) -> tuple[str, str, Optional[str]]:
|
74 |
-
long = round(location[1], 3)
|
75 |
-
lat = round(location[0], 3)
|
76 |
-
conn = duckdb.connect()
|
77 |
-
|
78 |
-
if mode == 'DRIAS':
|
79 |
-
table_path = f"'{DRIAS_MEAN_ANNUAL_TEMPERATURE_PATH}'"
|
80 |
-
results = conn.sql(
|
81 |
-
f"SELECT latitude, longitude FROM {table_path} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}"
|
82 |
-
).fetchdf()
|
83 |
-
else:
|
84 |
-
table_path = f"'{IPCC_COORDINATES_PATH}'"
|
85 |
-
results = conn.sql(
|
86 |
-
f"SELECT latitude, longitude, admin1 FROM {table_path} WHERE latitude BETWEEN {lat - 0.5} AND {lat + 0.5} AND longitude BETWEEN {long - 0.5} AND {long + 0.5}"
|
87 |
-
).fetchdf()
|
88 |
-
|
89 |
-
|
90 |
-
if len(results) == 0:
|
91 |
-
return "", "", ""
|
92 |
-
|
93 |
-
if 'admin1' in results.columns:
|
94 |
-
admin1 = results['admin1'].iloc[0]
|
95 |
-
else:
|
96 |
-
admin1 = None
|
97 |
-
return results['latitude'].iloc[0], results['longitude'].iloc[0], admin1
|
98 |
-
|
99 |
-
async def detect_year_with_openai(sentence: str) -> str:
|
100 |
-
"""
|
101 |
-
Detects years in a sentence using OpenAI's API via LangChain.
|
102 |
-
"""
|
103 |
-
llm = get_llm()
|
104 |
-
|
105 |
-
prompt = """
|
106 |
-
Extract all years mentioned in the following sentence.
|
107 |
-
Return the result as a Python list. If no year are mentioned, return an empty list.
|
108 |
-
|
109 |
-
Sentence: "{sentence}"
|
110 |
-
"""
|
111 |
-
|
112 |
-
prompt = ChatPromptTemplate.from_template(prompt)
|
113 |
-
structured_llm = llm.with_structured_output(ArrayOutput)
|
114 |
-
chain = prompt | structured_llm
|
115 |
-
response: ArrayOutput = await chain.ainvoke({"sentence": sentence})
|
116 |
-
years_list = ast.literal_eval(response['array'])
|
117 |
-
if len(years_list) > 0:
|
118 |
-
return years_list[0]
|
119 |
-
else:
|
120 |
-
return ""
|
121 |
-
|
122 |
-
|
123 |
-
async def detect_relevant_tables(user_question: str, plot: Plot, llm, table_names_list: list[str]) -> list[str]:
|
124 |
-
"""Identifies relevant tables for a plot based on user input.
|
125 |
-
|
126 |
-
This function uses an LLM to analyze the user's question and the plot
|
127 |
-
description to determine which tables in the DRIAS database would be
|
128 |
-
most relevant for generating the requested visualization.
|
129 |
-
|
130 |
-
Args:
|
131 |
-
user_question (str): The user's question about climate data
|
132 |
-
plot (Plot): The plot configuration object
|
133 |
-
llm: The language model instance to use for analysis
|
134 |
-
|
135 |
-
Returns:
|
136 |
-
list[str]: A list of table names that are relevant for the plot
|
137 |
-
|
138 |
-
Example:
|
139 |
-
>>> detect_relevant_tables(
|
140 |
-
... "What will the temperature be like in Paris?",
|
141 |
-
... indicator_evolution_at_location,
|
142 |
-
... llm
|
143 |
-
... )
|
144 |
-
['mean_annual_temperature', 'mean_summer_temperature']
|
145 |
-
"""
|
146 |
-
# Get all table names
|
147 |
-
|
148 |
-
prompt = (
|
149 |
-
f"You are helping to build a plot following this description : {plot['description']}."
|
150 |
-
f"You are given a list of tables and a user question."
|
151 |
-
f"Based on the description of the plot, which table are appropriate for that kind of plot."
|
152 |
-
f"Write the 3 most relevant tables to use. Answer only a python list of table name."
|
153 |
-
f"### List of tables : {table_names_list}"
|
154 |
-
f"### User question : {user_question}"
|
155 |
-
f"### List of table name : "
|
156 |
-
)
|
157 |
-
|
158 |
-
table_names = ast.literal_eval(
|
159 |
-
(await llm.ainvoke(prompt)).content.strip("```python\n").strip()
|
160 |
-
)
|
161 |
-
return table_names
|
162 |
-
|
163 |
-
async def detect_relevant_plots(user_question: str, llm, plot_list: list[Plot]) -> list[str]:
|
164 |
-
plots_description = ""
|
165 |
-
for plot in plot_list:
|
166 |
-
plots_description += "Name: " + plot["name"]
|
167 |
-
plots_description += " - Description: " + plot["description"] + "\n"
|
168 |
-
|
169 |
-
prompt = (
|
170 |
-
"You are helping to answer a question with insightful visualizations.\n"
|
171 |
-
"You are given a user question and a list of plots with their name and description.\n"
|
172 |
-
"Based on the descriptions of the plots, select ALL plots that could provide a useful answer to this question. "
|
173 |
-
"Include any plot that could show relevant information, even if their perspectives (such as time series or spatial distribution) are different.\n"
|
174 |
-
"For example, for a question like 'What will be the total rainfall in China in 2050?', both a time series plot and a spatial map plot could be relevant.\n"
|
175 |
-
"Return only a Python list of plot names sorted from the most relevant one to the less relevant one.\n"
|
176 |
-
f"### Descriptions of the plots : {plots_description}"
|
177 |
-
f"### User question : {user_question}\n"
|
178 |
-
f"### Names of the plots : "
|
179 |
-
)
|
180 |
-
|
181 |
-
plot_names = ast.literal_eval(
|
182 |
-
(await llm.ainvoke(prompt)).content.strip("```python\n").strip()
|
183 |
-
)
|
184 |
-
return plot_names
|
185 |
-
|
186 |
-
async def find_location(user_input: str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> Location:
|
187 |
-
print(f"---- Find location in user input ----")
|
188 |
-
location = await detect_location_with_openai(user_input)
|
189 |
-
output: Location = {
|
190 |
-
'location' : location,
|
191 |
-
'longitude' : None,
|
192 |
-
'latitude' : None,
|
193 |
-
'country_code' : None,
|
194 |
-
'country_name' : None,
|
195 |
-
'admin1' : None
|
196 |
-
}
|
197 |
-
|
198 |
-
if location:
|
199 |
-
coords = loc_to_coords(location)
|
200 |
-
country_code, country_name = coords_to_country(coords)
|
201 |
-
neighbour = nearest_neighbour_sql(coords, mode)
|
202 |
-
output.update({
|
203 |
-
"latitude": neighbour[0],
|
204 |
-
"longitude": neighbour[1],
|
205 |
-
"country_code": country_code,
|
206 |
-
"country_name": country_name,
|
207 |
-
"admin1": neighbour[2]
|
208 |
-
})
|
209 |
-
output = cast(Location, output)
|
210 |
-
return output
|
211 |
-
|
212 |
-
async def find_year(user_input: str) -> str| None:
|
213 |
-
"""Extracts year information from user input using LLM.
|
214 |
-
|
215 |
-
This function uses an LLM to identify and extract year information from the
|
216 |
-
user's query, which is used to filter data in subsequent queries.
|
217 |
-
|
218 |
-
Args:
|
219 |
-
user_input (str): The user's query text
|
220 |
-
|
221 |
-
Returns:
|
222 |
-
str: The extracted year, or empty string if no year found
|
223 |
-
"""
|
224 |
-
print(f"---- Find year ---")
|
225 |
-
year = await detect_year_with_openai(user_input)
|
226 |
-
if year == "":
|
227 |
-
return None
|
228 |
-
return year
|
229 |
-
|
230 |
-
async def find_relevant_plots(state: State, llm, plots: list[Plot]) -> list[str]:
|
231 |
-
print("---- Find relevant plots ----")
|
232 |
-
relevant_plots = await detect_relevant_plots(state['user_input'], llm, plots)
|
233 |
-
return relevant_plots
|
234 |
-
|
235 |
-
async def find_relevant_tables_per_plot(state: State, plot: Plot, llm, tables: list[str]) -> list[str]:
|
236 |
-
print(f"---- Find relevant tables for {plot['name']} ----")
|
237 |
-
relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm, tables)
|
238 |
-
return relevant_tables
|
239 |
-
|
240 |
-
async def find_param(state: State, param_name:str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> dict[str, Optional[str]] | Location | None:
|
241 |
-
"""Perform the good method to retrieve the desired parameter
|
242 |
-
|
243 |
-
Args:
|
244 |
-
state (State): state of the workflow
|
245 |
-
param_name (str): name of the desired parameter
|
246 |
-
table (str): name of the table
|
247 |
-
|
248 |
-
Returns:
|
249 |
-
dict[str, Any] | None:
|
250 |
-
"""
|
251 |
-
if param_name == 'location':
|
252 |
-
location = await find_location(state['user_input'], mode)
|
253 |
-
return location
|
254 |
-
if param_name == 'year':
|
255 |
-
year = await find_year(state['user_input'])
|
256 |
-
return {'year': year}
|
257 |
-
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/talk_to_data/ipcc/config.py
DELETED
@@ -1,98 +0,0 @@
|
|
1 |
-
from climateqa.engine.talk_to_data.ui_config import PRECIPITATION_COLORSCALE, TEMPERATURE_COLORSCALE
|
2 |
-
from climateqa.engine.talk_to_data.config import IPCC_DATASET_URL
|
3 |
-
|
4 |
-
|
5 |
-
# IPCC_DATASET_URL = "hf://datasets/ekimetrics/ipcc-atlas"
|
6 |
-
IPCC_TABLES = [
|
7 |
-
"mean_temperature",
|
8 |
-
"total_precipitation",
|
9 |
-
]
|
10 |
-
|
11 |
-
IPCC_INDICATOR_COLUMNS_PER_TABLE = {
|
12 |
-
"mean_temperature": "mean_temperature",
|
13 |
-
"total_precipitation": "total_precipitation"
|
14 |
-
}
|
15 |
-
|
16 |
-
IPCC_INDICATOR_TO_UNIT = {
|
17 |
-
"mean_temperature": "°C",
|
18 |
-
"total_precipitation": "mm/day"
|
19 |
-
}
|
20 |
-
|
21 |
-
IPCC_SCENARIO = [
|
22 |
-
"historical",
|
23 |
-
"ssp126",
|
24 |
-
"ssp245",
|
25 |
-
"ssp370",
|
26 |
-
"ssp585",
|
27 |
-
]
|
28 |
-
|
29 |
-
IPCC_MODELS = []
|
30 |
-
|
31 |
-
IPCC_PLOT_PARAMETERS = [
|
32 |
-
'year',
|
33 |
-
'location'
|
34 |
-
]
|
35 |
-
|
36 |
-
MACRO_COUNTRIES = ['JP',
|
37 |
-
'IN',
|
38 |
-
'MH',
|
39 |
-
'PT',
|
40 |
-
'ID',
|
41 |
-
'SJ',
|
42 |
-
'MX',
|
43 |
-
'CN',
|
44 |
-
'GL',
|
45 |
-
'PN',
|
46 |
-
'AR',
|
47 |
-
'AQ',
|
48 |
-
'PF',
|
49 |
-
'BR',
|
50 |
-
'SH',
|
51 |
-
'GS',
|
52 |
-
'ZA',
|
53 |
-
'NZ',
|
54 |
-
'TF',
|
55 |
-
]
|
56 |
-
|
57 |
-
HUGE_MACRO_COUNTRIES = ['CL',
|
58 |
-
'CA',
|
59 |
-
'AU',
|
60 |
-
'US',
|
61 |
-
'RU'
|
62 |
-
]
|
63 |
-
|
64 |
-
IPCC_INDICATOR_TO_COLORSCALE = {
|
65 |
-
"mean_temperature": TEMPERATURE_COLORSCALE,
|
66 |
-
"total_precipitation": PRECIPITATION_COLORSCALE
|
67 |
-
}
|
68 |
-
|
69 |
-
IPCC_UI_TEXT = """
|
70 |
-
Hi, I'm **Talk to IPCC**, designed to answer your questions using [**IPCC - ATLAS**](https://interactive-atlas.ipcc.ch/regional-information#eyJ0eXBlIjoiQVRMQVMiLCJjb21tb25zIjp7ImxhdCI6OTc3MiwibG5nIjo0MDA2OTIsInpvb20iOjQsInByb2oiOiJFUFNHOjU0MDMwIiwibW9kZSI6ImNvbXBsZXRlX2F0bGFzIn0sInByaW1hcnkiOnsic2NlbmFyaW8iOiJzc3A1ODUiLCJwZXJpb2QiOiIyIiwic2Vhc29uIjoieWVhciIsImRhdGFzZXQiOiJDTUlQNiIsInZhcmlhYmxlIjoidGFzIiwidmFsdWVUeXBlIjoiQU5PTUFMWSIsImhhdGNoaW5nIjoiU0lNUExFIiwicmVnaW9uU2V0IjoiYXI2IiwiYmFzZWxpbmUiOiJwcmVJbmR1c3RyaWFsIiwicmVnaW9uc1NlbGVjdGVkIjpbXX0sInBsb3QiOnsiYWN0aXZlVGFiIjoicGx1bWUiLCJtYXNrIjoibm9uZSIsInNjYXR0ZXJZTWFnIjpudWxsLCJzY2F0dGVyWVZhciI6bnVsbCwic2hvd2luZyI6ZmFsc2V9fQ==) data.
|
71 |
-
I'll answer by displaying a list of SQL queries, graphs and data most relevant to your question.
|
72 |
-
|
73 |
-
You can ask me anything about these climate indicators: **temperature** or **precipitation**.
|
74 |
-
You can specify **location** and/or **year**.
|
75 |
-
By default, we take the **mediane of each climate model**.
|
76 |
-
|
77 |
-
Current available charts :
|
78 |
-
- Yearly evolution of an indicator at a specific location (historical + SSP Projections)
|
79 |
-
- Yearly spatial distribution of an indicator in a specific country
|
80 |
-
|
81 |
-
Current available indicators :
|
82 |
-
- Mean temperature
|
83 |
-
- Total precipitation
|
84 |
-
|
85 |
-
For example, you can ask:
|
86 |
-
- What will the temperature be like in Paris?
|
87 |
-
- What will be the total rainfall in the USA in 2030?
|
88 |
-
- How will the average temperature evolve in China ?
|
89 |
-
|
90 |
-
⚠️ **Limitations**:
|
91 |
-
- You can't ask anything that isn't related to **IPCC - ATLAS** data.
|
92 |
-
- You can not ask about **several locations at the same time**.
|
93 |
-
- If you specify a year **before 1850 or over 2100**, there will be **no data**.
|
94 |
-
- You **cannot compare two models**.
|
95 |
-
|
96 |
-
🛈 **Information**
|
97 |
-
Please note that we **log your questions for meta-analysis purposes**, so avoid sharing any sensitive or personal information.
|
98 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/talk_to_data/ipcc/plot_informations.py
DELETED
@@ -1,50 +0,0 @@
|
|
1 |
-
from climateqa.engine.talk_to_data.ipcc.config import IPCC_INDICATOR_TO_UNIT
|
2 |
-
|
3 |
-
def indicator_evolution_informations(
|
4 |
-
indicator: str,
|
5 |
-
params: dict[str,str],
|
6 |
-
) -> str:
|
7 |
-
if "location" not in params:
|
8 |
-
raise ValueError('"location" must be provided in params')
|
9 |
-
location = params["location"]
|
10 |
-
|
11 |
-
unit = IPCC_INDICATOR_TO_UNIT[indicator]
|
12 |
-
return f"""
|
13 |
-
This plot shows how the climate indicator **{indicator}** evolves over time in **{location}**.
|
14 |
-
|
15 |
-
It combines both historical (from 1950 to 2015) observations and future (from 2016 to 2100) projections for the different SSP climate scenarios (SSP126, SSP245, SSP370 and SSP585).
|
16 |
-
|
17 |
-
The x-axis represents the years (from 1950 to 2100), and the y-axis shows the value of the {indicator} ({unit}).
|
18 |
-
|
19 |
-
Each line corresponds to a different scenario, allowing you to compare how {indicator} might change under various future conditions.
|
20 |
-
|
21 |
-
**Data source:**
|
22 |
-
- The data comes from the CMIP6 IPCC ATLAS data. The data were initially extracted from [this referenced website](https://digital.csic.es/handle/10261/332744) and then preprocessed to a tabular format and uploaded as parquet in this [Hugging Face dataset](https://huggingface.co/datasets/Ekimetrics/ipcc-atlas).
|
23 |
-
- The underlying data is retrieved by aggregating yearly values of {indicator} for the selected location, across all available scenarios. This means the system collects, for each year, the value of {indicator} in {location}, both for the historical period and for each scenario, to build the time series.
|
24 |
-
- The coordinates used for {location} correspond to the closest available point in the IPCC database, which uses a regular grid with a spatial resolution of 1 degree.
|
25 |
-
"""
|
26 |
-
|
27 |
-
def choropleth_map_informations(
|
28 |
-
indicator: str,
|
29 |
-
params: dict[str, str],
|
30 |
-
) -> str:
|
31 |
-
unit = IPCC_INDICATOR_TO_UNIT[indicator]
|
32 |
-
if "location" not in params:
|
33 |
-
raise ValueError('"location" must be provided in params')
|
34 |
-
location = params["location"]
|
35 |
-
country_name = params["country_name"]
|
36 |
-
year = params["year"]
|
37 |
-
if year is None:
|
38 |
-
year = 2050
|
39 |
-
|
40 |
-
return f"""
|
41 |
-
This plot displays a choropleth map showing the spatial distribution of **{indicator}** across all regions of **{location}** country ({country_name}) for the year **{year}** and the chosen scenario.
|
42 |
-
|
43 |
-
Each grid point is colored according to the value of the indicator ({unit}), allowing you to visually compare how {indicator} varies geographically within the country for the selected year and scenario.
|
44 |
-
|
45 |
-
**Data source:**
|
46 |
-
- The data come from the CMIP6 IPCC ATLAS data. The data were initially extracted from [this referenced website](https://digital.csic.es/handle/10261/332744) and then preprocessed to a tabular format and uploaded as parquet in this [Hugging Face dataset](https://huggingface.co/datasets/Ekimetrics/ipcc-atlas).
|
47 |
-
- For each grid point of {location} country ({country_name}), the value of {indicator} in {year} and for the selected scenario is extracted and mapped to its geographic coordinates.
|
48 |
-
- The grid points correspond to 1-degree squares centered on the grid points of the IPCC dataset. Each grid point has been mapped to a country using [**reverse_geocoder**](https://github.com/thampiman/reverse-geocoder).
|
49 |
-
- The coordinates used for each region are those of the closest available grid point in the IPCC database, which uses a regular grid with a spatial resolution of 1 degree.
|
50 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/talk_to_data/ipcc/plots.py
DELETED
@@ -1,189 +0,0 @@
|
|
1 |
-
from typing import Callable
|
2 |
-
from plotly.graph_objects import Figure
|
3 |
-
import plotly.graph_objects as go
|
4 |
-
import pandas as pd
|
5 |
-
import geojson
|
6 |
-
|
7 |
-
from climateqa.engine.talk_to_data.ipcc.config import IPCC_INDICATOR_TO_COLORSCALE, IPCC_INDICATOR_TO_UNIT, IPCC_SCENARIO
|
8 |
-
from climateqa.engine.talk_to_data.ipcc.plot_informations import choropleth_map_informations, indicator_evolution_informations
|
9 |
-
from climateqa.engine.talk_to_data.ipcc.queries import indicator_for_given_year_query, indicator_per_year_at_location_query
|
10 |
-
from climateqa.engine.talk_to_data.objects.plot import Plot
|
11 |
-
|
12 |
-
def generate_geojson_polygons(latitudes: list[float], longitudes: list[float], indicators: list[float]) -> geojson.FeatureCollection:
|
13 |
-
features = [
|
14 |
-
geojson.Feature(
|
15 |
-
geometry=geojson.Polygon([[
|
16 |
-
[lon - 0.5, lat - 0.5],
|
17 |
-
[lon + 0.5, lat - 0.5],
|
18 |
-
[lon + 0.5, lat + 0.5],
|
19 |
-
[lon - 0.5, lat + 0.5],
|
20 |
-
[lon - 0.5, lat - 0.5]
|
21 |
-
]]),
|
22 |
-
properties={"value": val},
|
23 |
-
id=str(idx)
|
24 |
-
)
|
25 |
-
for idx, (lat, lon, val) in enumerate(zip(latitudes, longitudes, indicators))
|
26 |
-
]
|
27 |
-
|
28 |
-
geojson_data = geojson.FeatureCollection(features)
|
29 |
-
return geojson_data
|
30 |
-
|
31 |
-
def plot_indicator_evolution_at_location_historical_and_projections(
|
32 |
-
params: dict,
|
33 |
-
) -> Callable[[pd.DataFrame], Figure]:
|
34 |
-
"""
|
35 |
-
Returns a function that generates a line plot showing the evolution of a climate indicator
|
36 |
-
(e.g., temperature, rainfall) over time at a specific location, including both historical data
|
37 |
-
and future projections for different climate scenarios.
|
38 |
-
|
39 |
-
Args:
|
40 |
-
params (dict): Dictionary with:
|
41 |
-
- indicator_column (str): Name of the climate indicator column to plot.
|
42 |
-
- location (str): Location (e.g., country, city) for which to plot the indicator.
|
43 |
-
|
44 |
-
Returns:
|
45 |
-
Callable[[pd.DataFrame], Figure]: Function that takes a DataFrame and returns a Plotly Figure
|
46 |
-
showing the indicator's evolution over time, with scenario lines and historical data.
|
47 |
-
"""
|
48 |
-
indicator = params["indicator_column"]
|
49 |
-
location = params["location"]
|
50 |
-
indicator_label = " ".join(word.capitalize() for word in indicator.split("_"))
|
51 |
-
unit = IPCC_INDICATOR_TO_UNIT.get(indicator, "")
|
52 |
-
|
53 |
-
def plot_data(df: pd.DataFrame) -> Figure:
|
54 |
-
df = df.sort_values(by='year')
|
55 |
-
years = df['year'].astype(int).tolist()
|
56 |
-
indicators = df[indicator].astype(float).tolist()
|
57 |
-
scenarios = df['scenario'].astype(str).tolist()
|
58 |
-
|
59 |
-
# Find last historical value for continuity
|
60 |
-
last_historical = [(y, v) for y, v, s in zip(years, indicators, scenarios) if s == 'historical']
|
61 |
-
last_historical_year, last_historical_indicator = last_historical[-1] if last_historical else (None, None)
|
62 |
-
|
63 |
-
fig = go.Figure()
|
64 |
-
for scenario in IPCC_SCENARIO:
|
65 |
-
x = [y for y, s in zip(years, scenarios) if s == scenario]
|
66 |
-
y = [v for v, s in zip(indicators, scenarios) if s == scenario]
|
67 |
-
# Connect historical to scenario
|
68 |
-
if scenario != 'historical' and last_historical_indicator is not None:
|
69 |
-
x = [last_historical_year] + x
|
70 |
-
y = [last_historical_indicator] + y
|
71 |
-
fig.add_trace(go.Scatter(
|
72 |
-
x=x,
|
73 |
-
y=y,
|
74 |
-
mode='lines',
|
75 |
-
name=scenario
|
76 |
-
))
|
77 |
-
|
78 |
-
fig.update_layout(
|
79 |
-
title=f'Yearly Evolution of {indicator_label} in {location} (Historical + SSP Scenarios)',
|
80 |
-
xaxis_title='Year',
|
81 |
-
yaxis_title=f'{indicator_label} ({unit})',
|
82 |
-
legend_title='Scenario',
|
83 |
-
height=800,
|
84 |
-
)
|
85 |
-
return fig
|
86 |
-
|
87 |
-
return plot_data
|
88 |
-
|
89 |
-
indicator_evolution_at_location_historical_and_projections: Plot = {
|
90 |
-
"name": "Indicator Evolution at Location (Historical + Projections)",
|
91 |
-
"description": (
|
92 |
-
"Shows how a climate indicator (e.g., rainfall, temperature) changes over time at a specific location, "
|
93 |
-
"including historical data and future projections. "
|
94 |
-
"Useful for questions about the value or trend of an indicator at a location for any year, "
|
95 |
-
"such as 'What will be the total rainfall in China in 2050?' or 'How does rainfall evolve in China over time?'. "
|
96 |
-
"Parameters: indicator_column (the climate variable), location (e.g., country, city)."
|
97 |
-
),
|
98 |
-
"params": ["indicator_column", "location"],
|
99 |
-
"plot_function": plot_indicator_evolution_at_location_historical_and_projections,
|
100 |
-
"sql_query": indicator_per_year_at_location_query,
|
101 |
-
"plot_information": indicator_evolution_informations,
|
102 |
-
"short_name": "Evolution"
|
103 |
-
}
|
104 |
-
|
105 |
-
def plot_choropleth_map_of_country_indicator_for_specific_year(
|
106 |
-
params: dict,
|
107 |
-
) -> Callable[[pd.DataFrame], Figure]:
|
108 |
-
"""
|
109 |
-
Returns a function that generates a choropleth map (heatmap) showing the spatial distribution
|
110 |
-
of a climate indicator (e.g., temperature, rainfall) across all regions of a country for a specific year.
|
111 |
-
|
112 |
-
Args:
|
113 |
-
params (dict): Dictionary with:
|
114 |
-
- indicator_column (str): Name of the climate indicator column to plot.
|
115 |
-
- year (str or int, optional): Year for which to plot the indicator (default: 2050).
|
116 |
-
- country_name (str): Name of the country.
|
117 |
-
- location (str): Location (country or region) for the map.
|
118 |
-
|
119 |
-
Returns:
|
120 |
-
Callable[[pd.DataFrame], Figure]: Function that takes a DataFrame and returns a Plotly Figure
|
121 |
-
showing the indicator's spatial distribution as a choropleth map for the specified year.
|
122 |
-
"""
|
123 |
-
indicator = params["indicator_column"]
|
124 |
-
year = params.get('year')
|
125 |
-
if year is None:
|
126 |
-
year = 2050
|
127 |
-
country_name = params['country_name']
|
128 |
-
location = params['location']
|
129 |
-
indicator_label = " ".join(word.capitalize() for word in indicator.split("_"))
|
130 |
-
unit = IPCC_INDICATOR_TO_UNIT.get(indicator, "")
|
131 |
-
|
132 |
-
def plot_data(df: pd.DataFrame) -> Figure:
|
133 |
-
|
134 |
-
indicators = df[indicator].astype(float).tolist()
|
135 |
-
latitudes = df["latitude"].astype(float).tolist()
|
136 |
-
longitudes = df["longitude"].astype(float).tolist()
|
137 |
-
|
138 |
-
geojson_data = generate_geojson_polygons(latitudes, longitudes, indicators)
|
139 |
-
|
140 |
-
fig = go.Figure(go.Choroplethmapbox(
|
141 |
-
geojson=geojson_data,
|
142 |
-
locations=[str(i) for i in range(len(indicators))],
|
143 |
-
featureidkey="id",
|
144 |
-
z=indicators,
|
145 |
-
colorscale=IPCC_INDICATOR_TO_COLORSCALE[indicator],
|
146 |
-
zmin=min(indicators),
|
147 |
-
zmax=max(indicators),
|
148 |
-
marker_opacity=0.7,
|
149 |
-
marker_line_width=0,
|
150 |
-
colorbar_title=f"{indicator_label} ({unit})",
|
151 |
-
text=[f"{indicator_label}: {value:.2f} {unit}" for value in indicators], # Add hover text showing the indicator value
|
152 |
-
hoverinfo="text"
|
153 |
-
))
|
154 |
-
|
155 |
-
fig.update_layout(
|
156 |
-
mapbox_style="open-street-map",
|
157 |
-
mapbox_zoom=2,
|
158 |
-
height=800,
|
159 |
-
mapbox_center={
|
160 |
-
"lat": latitudes[len(latitudes)//2] if latitudes else 0,
|
161 |
-
"lon": longitudes[len(longitudes)//2] if longitudes else 0
|
162 |
-
},
|
163 |
-
coloraxis_colorbar=dict(title=f"{indicator_label} ({unit})"),
|
164 |
-
title=f"{indicator_label} in {year} in {location} ({country_name})"
|
165 |
-
)
|
166 |
-
return fig
|
167 |
-
|
168 |
-
return plot_data
|
169 |
-
|
170 |
-
choropleth_map_of_country_indicator_for_specific_year: Plot = {
|
171 |
-
"name": "Choropleth Map of a Country's Indicator Distribution for a Specific Year",
|
172 |
-
"description": (
|
173 |
-
"Displays a map showing the spatial distribution of a climate indicator (e.g., rainfall, temperature) "
|
174 |
-
"across all regions of a country for a specific year. "
|
175 |
-
"Can answer questions about the value of an indicator in a country or region for a given year, "
|
176 |
-
"such as 'What will be the total rainfall in China in 2050?' or 'How is rainfall distributed across China in 2050?'. "
|
177 |
-
"Parameters: indicator_column (the climate variable), year, location (country name)."
|
178 |
-
),
|
179 |
-
"params": ["indicator_column", "year", "location"],
|
180 |
-
"plot_function": plot_choropleth_map_of_country_indicator_for_specific_year,
|
181 |
-
"sql_query": indicator_for_given_year_query,
|
182 |
-
"plot_information": choropleth_map_informations,
|
183 |
-
"short_name": "Map",
|
184 |
-
}
|
185 |
-
|
186 |
-
IPCC_PLOTS = [
|
187 |
-
indicator_evolution_at_location_historical_and_projections,
|
188 |
-
choropleth_map_of_country_indicator_for_specific_year
|
189 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/talk_to_data/ipcc/queries.py
DELETED
@@ -1,144 +0,0 @@
|
|
1 |
-
from typing import TypedDict, Optional
|
2 |
-
|
3 |
-
from climateqa.engine.talk_to_data.ipcc.config import HUGE_MACRO_COUNTRIES, MACRO_COUNTRIES
|
4 |
-
from climateqa.engine.talk_to_data.config import IPCC_DATASET_URL
|
5 |
-
|
6 |
-
class IndicatorPerYearAtLocationQueryParams(TypedDict, total=False):
|
7 |
-
"""
|
8 |
-
Parameters for querying the evolution of an indicator per year at a specific location.
|
9 |
-
|
10 |
-
Attributes:
|
11 |
-
indicator_column (str): Name of the climate indicator column.
|
12 |
-
latitude (str): Latitude of the location.
|
13 |
-
longitude (str): Longitude of the location.
|
14 |
-
country_code (str): Country code.
|
15 |
-
admin1 (str): Administrative region (optional).
|
16 |
-
"""
|
17 |
-
indicator_column: str
|
18 |
-
latitude: str
|
19 |
-
longitude: str
|
20 |
-
country_code: str
|
21 |
-
admin1: Optional[str]
|
22 |
-
|
23 |
-
def indicator_per_year_at_location_query(
|
24 |
-
table: str, params: IndicatorPerYearAtLocationQueryParams
|
25 |
-
) -> str:
|
26 |
-
"""
|
27 |
-
Builds an SQL query to get the evolution of an indicator per year at a specific location.
|
28 |
-
|
29 |
-
Args:
|
30 |
-
table (str): SQL table of the indicator.
|
31 |
-
params (IndicatorPerYearAtLocationQueryParams): Dictionary with the required params for the query.
|
32 |
-
|
33 |
-
Returns:
|
34 |
-
str: The SQL query string, or an empty string if required parameters are missing.
|
35 |
-
"""
|
36 |
-
indicator_column = params.get("indicator_column")
|
37 |
-
latitude = params.get("latitude")
|
38 |
-
longitude = params.get("longitude")
|
39 |
-
country_code = params.get("country_code")
|
40 |
-
admin1 = params.get("admin1")
|
41 |
-
|
42 |
-
if not all([indicator_column, latitude, longitude, country_code]):
|
43 |
-
return ""
|
44 |
-
|
45 |
-
if country_code in MACRO_COUNTRIES:
|
46 |
-
table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_macro.parquet'"
|
47 |
-
sql_query = f"""
|
48 |
-
SELECT year, scenario, AVG({indicator_column}) as {indicator_column}
|
49 |
-
FROM {table_path}
|
50 |
-
WHERE latitude = {latitude} AND longitude = {longitude} AND year >= 1950
|
51 |
-
GROUP BY scenario, year
|
52 |
-
ORDER BY year, scenario
|
53 |
-
"""
|
54 |
-
elif country_code in HUGE_MACRO_COUNTRIES:
|
55 |
-
table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_macro.parquet'"
|
56 |
-
sql_query = f"""
|
57 |
-
SELECT year, scenario, {indicator_column}
|
58 |
-
FROM {table_path}
|
59 |
-
WHERE latitude = {latitude} AND longitude = {longitude} AND year >= 1950
|
60 |
-
ORDER BY year, scenario
|
61 |
-
"""
|
62 |
-
else:
|
63 |
-
table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}.parquet'"
|
64 |
-
sql_query = f"""
|
65 |
-
WITH medians_per_month AS (
|
66 |
-
SELECT year, scenario, month, MEDIAN({indicator_column}) AS median_value
|
67 |
-
FROM {table_path}
|
68 |
-
WHERE latitude = {latitude} AND longitude = {longitude} AND year >= 1950
|
69 |
-
GROUP BY scenario, year, month
|
70 |
-
)
|
71 |
-
SELECT year, scenario, AVG(median_value) AS {indicator_column}
|
72 |
-
FROM medians_per_month
|
73 |
-
GROUP BY scenario, year
|
74 |
-
ORDER BY year, scenario
|
75 |
-
"""
|
76 |
-
return sql_query.strip()
|
77 |
-
|
78 |
-
class IndicatorForGivenYearQueryParams(TypedDict, total=False):
|
79 |
-
"""
|
80 |
-
Parameters for querying an indicator's values across locations for a specific year.
|
81 |
-
|
82 |
-
Attributes:
|
83 |
-
indicator_column (str): The column name for the climate indicator.
|
84 |
-
year (str): The year to query.
|
85 |
-
country_code (str): The country code.
|
86 |
-
"""
|
87 |
-
indicator_column: str
|
88 |
-
year: str
|
89 |
-
country_code: str
|
90 |
-
|
91 |
-
def indicator_for_given_year_query(
|
92 |
-
table: str, params: IndicatorForGivenYearQueryParams
|
93 |
-
) -> str:
|
94 |
-
"""
|
95 |
-
Builds an SQL query to get the values of an indicator with their latitudes, longitudes,
|
96 |
-
and scenarios for a given year.
|
97 |
-
|
98 |
-
Args:
|
99 |
-
table (str): SQL table of the indicator.
|
100 |
-
params (IndicatorForGivenYearQueryParams): Dictionary with the required params for the query.
|
101 |
-
|
102 |
-
Returns:
|
103 |
-
str: The SQL query string, or an empty string if required parameters are missing.
|
104 |
-
"""
|
105 |
-
indicator_column = params.get("indicator_column")
|
106 |
-
year = params.get("year") or 2050
|
107 |
-
country_code = params.get("country_code")
|
108 |
-
|
109 |
-
if not all([indicator_column, year, country_code]):
|
110 |
-
return ""
|
111 |
-
|
112 |
-
if country_code in MACRO_COUNTRIES:
|
113 |
-
table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_macro.parquet'"
|
114 |
-
sql_query = f"""
|
115 |
-
SELECT latitude, longitude, scenario, AVG({indicator_column}) as {indicator_column}
|
116 |
-
FROM {table_path}
|
117 |
-
WHERE year = {year}
|
118 |
-
GROUP BY latitude, longitude, scenario
|
119 |
-
ORDER BY latitude, longitude, scenario
|
120 |
-
"""
|
121 |
-
elif country_code in HUGE_MACRO_COUNTRIES:
|
122 |
-
table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_macro.parquet'"
|
123 |
-
sql_query = f"""
|
124 |
-
SELECT latitude, longitude, scenario, {indicator_column}
|
125 |
-
FROM {table_path}
|
126 |
-
WHERE year = {year}
|
127 |
-
ORDER BY latitude, longitude, scenario
|
128 |
-
"""
|
129 |
-
else:
|
130 |
-
table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}.parquet'"
|
131 |
-
sql_query = f"""
|
132 |
-
WITH medians_per_month AS (
|
133 |
-
SELECT latitude, longitude, scenario, month, MEDIAN({indicator_column}) AS median_value
|
134 |
-
FROM {table_path}
|
135 |
-
WHERE year = {year}
|
136 |
-
GROUP BY latitude, longitude, scenario, month
|
137 |
-
)
|
138 |
-
SELECT latitude, longitude, scenario, AVG(median_value) AS {indicator_column}
|
139 |
-
FROM medians_per_month
|
140 |
-
GROUP BY latitude, longitude, scenario
|
141 |
-
ORDER BY latitude, longitude, scenario
|
142 |
-
"""
|
143 |
-
|
144 |
-
return sql_query.strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/talk_to_data/main.py
DELETED
@@ -1,124 +0,0 @@
|
|
1 |
-
from climateqa.engine.talk_to_data.workflow.drias import drias_workflow
|
2 |
-
from climateqa.engine.talk_to_data.workflow.ipcc import ipcc_workflow
|
3 |
-
from climateqa.logging import log_drias_interaction_to_huggingface
|
4 |
-
|
5 |
-
async def ask_drias(query: str, index_state: int = 0, user_id: str | None = None) -> tuple:
|
6 |
-
"""Main function to process a DRIAS query and return results.
|
7 |
-
|
8 |
-
This function orchestrates the DRIAS workflow, processing a user query to generate
|
9 |
-
SQL queries, dataframes, and visualizations. It handles multiple results and allows
|
10 |
-
pagination through them.
|
11 |
-
|
12 |
-
Args:
|
13 |
-
query (str): The user's question about climate data
|
14 |
-
index_state (int, optional): The index of the result to return. Defaults to 0.
|
15 |
-
|
16 |
-
Returns:
|
17 |
-
tuple: A tuple containing:
|
18 |
-
- sql_query (str): The SQL query used
|
19 |
-
- dataframe (pd.DataFrame): The resulting data
|
20 |
-
- figure (Callable): Function to generate the visualization
|
21 |
-
- sql_queries (list): All generated SQL queries
|
22 |
-
- result_dataframes (list): All resulting dataframes
|
23 |
-
- figures (list): All figure generation functions
|
24 |
-
- index_state (int): Current result index
|
25 |
-
- table_list (list): List of table names used
|
26 |
-
- error (str): Error message if any
|
27 |
-
"""
|
28 |
-
final_state = await drias_workflow(query)
|
29 |
-
sql_queries = []
|
30 |
-
result_dataframes = []
|
31 |
-
figures = []
|
32 |
-
plot_title_list = []
|
33 |
-
plot_informations = []
|
34 |
-
|
35 |
-
for output_title, output in final_state['outputs'].items():
|
36 |
-
if output['status'] == 'OK':
|
37 |
-
if output['table'] is not None:
|
38 |
-
plot_title_list.append(output_title)
|
39 |
-
|
40 |
-
if output['plot_information'] is not None:
|
41 |
-
plot_informations.append(output['plot_information'])
|
42 |
-
|
43 |
-
if output['sql_query'] is not None:
|
44 |
-
sql_queries.append(output['sql_query'])
|
45 |
-
|
46 |
-
if output['dataframe'] is not None:
|
47 |
-
result_dataframes.append(output['dataframe'])
|
48 |
-
if output['figure'] is not None:
|
49 |
-
figures.append(output['figure'])
|
50 |
-
|
51 |
-
if "error" in final_state and final_state["error"] != "":
|
52 |
-
# No Sql query, no dataframe, no figure, no plot information, empty sql queries list, empty result dataframes list, empty figures list, empty plot information list, index state = 0, empty table list, error message
|
53 |
-
return None, None, None, None, [], [], [], 0, [], final_state["error"]
|
54 |
-
|
55 |
-
sql_query = sql_queries[index_state]
|
56 |
-
dataframe = result_dataframes[index_state]
|
57 |
-
figure = figures[index_state](dataframe)
|
58 |
-
plot_information = plot_informations[index_state]
|
59 |
-
|
60 |
-
|
61 |
-
log_drias_interaction_to_huggingface(query, sql_query, user_id)
|
62 |
-
|
63 |
-
return sql_query, dataframe, figure, plot_information, sql_queries, result_dataframes, figures, plot_informations, index_state, plot_title_list, ""
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
async def ask_ipcc(query: str, index_state: int = 0, user_id: str | None = None) -> tuple:
|
68 |
-
"""Main function to process a DRIAS query and return results.
|
69 |
-
|
70 |
-
This function orchestrates the DRIAS workflow, processing a user query to generate
|
71 |
-
SQL queries, dataframes, and visualizations. It handles multiple results and allows
|
72 |
-
pagination through them.
|
73 |
-
|
74 |
-
Args:
|
75 |
-
query (str): The user's question about climate data
|
76 |
-
index_state (int, optional): The index of the result to return. Defaults to 0.
|
77 |
-
|
78 |
-
Returns:
|
79 |
-
tuple: A tuple containing:
|
80 |
-
- sql_query (str): The SQL query used
|
81 |
-
- dataframe (pd.DataFrame): The resulting data
|
82 |
-
- figure (Callable): Function to generate the visualization
|
83 |
-
- sql_queries (list): All generated SQL queries
|
84 |
-
- result_dataframes (list): All resulting dataframes
|
85 |
-
- figures (list): All figure generation functions
|
86 |
-
- index_state (int): Current result index
|
87 |
-
- table_list (list): List of table names used
|
88 |
-
- error (str): Error message if any
|
89 |
-
"""
|
90 |
-
final_state = await ipcc_workflow(query)
|
91 |
-
sql_queries = []
|
92 |
-
result_dataframes = []
|
93 |
-
figures = []
|
94 |
-
plot_title_list = []
|
95 |
-
plot_informations = []
|
96 |
-
|
97 |
-
for output_title, output in final_state['outputs'].items():
|
98 |
-
if output['status'] == 'OK':
|
99 |
-
if output['table'] is not None:
|
100 |
-
plot_title_list.append(output_title)
|
101 |
-
|
102 |
-
if output['plot_information'] is not None:
|
103 |
-
plot_informations.append(output['plot_information'])
|
104 |
-
|
105 |
-
if output['sql_query'] is not None:
|
106 |
-
sql_queries.append(output['sql_query'])
|
107 |
-
|
108 |
-
if output['dataframe'] is not None:
|
109 |
-
result_dataframes.append(output['dataframe'])
|
110 |
-
if output['figure'] is not None:
|
111 |
-
figures.append(output['figure'])
|
112 |
-
|
113 |
-
if "error" in final_state and final_state["error"] != "":
|
114 |
-
# No Sql query, no dataframe, no figure, no plot information, empty sql queries list, empty result dataframes list, empty figures list, empty plot information list, index state = 0, empty table list, error message
|
115 |
-
return None, None, None, None, [], [], [], 0, [], final_state["error"]
|
116 |
-
|
117 |
-
sql_query = sql_queries[index_state]
|
118 |
-
dataframe = result_dataframes[index_state]
|
119 |
-
figure = figures[index_state](dataframe)
|
120 |
-
plot_information = plot_informations[index_state]
|
121 |
-
|
122 |
-
log_drias_interaction_to_huggingface(query, sql_query, user_id)
|
123 |
-
|
124 |
-
return sql_query, dataframe, figure, plot_information, sql_queries, result_dataframes, figures, plot_informations, index_state, plot_title_list, ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/talk_to_data/myVanna.py
DELETED
@@ -1,13 +0,0 @@
|
|
1 |
-
from dotenv import load_dotenv
|
2 |
-
from climateqa.engine.talk_to_data.vanna_class import MyCustomVectorDB
|
3 |
-
from vanna.openai import OpenAI_Chat
|
4 |
-
import os
|
5 |
-
|
6 |
-
load_dotenv()
|
7 |
-
|
8 |
-
OPENAI_API_KEY = os.getenv('THEO_API_KEY')
|
9 |
-
|
10 |
-
class MyVanna(MyCustomVectorDB, OpenAI_Chat):
|
11 |
-
def __init__(self, config=None):
|
12 |
-
MyCustomVectorDB.__init__(self, config=config)
|
13 |
-
OpenAI_Chat.__init__(self, config=config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/talk_to_data/objects/llm_outputs.py
DELETED
@@ -1,13 +0,0 @@
|
|
1 |
-
from typing import Annotated, TypedDict
|
2 |
-
|
3 |
-
|
4 |
-
class ArrayOutput(TypedDict):
|
5 |
-
"""Represents the output of a function that returns an array.
|
6 |
-
|
7 |
-
This class is used to type-hint functions that return arrays,
|
8 |
-
ensuring consistent return types across the codebase.
|
9 |
-
|
10 |
-
Attributes:
|
11 |
-
array (str): A syntactically valid Python array string
|
12 |
-
"""
|
13 |
-
array: Annotated[str, "Syntactically valid python array."]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/talk_to_data/objects/location.py
DELETED
@@ -1,12 +0,0 @@
|
|
1 |
-
from token import OP
|
2 |
-
from typing import Optional, TypedDict
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
class Location(TypedDict):
|
7 |
-
location: str
|
8 |
-
latitude: Optional[str]
|
9 |
-
longitude: Optional[str]
|
10 |
-
country_code: Optional[str]
|
11 |
-
country_name: Optional[str]
|
12 |
-
admin1: Optional[str]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/talk_to_data/objects/plot.py
DELETED
@@ -1,23 +0,0 @@
|
|
1 |
-
from typing import Callable, TypedDict, Optional
|
2 |
-
from plotly.graph_objects import Figure
|
3 |
-
|
4 |
-
class Plot(TypedDict):
|
5 |
-
"""Represents a plot configuration in the DRIAS system.
|
6 |
-
|
7 |
-
This class defines the structure for configuring different types of plots
|
8 |
-
that can be generated from climate data.
|
9 |
-
|
10 |
-
Attributes:
|
11 |
-
name (str): The name of the plot type
|
12 |
-
description (str): A description of what the plot shows
|
13 |
-
params (list[str]): List of required parameters for the plot
|
14 |
-
plot_function (Callable[..., Callable[..., Figure]]): Function to generate the plot
|
15 |
-
sql_query (Callable[..., str]): Function to generate the SQL query for the plot
|
16 |
-
"""
|
17 |
-
name: str
|
18 |
-
description: str
|
19 |
-
params: list[str]
|
20 |
-
plot_function: Callable[..., Callable[..., Figure]]
|
21 |
-
sql_query: Callable[..., str]
|
22 |
-
plot_information: Callable[..., str]
|
23 |
-
short_name: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/talk_to_data/objects/states.py
DELETED
@@ -1,19 +0,0 @@
|
|
1 |
-
from typing import Any, Callable, Optional, TypedDict
|
2 |
-
from plotly.graph_objects import Figure
|
3 |
-
import pandas as pd
|
4 |
-
from climateqa.engine.talk_to_data.objects.plot import Plot
|
5 |
-
|
6 |
-
class TTDOutput(TypedDict):
|
7 |
-
status: str
|
8 |
-
plot: Plot
|
9 |
-
table: str
|
10 |
-
sql_query: Optional[str]
|
11 |
-
dataframe: Optional[pd.DataFrame]
|
12 |
-
figure: Optional[Callable[..., Figure]]
|
13 |
-
plot_information: Optional[str]
|
14 |
-
class State(TypedDict):
|
15 |
-
user_input: str
|
16 |
-
plots: list[str]
|
17 |
-
outputs: dict[str, TTDOutput]
|
18 |
-
error: Optional[str]
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|