This view is limited to 50 files because it contains too many changes.  See the raw diff here.
Files changed (50) hide show
  1. .gitattributes +0 -2
  2. .gitignore +0 -13
  3. README.md +1 -1
  4. app.py +608 -486
  5. climateqa/chat.py +0 -194
  6. climateqa/constants.py +1 -59
  7. climateqa/engine/chains/__init__.py +0 -0
  8. climateqa/engine/chains/answer_ai_impact.py +0 -46
  9. climateqa/engine/chains/answer_chitchat.py +0 -56
  10. climateqa/engine/chains/chitchat_categorization.py +0 -43
  11. climateqa/engine/chains/follow_up.py +0 -33
  12. climateqa/engine/chains/graph_retriever.py +0 -130
  13. climateqa/engine/chains/intent_categorization.py +0 -97
  14. climateqa/engine/chains/keywords_extraction.py +0 -40
  15. climateqa/engine/chains/query_transformation.py +0 -300
  16. climateqa/engine/chains/retrieve_documents.py +0 -705
  17. climateqa/engine/chains/retrieve_papers.py +0 -95
  18. climateqa/engine/chains/retriever.py +0 -126
  19. climateqa/engine/chains/sample_router.py +0 -66
  20. climateqa/engine/chains/set_defaults.py +0 -13
  21. climateqa/engine/chains/standalone_question.py +0 -42
  22. climateqa/engine/chains/translation.py +0 -42
  23. climateqa/engine/embeddings.py +3 -6
  24. climateqa/engine/graph.py +0 -346
  25. climateqa/engine/graph_retriever.py +0 -88
  26. climateqa/engine/keywords.py +1 -3
  27. climateqa/engine/llm/__init__.py +0 -3
  28. climateqa/engine/llm/ollama.py +0 -6
  29. climateqa/engine/llm/openai.py +1 -1
  30. climateqa/engine/{chains/prompts.py → prompts.py} +6 -107
  31. climateqa/engine/{chains/answer_rag.py → rag.py} +60 -41
  32. climateqa/engine/{chains/reformulation.py → reformulation.py} +1 -1
  33. climateqa/engine/reranker.py +0 -55
  34. climateqa/engine/retriever.py +163 -0
  35. climateqa/engine/talk_to_data/config.py +0 -11
  36. climateqa/engine/talk_to_data/drias/config.py +0 -124
  37. climateqa/engine/talk_to_data/drias/plot_informations.py +0 -88
  38. climateqa/engine/talk_to_data/drias/plots.py +0 -434
  39. climateqa/engine/talk_to_data/drias/queries.py +0 -83
  40. climateqa/engine/talk_to_data/input_processing.py +0 -257
  41. climateqa/engine/talk_to_data/ipcc/config.py +0 -98
  42. climateqa/engine/talk_to_data/ipcc/plot_informations.py +0 -50
  43. climateqa/engine/talk_to_data/ipcc/plots.py +0 -189
  44. climateqa/engine/talk_to_data/ipcc/queries.py +0 -144
  45. climateqa/engine/talk_to_data/main.py +0 -124
  46. climateqa/engine/talk_to_data/myVanna.py +0 -13
  47. climateqa/engine/talk_to_data/objects/llm_outputs.py +0 -13
  48. climateqa/engine/talk_to_data/objects/location.py +0 -12
  49. climateqa/engine/talk_to_data/objects/plot.py +0 -23
  50. climateqa/engine/talk_to_data/objects/states.py +0 -19
.gitattributes CHANGED
@@ -44,5 +44,3 @@ documents/climate_gpt_v2_only_giec.faiss filter=lfs diff=lfs merge=lfs -text
44
  documents/climate_gpt_v2.faiss filter=lfs diff=lfs merge=lfs -text
45
  climateqa_v3.db filter=lfs diff=lfs merge=lfs -text
46
  climateqa_v3.faiss filter=lfs diff=lfs merge=lfs -text
47
- data/drias/drias.db filter=lfs diff=lfs merge=lfs -text
48
- front/assets/*.png filter=lfs diff=lfs merge=lfs -text
 
44
  documents/climate_gpt_v2.faiss filter=lfs diff=lfs merge=lfs -text
45
  climateqa_v3.db filter=lfs diff=lfs merge=lfs -text
46
  climateqa_v3.faiss filter=lfs diff=lfs merge=lfs -text
 
 
.gitignore CHANGED
@@ -5,16 +5,3 @@ __pycache__/utils.cpython-38.pyc
5
 
6
  notebooks/
7
  *.pyc
8
-
9
- **/.ipynb_checkpoints/
10
- **/.flashrank_cache/
11
-
12
- data/
13
- sandbox/
14
-
15
- climateqa/talk_to_data/database/
16
- *.db
17
-
18
- data_ingestion/
19
- .vscode
20
- *old/
 
5
 
6
  notebooks/
7
  *.pyc
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🌍
4
  colorFrom: blue
5
  colorTo: red
6
  sdk: gradio
7
- sdk_version: 5.0.2
8
  app_file: app.py
9
  fullWidth: true
10
  pinned: false
 
4
  colorFrom: blue
5
  colorTo: red
6
  sdk: gradio
7
+ sdk_version: 4.19.1
8
  app_file: app.py
9
  fullWidth: true
10
  pinned: false
app.py CHANGED
@@ -1,44 +1,52 @@
1
- # Import necessary libraries
2
- import os
3
- import gradio as gr
4
 
5
- from azure.storage.fileshare import ShareServiceClient
 
6
 
7
- # Import custom modules
8
- from climateqa.engine.embeddings import get_embeddings_function
9
- from climateqa.engine.llm import get_llm
10
- from climateqa.engine.vectorstore import get_pinecone_vectorstore
11
- from climateqa.engine.reranker import get_reranker
12
- from climateqa.engine.graph import make_graph_agent, make_graph_agent_poc
13
- from climateqa.engine.chains.retrieve_papers import find_papers
14
- from climateqa.chat import start_chat, chat_stream, finish_chat
15
 
16
- from front.tabs import create_config_modal, cqa_tab, create_about_tab
17
- from front.tabs import MainTabPanel, ConfigPanel
18
- from front.tabs.tab_drias import create_drias_tab
19
- from front.tabs.tab_ipcc import create_ipcc_tab
 
 
 
20
 
21
- from front.utils import process_figures
22
- from gradio_modal import Modal
23
 
 
 
 
 
 
24
 
25
  from utils import create_user_id
26
- import logging
27
 
28
- logging.basicConfig(level=logging.WARNING)
29
- os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Suppresses INFO and WARNING logs
30
- logging.getLogger().setLevel(logging.WARNING)
31
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  # Load environment variables in local mode
34
  try:
35
  from dotenv import load_dotenv
36
-
37
  load_dotenv()
38
  except Exception as e:
39
  pass
40
 
41
-
42
  # Set up Gradio Theme
43
  theme = gr.themes.Base(
44
  primary_hue="blue",
@@ -46,7 +54,15 @@ theme = gr.themes.Base(
46
  font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"],
47
  )
48
 
49
- # Azure Blob Storage credentials
 
 
 
 
 
 
 
 
50
  account_key = os.environ["BLOB_ACCOUNT_KEY"]
51
  if len(account_key) == 86:
52
  account_key += "=="
@@ -64,102 +80,365 @@ share_client = service.get_share_client(file_share_name)
64
  user_id = create_user_id()
65
 
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  # Create vectorstore and retriever
68
- embeddings_function = get_embeddings_function()
69
- vectorstore = get_pinecone_vectorstore(
70
- embeddings_function, index_name=os.getenv("PINECONE_API_INDEX")
71
- )
72
- vectorstore_graphs = get_pinecone_vectorstore(
73
- embeddings_function,
74
- index_name=os.getenv("PINECONE_API_INDEX_OWID"),
75
- text_key="description",
76
- )
77
- vectorstore_region = get_pinecone_vectorstore(
78
- embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_LOCAL_V2")
79
- )
80
 
81
- llm = get_llm(provider="openai", max_tokens=1024, temperature=0.0)
82
- if os.environ["GRADIO_ENV"] == "local":
83
- reranker = get_reranker("nano")
84
- else:
85
- reranker = get_reranker("nano")
86
-
87
- agent = make_graph_agent(
88
- llm=llm,
89
- vectorstore_ipcc=vectorstore,
90
- vectorstore_graphs=vectorstore_graphs,
91
- vectorstore_region=vectorstore_region,
92
- reranker=reranker,
93
- threshold_docs=0.2,
94
- )
95
- agent_poc = make_graph_agent_poc(
96
- llm=llm,
97
- vectorstore_ipcc=vectorstore,
98
- vectorstore_graphs=vectorstore_graphs,
99
- vectorstore_region=vectorstore_region,
100
- reranker=reranker,
101
- threshold_docs=0,
102
- version="v4",
103
- ) # TODO put back default 0.2
104
-
105
-
106
-
107
-
108
- async def chat(
109
- query,
110
- history,
111
- audience,
112
- sources,
113
- reports,
114
- relevant_content_sources_selection,
115
- search_only,
116
- ):
117
- print("chat cqa - message received")
118
- # Ensure default values if components are not set
119
- audience = audience or "Experts"
120
- sources = sources or ["IPCC", "IPBES"]
121
- reports = reports or []
122
- relevant_content_sources_selection = relevant_content_sources_selection or ["Figures (IPCC/IPBES)"]
123
- search_only = bool(search_only) # Convert to boolean if None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
- async for event in chat_stream(
126
- agent,
127
- query,
128
- history,
129
- audience,
130
- sources,
131
- reports,
132
- relevant_content_sources_selection,
133
- search_only,
134
- share_client,
135
- user_id,
136
- ):
137
- yield event
138
-
139
-
140
- async def chat_poc(
141
- query,
142
- history,
143
- audience,
144
- sources,
145
- reports,
146
- relevant_content_sources_selection,
147
- search_only,
148
- ):
149
- print("chat poc - message received")
150
- async for event in chat_stream(
151
- agent_poc,
152
- query,
153
- history,
154
- audience,
155
- sources,
156
- reports,
157
- relevant_content_sources_selection,
158
- search_only,
159
- share_client,
160
- user_id,
161
- ):
162
- yield event
163
 
164
 
165
  # --------------------------------------------------------------------
@@ -167,389 +446,232 @@ async def chat_poc(
167
  # --------------------------------------------------------------------
168
 
169
 
170
- # Function to update modal visibility
171
- def update_config_modal_visibility(config_open):
172
- print(config_open)
173
- new_config_visibility_status = not config_open
174
- return Modal(visible=new_config_visibility_status), new_config_visibility_status
175
-
176
-
177
- def update_sources_number_display(
178
- sources_textbox, figures_cards, current_graphs, papers_html
179
- ):
180
- sources_number = sources_textbox.count("<h2>")
181
- figures_number = figures_cards.count("<h2>")
182
- graphs_number = current_graphs.count("<iframe")
183
- papers_number = papers_html.count("<h2>")
184
- sources_notif_label = f"Sources ({sources_number})"
185
- figures_notif_label = f"Figures ({figures_number})"
186
- graphs_notif_label = f"Graphs ({graphs_number})"
187
- papers_notif_label = f"Papers ({papers_number})"
188
- recommended_content_notif_label = (
189
- f"Recommended content ({figures_number + graphs_number + papers_number})"
190
- )
191
 
192
- return (
193
- gr.update(label=recommended_content_notif_label),
194
- gr.update(label=sources_notif_label),
195
- gr.update(label=figures_notif_label),
196
- gr.update(label=graphs_notif_label),
197
- gr.update(label=papers_notif_label),
198
- )
199
 
 
 
200
 
201
- def config_event_handling(
202
- main_tabs_components: list[MainTabPanel], config_componenets: ConfigPanel
203
- ):
204
- config_open = config_componenets.config_open
205
- config_modal = config_componenets.config_modal
206
- close_config_modal = config_componenets.close_config_modal_button
207
-
208
- for button in [close_config_modal] + [
209
- main_tab_component.config_button for main_tab_component in main_tabs_components
210
- ]:
211
- button.click(
212
- fn=update_config_modal_visibility,
213
- inputs=[config_open],
214
- outputs=[config_modal, config_open],
215
- )
216
-
217
-
218
- def event_handling(
219
- main_tab_components: MainTabPanel,
220
- config_components: ConfigPanel,
221
- tab_name="ClimateQ&A",
222
- ):
223
- chatbot = main_tab_components.chatbot
224
- textbox = main_tab_components.textbox
225
- tabs = main_tab_components.tabs
226
- sources_raw = main_tab_components.sources_raw
227
- new_figures = main_tab_components.new_figures
228
- current_graphs = main_tab_components.current_graphs
229
- examples_hidden = main_tab_components.examples_hidden
230
- sources_textbox = main_tab_components.sources_textbox
231
- figures_cards = main_tab_components.figures_cards
232
- gallery_component = main_tab_components.gallery_component
233
- papers_direct_search = main_tab_components.papers_direct_search
234
- papers_html = main_tab_components.papers_html
235
- citations_network = main_tab_components.citations_network
236
- papers_summary = main_tab_components.papers_summary
237
- tab_recommended_content = main_tab_components.tab_recommended_content
238
- tab_sources = main_tab_components.tab_sources
239
- tab_figures = main_tab_components.tab_figures
240
- tab_graphs = main_tab_components.tab_graphs
241
- tab_papers = main_tab_components.tab_papers
242
- graphs_container = main_tab_components.graph_container
243
- follow_up_examples = main_tab_components.follow_up_examples
244
- follow_up_examples_hidden = main_tab_components.follow_up_examples_hidden
245
-
246
- dropdown_sources = config_components.dropdown_sources
247
- dropdown_reports = config_components.dropdown_reports
248
- dropdown_external_sources = config_components.dropdown_external_sources
249
- search_only = config_components.search_only
250
- dropdown_audience = config_components.dropdown_audience
251
- after = config_components.after
252
- output_query = config_components.output_query
253
- output_language = config_components.output_language
254
-
255
- new_sources_hmtl = gr.State([])
256
- ttd_data = gr.State([])
257
-
258
- if tab_name == "ClimateQ&A":
259
- print("chat cqa - message sent")
260
-
261
- # Event for textbox
262
- (
263
- textbox.submit(
264
- start_chat,
265
- [textbox, chatbot, search_only],
266
- [textbox, tabs, chatbot, sources_raw],
267
- queue=False,
268
- api_name=f"start_chat_{textbox.elem_id}",
269
- )
270
- .then(
271
- chat,
272
- [
273
- textbox,
274
- chatbot,
275
- dropdown_audience,
276
- dropdown_sources,
277
- dropdown_reports,
278
- dropdown_external_sources,
279
- search_only,
280
- ],
281
- [
282
- chatbot,
283
- new_sources_hmtl,
284
- output_query,
285
- output_language,
286
- new_figures,
287
- current_graphs,
288
- follow_up_examples.dataset,
289
- ],
290
- concurrency_limit=8,
291
- api_name=f"chat_{textbox.elem_id}",
292
- )
293
- .then(
294
- finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}"
295
- )
296
- )
297
- # Event for examples_hidden
298
- (
299
- examples_hidden.change(
300
- start_chat,
301
- [examples_hidden, chatbot, search_only],
302
- [examples_hidden, tabs, chatbot, sources_raw],
303
- queue=False,
304
- api_name=f"start_chat_{examples_hidden.elem_id}",
305
- )
306
- .then(
307
- chat,
308
- [
309
- examples_hidden,
310
- chatbot,
311
- dropdown_audience,
312
- dropdown_sources,
313
- dropdown_reports,
314
- dropdown_external_sources,
315
- search_only,
316
- ],
317
- [
318
- chatbot,
319
- new_sources_hmtl,
320
- output_query,
321
- output_language,
322
- new_figures,
323
- current_graphs,
324
- follow_up_examples.dataset,
325
- ],
326
- concurrency_limit=8,
327
- api_name=f"chat_{examples_hidden.elem_id}",
328
- )
329
- .then(
330
- finish_chat,
331
- None,
332
- [textbox],
333
- api_name=f"finish_chat_{examples_hidden.elem_id}",
334
- )
335
- )
336
- (
337
- follow_up_examples_hidden.change(
338
- start_chat,
339
- [follow_up_examples_hidden, chatbot, search_only],
340
- [follow_up_examples_hidden, tabs, chatbot, sources_raw],
341
- queue=False,
342
- api_name=f"start_chat_{examples_hidden.elem_id}",
343
- )
344
- .then(
345
- chat,
346
- [
347
- follow_up_examples_hidden,
348
- chatbot,
349
- dropdown_audience,
350
- dropdown_sources,
351
- dropdown_reports,
352
- dropdown_external_sources,
353
- search_only,
354
- ],
355
- [
356
- chatbot,
357
- new_sources_hmtl,
358
- output_query,
359
- output_language,
360
- new_figures,
361
- current_graphs,
362
- follow_up_examples.dataset,
363
- ],
364
- concurrency_limit=8,
365
- api_name=f"chat_{examples_hidden.elem_id}",
366
- )
367
- .then(
368
- finish_chat,
369
- None,
370
- [textbox],
371
- api_name=f"finish_chat_{follow_up_examples_hidden.elem_id}",
372
- )
373
- )
374
-
375
- elif tab_name == "France - Local Q&A":
376
- print("chat poc - message sent")
377
- # Event for textbox
378
- (
379
- textbox.submit(
380
- start_chat,
381
- [textbox, chatbot, search_only],
382
- [textbox, tabs, chatbot, sources_raw],
383
- queue=False,
384
- api_name=f"start_chat_{textbox.elem_id}",
385
- )
386
- .then(
387
- chat_poc,
388
- [
389
- textbox,
390
- chatbot,
391
- dropdown_audience,
392
- dropdown_sources,
393
- dropdown_reports,
394
- dropdown_external_sources,
395
- search_only,
396
- ],
397
- [
398
- chatbot,
399
- new_sources_hmtl,
400
- output_query,
401
- output_language,
402
- new_figures,
403
- current_graphs,
404
- follow_up_examples.dataset,
405
- ],
406
- concurrency_limit=8,
407
- api_name=f"chat_{textbox.elem_id}",
408
- )
409
- .then(
410
- finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}"
411
- )
412
- )
413
- # Event for examples_hidden
414
- (
415
- examples_hidden.change(
416
- start_chat,
417
- [examples_hidden, chatbot, search_only],
418
- [examples_hidden, tabs, chatbot, sources_raw],
419
- queue=False,
420
- api_name=f"start_chat_{examples_hidden.elem_id}",
421
- )
422
- .then(
423
- chat_poc,
424
- [
425
- examples_hidden,
426
- chatbot,
427
- dropdown_audience,
428
- dropdown_sources,
429
- dropdown_reports,
430
- dropdown_external_sources,
431
- search_only,
432
- ],
433
- [
434
- chatbot,
435
- new_sources_hmtl,
436
- output_query,
437
- output_language,
438
- new_figures,
439
- current_graphs,
440
- follow_up_examples.dataset,
441
- ],
442
- concurrency_limit=8,
443
- api_name=f"chat_{examples_hidden.elem_id}",
444
- )
445
- .then(
446
- finish_chat,
447
- None,
448
- [textbox],
449
- api_name=f"finish_chat_{examples_hidden.elem_id}",
450
- )
451
- )
452
- (
453
- follow_up_examples_hidden.change(
454
- start_chat,
455
- [follow_up_examples_hidden, chatbot, search_only],
456
- [follow_up_examples_hidden, tabs, chatbot, sources_raw],
457
- queue=False,
458
- api_name=f"start_chat_{examples_hidden.elem_id}",
459
- )
460
- .then(
461
- chat,
462
- [
463
- follow_up_examples_hidden,
464
- chatbot,
465
- dropdown_audience,
466
- dropdown_sources,
467
- dropdown_reports,
468
- dropdown_external_sources,
469
- search_only,
470
- ],
471
- [
472
- chatbot,
473
- new_sources_hmtl,
474
- output_query,
475
- output_language,
476
- new_figures,
477
- current_graphs,
478
- follow_up_examples.dataset,
479
- ],
480
- concurrency_limit=8,
481
- api_name=f"chat_{examples_hidden.elem_id}",
482
- )
483
- .then(
484
- finish_chat,
485
- None,
486
- [textbox],
487
- api_name=f"finish_chat_{follow_up_examples_hidden.elem_id}",
488
- )
489
- )
490
-
491
- new_sources_hmtl.change(
492
- lambda x: x, inputs=[new_sources_hmtl], outputs=[sources_textbox]
493
- )
494
- current_graphs.change(
495
- lambda x: x, inputs=[current_graphs], outputs=[graphs_container]
496
  )
497
- new_figures.change(
498
- process_figures,
499
- inputs=[sources_raw, new_figures],
500
- outputs=[sources_raw, figures_cards, gallery_component],
 
501
  )
502
 
503
- # Update sources numbers
504
- for component in [sources_textbox, figures_cards, current_graphs, papers_html]:
505
- component.change(
506
- update_sources_number_display,
507
- [sources_textbox, figures_cards, current_graphs, papers_html],
508
- [tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers],
509
- )
510
 
511
- # Search for papers
512
- for component in [textbox, examples_hidden, papers_direct_search]:
513
- component.submit(
514
- find_papers,
515
- [component, after, dropdown_external_sources],
516
- [papers_html, citations_network, papers_summary],
517
- )
518
 
519
- # if tab_name == "France - Local Q&A": # Not untill results are good enough
520
- # # Drias search
521
- # textbox.submit(ask_vanna, [textbox], [vanna_sql_query ,vanna_table, vanna_display])
522
 
 
523
 
524
- def main_ui():
525
- # config_open = gr.State(True)
526
- with gr.Blocks(
527
- title="Climate Q&A",
528
- css_paths=os.getcwd() + "/style.css",
529
- theme=theme,
530
- elem_id="main-component",
531
- ) as demo:
532
- config_components = create_config_modal()
533
 
534
- with gr.Tabs():
535
- cqa_components = cqa_tab(tab_name="ClimateQ&A")
536
- local_cqa_components = cqa_tab(tab_name="France - Local Q&A")
537
- drias_components = create_drias_tab(share_client=share_client, user_id=user_id)
538
- ipcc_components = create_ipcc_tab(share_client=share_client, user_id=user_id)
 
 
 
539
 
540
- create_about_tab()
 
 
 
 
 
 
 
 
 
541
 
542
- event_handling(cqa_components, config_components, tab_name="ClimateQ&A")
543
- event_handling(
544
- local_cqa_components, config_components, tab_name="France - Local Q&A"
545
- )
546
 
547
- config_event_handling([cqa_components, local_cqa_components], config_components)
 
548
 
549
- demo.queue()
550
 
551
- return demo
 
 
 
 
 
 
 
 
 
 
 
552
 
 
553
 
554
- demo = main_ui()
555
- demo.launch(ssr_mode=False)
 
1
+ from climateqa.engine.embeddings import get_embeddings_function
2
+ embeddings_function = get_embeddings_function()
 
3
 
4
+ from climateqa.papers.openalex import OpenAlex
5
+ from sentence_transformers import CrossEncoder
6
 
7
+ reranker = CrossEncoder("mixedbread-ai/mxbai-rerank-xsmall-v1")
8
+ oa = OpenAlex()
 
 
 
 
 
 
9
 
10
+ import gradio as gr
11
+ import pandas as pd
12
+ import numpy as np
13
+ import os
14
+ import time
15
+ import re
16
+ import json
17
 
18
+ # from gradio_modal import Modal
 
19
 
20
+ from io import BytesIO
21
+ import base64
22
+
23
+ from datetime import datetime
24
+ from azure.storage.fileshare import ShareServiceClient
25
 
26
  from utils import create_user_id
 
27
 
 
 
 
28
 
29
 
30
+ # ClimateQ&A imports
31
+ from climateqa.engine.llm import get_llm
32
+ from climateqa.engine.rag import make_rag_chain
33
+ from climateqa.engine.vectorstore import get_pinecone_vectorstore
34
+ from climateqa.engine.retriever import ClimateQARetriever
35
+ from climateqa.engine.embeddings import get_embeddings_function
36
+ from climateqa.engine.prompts import audience_prompts
37
+ from climateqa.sample_questions import QUESTIONS
38
+ from climateqa.constants import POSSIBLE_REPORTS
39
+ from climateqa.utils import get_image_from_azure_blob_storage
40
+ from climateqa.engine.keywords import make_keywords_chain
41
+ from climateqa.engine.rag import make_rag_papers_chain
42
+
43
  # Load environment variables in local mode
44
  try:
45
  from dotenv import load_dotenv
 
46
  load_dotenv()
47
  except Exception as e:
48
  pass
49
 
 
50
  # Set up Gradio Theme
51
  theme = gr.themes.Base(
52
  primary_hue="blue",
 
54
  font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"],
55
  )
56
 
57
+
58
+
59
+ init_prompt = ""
60
+
61
+ system_template = {
62
+ "role": "system",
63
+ "content": init_prompt,
64
+ }
65
+
66
  account_key = os.environ["BLOB_ACCOUNT_KEY"]
67
  if len(account_key) == 86:
68
  account_key += "=="
 
80
  user_id = create_user_id()
81
 
82
 
83
+
84
+ def parse_output_llm_with_sources(output):
85
+ # Split the content into a list of text and "[Doc X]" references
86
+ content_parts = re.split(r'\[(Doc\s?\d+(?:,\s?Doc\s?\d+)*)\]', output)
87
+ parts = []
88
+ for part in content_parts:
89
+ if part.startswith("Doc"):
90
+ subparts = part.split(",")
91
+ subparts = [subpart.lower().replace("doc","").strip() for subpart in subparts]
92
+ subparts = [f"""<a href="#doc{subpart}" class="a-doc-ref" target="_self"><span class='doc-ref'><sup>{subpart}</sup></span></a>""" for subpart in subparts]
93
+ parts.append("".join(subparts))
94
+ else:
95
+ parts.append(part)
96
+ content_parts = "".join(parts)
97
+ return content_parts
98
+
99
+
100
  # Create vectorstore and retriever
101
+ vectorstore = get_pinecone_vectorstore(embeddings_function)
102
+ llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
 
 
 
 
 
 
 
 
 
 
103
 
104
+
105
+ def make_pairs(lst):
106
+ """from a list of even lenght, make tupple pairs"""
107
+ return [(lst[i], lst[i + 1]) for i in range(0, len(lst), 2)]
108
+
109
+
110
+ def serialize_docs(docs):
111
+ new_docs = []
112
+ for doc in docs:
113
+ new_doc = {}
114
+ new_doc["page_content"] = doc.page_content
115
+ new_doc["metadata"] = doc.metadata
116
+ new_docs.append(new_doc)
117
+ return new_docs
118
+
119
+
120
+
121
+ async def chat(query,history,audience,sources,reports):
122
+ """taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of:
123
+ (messages in gradio format, messages in langchain format, source documents)"""
124
+
125
+ print(f">> NEW QUESTION : {query}")
126
+
127
+ if audience == "Children":
128
+ audience_prompt = audience_prompts["children"]
129
+ elif audience == "General public":
130
+ audience_prompt = audience_prompts["general"]
131
+ elif audience == "Experts":
132
+ audience_prompt = audience_prompts["experts"]
133
+ else:
134
+ audience_prompt = audience_prompts["experts"]
135
+
136
+ # Prepare default values
137
+ if len(sources) == 0:
138
+ sources = ["IPCC"]
139
+
140
+ if len(reports) == 0:
141
+ reports = []
142
+
143
+ retriever = ClimateQARetriever(vectorstore=vectorstore,sources = sources,min_size = 200,reports = reports,k_summary = 3,k_total = 15,threshold=0.5)
144
+ rag_chain = make_rag_chain(retriever,llm)
145
+
146
+ inputs = {"query": query,"audience": audience_prompt}
147
+ result = rag_chain.astream_log(inputs) #{"callbacks":[MyCustomAsyncHandler()]})
148
+ # result = rag_chain.stream(inputs)
149
+
150
+ path_reformulation = "/logs/reformulation/final_output"
151
+ path_keywords = "/logs/keywords/final_output"
152
+ path_retriever = "/logs/find_documents/final_output"
153
+ path_answer = "/logs/answer/streamed_output_str/-"
154
+
155
+ docs_html = ""
156
+ output_query = ""
157
+ output_language = ""
158
+ output_keywords = ""
159
+ gallery = []
160
+
161
+ try:
162
+ async for op in result:
163
+
164
+ op = op.ops[0]
165
+
166
+ if op['path'] == path_reformulation: # reforulated question
167
+ try:
168
+ output_language = op['value']["language"] # str
169
+ output_query = op["value"]["question"]
170
+ except Exception as e:
171
+ raise gr.Error(f"ClimateQ&A Error: {e} - The error has been noted, try another question and if the error remains, you can contact us :)")
172
+
173
+ if op["path"] == path_keywords:
174
+ try:
175
+ output_keywords = op['value']["keywords"] # str
176
+ output_keywords = " AND ".join(output_keywords)
177
+ except Exception as e:
178
+ pass
179
+
180
+
181
+ elif op['path'] == path_retriever: # documents
182
+ try:
183
+ docs = op['value']['docs'] # List[Document]
184
+ docs_html = []
185
+ for i, d in enumerate(docs, 1):
186
+ docs_html.append(make_html_source(d, i))
187
+ docs_html = "".join(docs_html)
188
+ except TypeError:
189
+ print("No documents found")
190
+ print("op: ",op)
191
+ continue
192
+
193
+ elif op['path'] == path_answer: # final answer
194
+ new_token = op['value'] # str
195
+ # time.sleep(0.01)
196
+ previous_answer = history[-1][1]
197
+ previous_answer = previous_answer if previous_answer is not None else ""
198
+ answer_yet = previous_answer + new_token
199
+ answer_yet = parse_output_llm_with_sources(answer_yet)
200
+ history[-1] = (query,answer_yet)
201
+
202
+
203
+
204
+ else:
205
+ continue
206
+
207
+ history = [tuple(x) for x in history]
208
+ yield history,docs_html,output_query,output_language,gallery,output_query,output_keywords
209
+
210
+ except Exception as e:
211
+ raise gr.Error(f"{e}")
212
+
213
+
214
+ try:
215
+ # Log answer on Azure Blob Storage
216
+ if os.getenv("GRADIO_ENV") != "local":
217
+ timestamp = str(datetime.now().timestamp())
218
+ file = timestamp + ".json"
219
+ prompt = history[-1][0]
220
+ logs = {
221
+ "user_id": str(user_id),
222
+ "prompt": prompt,
223
+ "query": prompt,
224
+ "question":output_query,
225
+ "sources":sources,
226
+ "docs":serialize_docs(docs),
227
+ "answer": history[-1][1],
228
+ "time": timestamp,
229
+ }
230
+ log_on_azure(file, logs, share_client)
231
+ except Exception as e:
232
+ print(f"Error logging on Azure Blob Storage: {e}")
233
+ raise gr.Error(f"ClimateQ&A Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)")
234
+
235
+ image_dict = {}
236
+ for i,doc in enumerate(docs):
237
+
238
+ if doc.metadata["chunk_type"] == "image":
239
+ try:
240
+ key = f"Image {i+1}"
241
+ image_path = doc.metadata["image_path"].split("documents/")[1]
242
+ img = get_image_from_azure_blob_storage(image_path)
243
+
244
+ # Convert the image to a byte buffer
245
+ buffered = BytesIO()
246
+ img.save(buffered, format="PNG")
247
+ img_str = base64.b64encode(buffered.getvalue()).decode()
248
+
249
+ # Embedding the base64 string in Markdown
250
+ markdown_image = f"![Alt text](data:image/png;base64,{img_str})"
251
+ image_dict[key] = {"img":img,"md":markdown_image,"caption":doc.page_content,"key":key,"figure_code":doc.metadata["figure_code"]}
252
+ except Exception as e:
253
+ print(f"Skipped adding image {i} because of {e}")
254
+
255
+ if len(image_dict) > 0:
256
+
257
+ gallery = [x["img"] for x in list(image_dict.values())]
258
+ img = list(image_dict.values())[0]
259
+ img_md = img["md"]
260
+ img_caption = img["caption"]
261
+ img_code = img["figure_code"]
262
+ if img_code != "N/A":
263
+ img_name = f"{img['key']} - {img['figure_code']}"
264
+ else:
265
+ img_name = f"{img['key']}"
266
+
267
+ answer_yet = history[-1][1] + f"\n\n{img_md}\n<p class='chatbot-caption'><b>{img_name}</b> - {img_caption}</p>"
268
+ history[-1] = (history[-1][0],answer_yet)
269
+ history = [tuple(x) for x in history]
270
+
271
+ # gallery = [x.metadata["image_path"] for x in docs if (len(x.metadata["image_path"]) > 0 and "IAS" in x.metadata["image_path"])]
272
+ # if len(gallery) > 0:
273
+ # gallery = list(set("|".join(gallery).split("|")))
274
+ # gallery = [get_image_from_azure_blob_storage(x) for x in gallery]
275
+
276
+ yield history,docs_html,output_query,output_language,gallery,output_query,output_keywords
277
+
278
+
279
+ def make_html_source(source,i):
280
+ meta = source.metadata
281
+ # content = source.page_content.split(":",1)[1].strip()
282
+ content = source.page_content.strip()
283
+
284
+ toc_levels = []
285
+ for j in range(2):
286
+ level = meta[f"toc_level{j}"]
287
+ if level != "N/A":
288
+ toc_levels.append(level)
289
+ else:
290
+ break
291
+ toc_levels = " > ".join(toc_levels)
292
+
293
+ if len(toc_levels) > 0:
294
+ name = f"<b>{toc_levels}</b><br/>{meta['name']}"
295
+ else:
296
+ name = meta['name']
297
+
298
+ if meta["chunk_type"] == "text":
299
+
300
+ card = f"""
301
+ <div class="card" id="doc{i}">
302
+ <div class="card-content">
303
+ <h2>Doc {i} - {meta['short_name']} - Page {int(meta['page_number'])}</h2>
304
+ <p>{content}</p>
305
+ </div>
306
+ <div class="card-footer">
307
+ <span>{name}</span>
308
+ <a href="{meta['url']}#page={int(meta['page_number'])}" target="_blank" class="pdf-link">
309
+ <span role="img" aria-label="Open PDF">🔗</span>
310
+ </a>
311
+ </div>
312
+ </div>
313
+ """
314
+
315
+ else:
316
+
317
+ if meta["figure_code"] != "N/A":
318
+ title = f"{meta['figure_code']} - {meta['short_name']}"
319
+ else:
320
+ title = f"{meta['short_name']}"
321
+
322
+ card = f"""
323
+ <div class="card card-image">
324
+ <div class="card-content">
325
+ <h2>Image {i} - {title} - Page {int(meta['page_number'])}</h2>
326
+ <p>{content}</p>
327
+ <p class='ai-generated'>AI-generated description</p>
328
+ </div>
329
+ <div class="card-footer">
330
+ <span>{name}</span>
331
+ <a href="{meta['url']}#page={int(meta['page_number'])}" target="_blank" class="pdf-link">
332
+ <span role="img" aria-label="Open PDF">🔗</span>
333
+ </a>
334
+ </div>
335
+ </div>
336
+ """
337
+
338
+ return card
339
+
340
+
341
+
342
+ # else:
343
+ # docs_string = "No relevant passages found in the climate science reports (IPCC and IPBES)"
344
+ # complete_response = "**No relevant passages found in the climate science reports (IPCC and IPBES), you may want to ask a more specific question (specifying your question on climate issues).**"
345
+ # messages.append({"role": "assistant", "content": complete_response})
346
+ # gradio_format = make_pairs([a["content"] for a in messages[1:]])
347
+ # yield gradio_format, messages, docs_string
348
+
349
+
350
+ def save_feedback(feed: str, user_id):
351
+ if len(feed) > 1:
352
+ timestamp = str(datetime.now().timestamp())
353
+ file = user_id + timestamp + ".json"
354
+ logs = {
355
+ "user_id": user_id,
356
+ "feedback": feed,
357
+ "time": timestamp,
358
+ }
359
+ log_on_azure(file, logs, share_client)
360
+ return "Feedback submitted, thank you!"
361
+
362
+
363
+
364
+
365
+ def log_on_azure(file, logs, share_client):
366
+ logs = json.dumps(logs)
367
+ file_client = share_client.get_file_client(file)
368
+ file_client.upload_file(logs)
369
+
370
+
371
+ def generate_keywords(query):
372
+ chain = make_keywords_chain(llm)
373
+ keywords = chain.invoke(query)
374
+ keywords = " AND ".join(keywords["keywords"])
375
+ return keywords
376
+
377
+
378
+
379
+ papers_cols_widths = {
380
+ "doc":50,
381
+ "id":100,
382
+ "title":300,
383
+ "doi":100,
384
+ "publication_year":100,
385
+ "abstract":500,
386
+ "rerank_score":100,
387
+ "is_oa":50,
388
+ }
389
+
390
+ papers_cols = list(papers_cols_widths.keys())
391
+ papers_cols_widths = list(papers_cols_widths.values())
392
+
393
+ async def find_papers(query, keywords,after):
394
+
395
+ summary = ""
396
+
397
+ df_works = oa.search(keywords,after = after)
398
+ df_works = df_works.dropna(subset=["abstract"])
399
+ df_works = oa.rerank(query,df_works,reranker)
400
+ df_works = df_works.sort_values("rerank_score",ascending=False)
401
+ G = oa.make_network(df_works)
402
+
403
+ height = "750px"
404
+ network = oa.show_network(G,color_by = "rerank_score",notebook=False,height = height)
405
+ network_html = network.generate_html()
406
+
407
+ network_html = network_html.replace("'", "\"")
408
+ css_to_inject = "<style>#mynetwork { border: none !important; } .card { border: none !important; }</style>"
409
+ network_html = network_html + css_to_inject
410
+
411
+
412
+ network_html = f"""<iframe style="width: 100%; height: {height};margin:0 auto" name="result" allow="midi; geolocation; microphone; camera;
413
+ display-capture; encrypted-media;" sandbox="allow-modals allow-forms
414
+ allow-scripts allow-same-origin allow-popups
415
+ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
416
+ allowpaymentrequest="" frameborder="0" srcdoc='{network_html}'></iframe>"""
417
+
418
+
419
+ docs = df_works["content"].head(15).tolist()
420
+
421
+ df_works = df_works.reset_index(drop = True).reset_index().rename(columns = {"index":"doc"})
422
+ df_works["doc"] = df_works["doc"] + 1
423
+ df_works = df_works[papers_cols]
424
+
425
+ yield df_works,network_html,summary
426
+
427
+ chain = make_rag_papers_chain(llm)
428
+ result = chain.astream_log({"question": query,"docs": docs,"language":"English"})
429
+ path_answer = "/logs/StrOutputParser/streamed_output/-"
430
+
431
+ async for op in result:
432
+
433
+ op = op.ops[0]
434
+
435
+ if op['path'] == path_answer: # reforulated question
436
+ new_token = op['value'] # str
437
+ summary += new_token
438
+ else:
439
+ continue
440
+ yield df_works,network_html,summary
441
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
442
 
443
 
444
  # --------------------------------------------------------------------
 
446
  # --------------------------------------------------------------------
447
 
448
 
449
+ init_prompt = """
450
+ Hello, I am ClimateQ&A, a conversational assistant designed to help you understand climate change and biodiversity loss. I will answer your questions by **sifting through the IPCC and IPBES scientific reports**.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
451
 
452
+ How to use
453
+ - **Language**: You can ask me your questions in any language.
454
+ - **Audience**: You can specify your audience (children, general public, experts) to get a more adapted answer.
455
+ - **Sources**: You can choose to search in the IPCC or IPBES reports, or both.
 
 
 
456
 
457
+ ⚠️ Limitations
458
+ *Please note that the AI is not perfect and may sometimes give irrelevant answers. If you are not satisfied with the answer, please ask a more specific question or report your feedback to help us improve the system.*
459
 
460
+ What do you want to learn ?
461
+ """
462
+
463
+
464
+ def vote(data: gr.LikeData):
465
+ if data.liked:
466
+ print(data.value)
467
+ else:
468
+ print(data)
469
+
470
+
471
+
472
+ with gr.Blocks(title="Climate Q&A", css="style.css", theme=theme,elem_id = "main-component") as demo:
473
+ # user_id_state = gr.State([user_id])
474
+
475
+ with gr.Tab("ClimateQ&A"):
476
+
477
+ with gr.Row(elem_id="chatbot-row"):
478
+ with gr.Column(scale=2):
479
+ # state = gr.State([system_template])
480
+ chatbot = gr.Chatbot(
481
+ value=[(None,init_prompt)],
482
+ show_copy_button=True,show_label = False,elem_id="chatbot",layout = "panel",
483
+ avatar_images = (None,"https://i.ibb.co/YNyd5W2/logo4.png"),
484
+ )#,avatar_images = ("assets/logo4.png",None))
485
+
486
+ # bot.like(vote,None,None)
487
+
488
+
489
+
490
+ with gr.Row(elem_id = "input-message"):
491
+ textbox=gr.Textbox(placeholder="Ask me anything here!",show_label=False,scale=7,lines = 1,interactive = True,elem_id="input-textbox")
492
+ # submit = gr.Button("",elem_id = "submit-button",scale = 1,interactive = True,icon = "https://static-00.iconduck.com/assets.00/settings-icon-2048x2046-cw28eevx.png")
493
+
494
+
495
+ with gr.Column(scale=1, variant="panel",elem_id = "right-panel"):
496
+
497
+
498
+ with gr.Tabs() as tabs:
499
+ with gr.TabItem("Examples",elem_id = "tab-examples",id = 0):
500
+
501
+ examples_hidden = gr.Textbox(visible = False)
502
+ first_key = list(QUESTIONS.keys())[0]
503
+ dropdown_samples = gr.Dropdown(QUESTIONS.keys(),value = first_key,interactive = True,show_label = True,label = "Select a category of sample questions",elem_id = "dropdown-samples")
504
+
505
+ samples = []
506
+ for i,key in enumerate(QUESTIONS.keys()):
507
+
508
+ examples_visible = True if i == 0 else False
509
+
510
+ with gr.Row(visible = examples_visible) as group_examples:
511
+
512
+ examples_questions = gr.Examples(
513
+ QUESTIONS[key],
514
+ [examples_hidden],
515
+ examples_per_page=8,
516
+ run_on_click=False,
517
+ elem_id=f"examples{i}",
518
+ api_name=f"examples{i}",
519
+ # label = "Click on the example question or enter your own",
520
+ # cache_examples=True,
521
+ )
522
+
523
+ samples.append(group_examples)
524
+
525
+
526
+ with gr.Tab("Sources",elem_id = "tab-citations",id = 1):
527
+ sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox")
528
+ docs_textbox = gr.State("")
529
+
530
+ # with Modal(visible = False) as config_modal:
531
+ with gr.Tab("Configuration",elem_id = "tab-config",id = 2):
532
+
533
+ gr.Markdown("Reminder: You can talk in any language, ClimateQ&A is multi-lingual!")
534
+
535
+
536
+ dropdown_sources = gr.CheckboxGroup(
537
+ ["IPCC", "IPBES","IPOS"],
538
+ label="Select source",
539
+ value=["IPCC"],
540
+ interactive=True,
541
+ )
542
+
543
+ dropdown_reports = gr.Dropdown(
544
+ POSSIBLE_REPORTS,
545
+ label="Or select specific reports",
546
+ multiselect=True,
547
+ value=None,
548
+ interactive=True,
549
+ )
550
+
551
+ dropdown_audience = gr.Dropdown(
552
+ ["Children","General public","Experts"],
553
+ label="Select audience",
554
+ value="Experts",
555
+ interactive=True,
556
+ )
557
+
558
+ output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False)
559
+ output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False)
560
+
561
+
562
+
563
+
564
+
565
+
566
+ #---------------------------------------------------------------------------------------
567
+ # OTHER TABS
568
+ #---------------------------------------------------------------------------------------
569
+
570
+
571
+ with gr.Tab("Figures",elem_id = "tab-images",elem_classes = "max-height other-tabs"):
572
+ gallery_component = gr.Gallery()
573
+
574
+ with gr.Tab("Papers (beta)",elem_id = "tab-papers",elem_classes = "max-height other-tabs"):
575
+
576
+ with gr.Row():
577
+ with gr.Column(scale=1):
578
+ query_papers = gr.Textbox(placeholder="Question",show_label=False,lines = 1,interactive = True,elem_id="query-papers")
579
+ keywords_papers = gr.Textbox(placeholder="Keywords",show_label=False,lines = 1,interactive = True,elem_id="keywords-papers")
580
+ after = gr.Slider(minimum=1950,maximum=2023,step=1,value=1960,label="Publication date",show_label=True,interactive=True,elem_id="date-papers")
581
+ search_papers = gr.Button("Search",elem_id="search-papers",interactive=True)
582
+
583
+ with gr.Column(scale=7):
584
+
585
+ with gr.Tab("Summary",elem_id="papers-summary-tab"):
586
+ papers_summary = gr.Markdown(visible=True,elem_id="papers-summary")
587
+
588
+ with gr.Tab("Relevant papers",elem_id="papers-results-tab"):
589
+ papers_dataframe = gr.Dataframe(visible=True,elem_id="papers-table",headers = papers_cols)
590
+
591
+ with gr.Tab("Citations network",elem_id="papers-network-tab"):
592
+ citations_network = gr.HTML(visible=True,elem_id="papers-citations-network")
593
+
594
+
595
+
596
+ with gr.Tab("About",elem_classes = "max-height other-tabs"):
597
+ with gr.Row():
598
+ with gr.Column(scale=1):
599
+ gr.Markdown("See more info at [https://climateqa.com](https://climateqa.com/docs/intro/)")
600
+
601
+
602
+ def start_chat(query,history):
603
+ history = history + [(query,None)]
604
+ history = [tuple(x) for x in history]
605
+ return (gr.update(interactive = False),gr.update(selected=1),history)
606
+
607
+ def finish_chat():
608
+ return (gr.update(interactive = True,value = ""))
609
+
610
+ (textbox
611
+ .submit(start_chat, [textbox,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_textbox")
612
+ .then(chat, [textbox,chatbot,dropdown_audience, dropdown_sources,dropdown_reports], [chatbot,sources_textbox,output_query,output_language,gallery_component,query_papers,keywords_papers],concurrency_limit = 8,api_name = "chat_textbox")
613
+ .then(finish_chat, None, [textbox],api_name = "finish_chat_textbox")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
614
  )
615
+
616
+ (examples_hidden
617
+ .change(start_chat, [examples_hidden,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_examples")
618
+ .then(chat, [examples_hidden,chatbot,dropdown_audience, dropdown_sources,dropdown_reports], [chatbot,sources_textbox,output_query,output_language,gallery_component,query_papers,keywords_papers],concurrency_limit = 8,api_name = "chat_examples")
619
+ .then(finish_chat, None, [textbox],api_name = "finish_chat_examples")
620
  )
621
 
 
 
 
 
 
 
 
622
 
623
+ def change_sample_questions(key):
624
+ index = list(QUESTIONS.keys()).index(key)
625
+ visible_bools = [False] * len(samples)
626
+ visible_bools[index] = True
627
+ return [gr.update(visible=visible_bools[i]) for i in range(len(samples))]
628
+
 
629
 
 
 
 
630
 
631
+ dropdown_samples.change(change_sample_questions,dropdown_samples,samples)
632
 
633
+ query_papers.submit(generate_keywords,[query_papers], [keywords_papers])
634
+ search_papers.click(find_papers,[query_papers,keywords_papers,after], [papers_dataframe,citations_network,papers_summary])
 
 
 
 
 
 
 
635
 
636
+ # # textbox.submit(predict_climateqa,[textbox,bot],[None,bot,sources_textbox])
637
+ # (textbox
638
+ # .submit(answer_user, [textbox,examples_hidden, bot], [textbox, bot],queue = False)
639
+ # .success(change_tab,None,tabs)
640
+ # .success(fetch_sources,[textbox,dropdown_sources], [textbox,sources_textbox,docs_textbox,output_query,output_language])
641
+ # .success(answer_bot, [textbox,bot,docs_textbox,output_query,output_language,dropdown_audience], [textbox,bot],queue = True)
642
+ # .success(lambda x : textbox,[textbox],[textbox])
643
+ # )
644
 
645
+ # (examples_hidden
646
+ # .change(answer_user_example, [textbox,examples_hidden, bot], [textbox, bot],queue = False)
647
+ # .success(change_tab,None,tabs)
648
+ # .success(fetch_sources,[textbox,dropdown_sources], [textbox,sources_textbox,docs_textbox,output_query,output_language])
649
+ # .success(answer_bot, [textbox,bot,docs_textbox,output_query,output_language,dropdown_audience], [textbox,bot],queue=True)
650
+ # .success(lambda x : textbox,[textbox],[textbox])
651
+ # )
652
+ # submit_button.click(answer_user, [textbox, bot], [textbox, bot], queue=True).then(
653
+ # answer_bot, [textbox,bot,dropdown_audience,dropdown_sources], [textbox,bot,sources_textbox]
654
+ # )
655
 
 
 
 
 
656
 
657
+ # with Modal(visible=True) as first_modal:
658
+ # gr.Markdown("# Welcome to ClimateQ&A !")
659
 
660
+ # gr.Markdown("### Examples")
661
 
662
+ # examples = gr.Examples(
663
+ # ["Yo ça roule","ça boume"],
664
+ # [examples_hidden],
665
+ # examples_per_page=8,
666
+ # run_on_click=False,
667
+ # elem_id="examples",
668
+ # api_name="examples",
669
+ # )
670
+
671
+
672
+ # submit.click(lambda: Modal(visible=True), None, config_modal)
673
+
674
 
675
+ demo.queue()
676
 
677
+ demo.launch()
 
climateqa/chat.py DELETED
@@ -1,194 +0,0 @@
1
- import os
2
- from datetime import datetime
3
- import gradio as gr
4
- # from .agent import agent
5
- from gradio import ChatMessage
6
- from langgraph.graph.state import CompiledStateGraph
7
- import json
8
-
9
- from .handle_stream_events import (
10
- init_audience,
11
- handle_retrieved_documents,
12
- convert_to_docs_to_html,
13
- stream_answer,
14
- handle_retrieved_owid_graphs,
15
- )
16
- from .logging import (
17
- log_interaction
18
- )
19
-
20
- # Chat functions
21
- def start_chat(query, history, search_only):
22
- history = history + [ChatMessage(role="user", content=query)]
23
- if not search_only:
24
- return (gr.update(interactive=False), gr.update(selected=1), history, [])
25
- else:
26
- return (gr.update(interactive=False), gr.update(selected=2), history, [])
27
-
28
- def finish_chat():
29
- return gr.update(interactive=True, value="")
30
-
31
- def handle_numerical_data(event):
32
- if event["name"] == "retrieve_drias_data" and event["event"] == "on_chain_end":
33
- numerical_data = event["data"]["output"]["drias_data"]
34
- sql_query = event["data"]["output"]["drias_sql_query"]
35
- return numerical_data, sql_query
36
- return None, None
37
-
38
- # Main chat function
39
- async def chat_stream(
40
- agent : CompiledStateGraph,
41
- query: str,
42
- history: list[ChatMessage],
43
- audience: str,
44
- sources: list[str],
45
- reports: list[str],
46
- relevant_content_sources_selection: list[str],
47
- search_only: bool,
48
- share_client,
49
- user_id: str
50
- ) -> tuple[list, str, str, str, list, str]:
51
- """Process a chat query and return response with relevant sources and visualizations.
52
-
53
- Args:
54
- query (str): The user's question
55
- history (list): Chat message history
56
- audience (str): Target audience type
57
- sources (list): Knowledge base sources to search
58
- reports (list): Specific reports to search within sources
59
- relevant_content_sources_selection (list): Types of content to retrieve (figures, papers, etc)
60
- search_only (bool): Whether to only search without generating answer
61
-
62
- Yields:
63
- tuple: Contains:
64
- - history: Updated chat history
65
- - docs_html: HTML of retrieved documents
66
- - output_query: Processed query
67
- - output_language: Detected language
68
- - related_contents: Related content
69
- - graphs_html: HTML of relevant graphs
70
- """
71
- # Log incoming question
72
- date_now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
73
- print(f">> NEW QUESTION ({date_now}) : {query}")
74
-
75
- audience_prompt = init_audience(audience)
76
- sources = sources or ["IPCC", "IPBES"]
77
- reports = reports or []
78
- relevant_history_discussion = history[-2:] if len(history) > 1 else []
79
-
80
- # Prepare inputs for agent
81
- inputs = {
82
- "user_input": query,
83
- "audience": audience_prompt,
84
- "sources_input": sources,
85
- "relevant_content_sources_selection": relevant_content_sources_selection,
86
- "search_only": search_only,
87
- "reports": reports,
88
- "chat_history": relevant_history_discussion,
89
- }
90
-
91
- # Get streaming events from agent
92
- result = agent.astream_events(inputs, version="v1")
93
-
94
- # Initialize state variables
95
- docs = []
96
- related_contents = []
97
- docs_html = ""
98
- new_docs_html = ""
99
- output_query = ""
100
- output_language = ""
101
- output_keywords = ""
102
- start_streaming = False
103
- graphs_html = ""
104
- used_documents = []
105
- retrieved_contents = []
106
- answer_message_content = ""
107
- vanna_data = {}
108
- follow_up_examples = gr.Dataset(samples=[])
109
-
110
- # Define processing steps
111
- steps_display = {
112
- "categorize_intent": ("🔄️ Analyzing user message", True),
113
- "transform_query": ("🔄️ Thinking step by step to answer the question", True),
114
- "retrieve_documents": ("🔄️ Searching in the knowledge base", False),
115
- "retrieve_local_data": ("🔄️ Searching in the knowledge base", False),
116
- }
117
-
118
- try:
119
- # Process streaming events
120
- async for event in result:
121
-
122
- if "langgraph_node" in event["metadata"]:
123
- node = event["metadata"]["langgraph_node"]
124
-
125
- # Handle document retrieval
126
- if event["event"] == "on_chain_end" and event["name"] in ["retrieve_documents","retrieve_local_data"] and event["data"]["output"] != None:
127
- history, used_documents, retrieved_contents = handle_retrieved_documents(
128
- event, history, used_documents, retrieved_contents
129
- )
130
- # Handle Vanna retrieval
131
- # if event["event"] == "on_chain_end" and event["name"] in ["retrieve_documents","retrieve_local_data"] and event["data"]["output"] != None:
132
- # df_output_vanna, sql_query = handle_numerical_data(
133
- # event
134
- # )
135
- # vanna_data = {"df_output": df_output_vanna, "sql_query": sql_query}
136
-
137
-
138
- if event["event"] == "on_chain_end" and event["name"] == "answer_search" :
139
- docs = event["data"]["input"]["documents"]
140
- docs_html = convert_to_docs_to_html(docs)
141
- related_contents = event["data"]["input"]["related_contents"]
142
-
143
- # Handle intent categorization
144
- elif (event["event"] == "on_chain_end" and
145
- node == "categorize_intent" and
146
- event["name"] == "_write"):
147
- intent = event["data"]["output"]["intent"]
148
- output_language = event["data"]["output"].get("language", "English")
149
- history[-1].content = f"Language identified: {output_language}\nIntent identified: {intent}"
150
-
151
- # Handle processing steps display
152
- elif event["name"] in steps_display and event["event"] == "on_chain_start":
153
- event_description, display_output = steps_display[node]
154
- if (not hasattr(history[-1], 'metadata') or
155
- history[-1].metadata["title"] != event_description):
156
- history.append(ChatMessage(
157
- role="assistant",
158
- content="",
159
- metadata={'title': event_description}
160
- ))
161
-
162
- # Handle answer streaming
163
- elif (event["name"] != "transform_query" and
164
- event["event"] == "on_chat_model_stream" and
165
- node in ["answer_rag","answer_rag_no_docs", "answer_search", "answer_chitchat"]):
166
- history, start_streaming, answer_message_content = stream_answer(
167
- history, event, start_streaming, answer_message_content
168
- )
169
-
170
- # Handle graph retrieval
171
- elif event["name"] in ["retrieve_graphs", "retrieve_graphs_ai"] and event["event"] == "on_chain_end":
172
- graphs_html = handle_retrieved_owid_graphs(event, graphs_html)
173
-
174
- # Handle query transformation
175
- if event["name"] == "transform_query" and event["event"] == "on_chain_end":
176
- if hasattr(history[-1], "content"):
177
- sub_questions = [q["question"] + "-> relevant sources : " + str(q["sources"]) for q in event["data"]["output"]["questions_list"]]
178
- history[-1].content += "Decompose question into sub-questions:\n\n - " + "\n - ".join(sub_questions)
179
-
180
- # Handle follow up questions
181
- if event["name"] == "generate_follow_up" and event["event"] == "on_chain_end":
182
- follow_up_examples = event["data"]["output"].get("follow_up_questions", [])
183
- follow_up_examples = gr.Dataset(samples= [ [question] for question in follow_up_examples ])
184
-
185
- yield history, docs_html, output_query, output_language, related_contents, graphs_html, follow_up_examples#, vanna_data
186
-
187
- except Exception as e:
188
- print(f"Event {event} has failed")
189
- raise gr.Error(str(e))
190
-
191
- # Call the function to log interaction
192
- log_interaction(history, output_query, sources, docs, share_client, user_id)
193
-
194
- yield history, docs_html, output_query, output_language, related_contents, graphs_html, follow_up_examples#, vanna_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/constants.py CHANGED
@@ -1,6 +1,4 @@
1
  POSSIBLE_REPORTS = [
2
- "IPBES IABWFH SPM",
3
- "IPBES CBL SPM",
4
  "IPCC AR6 WGI SPM",
5
  "IPCC AR6 WGI FR",
6
  "IPCC AR6 WGI TS",
@@ -44,60 +42,4 @@ POSSIBLE_REPORTS = [
44
  "IPBES IAS A C5",
45
  "IPBES IAS A C6",
46
  "IPBES IAS A SPM"
47
- ]
48
-
49
- OWID_CATEGORIES = ['Access to Energy', 'Agricultural Production',
50
- 'Agricultural Regulation & Policy', 'Air Pollution',
51
- 'Animal Welfare', 'Antibiotics', 'Biodiversity', 'Biofuels',
52
- 'Biological & Chemical Weapons', 'CO2 & Greenhouse Gas Emissions',
53
- 'COVID-19', 'Clean Water', 'Clean Water & Sanitation',
54
- 'Climate Change', 'Crop Yields', 'Diet Compositions',
55
- 'Electricity', 'Electricity Mix', 'Energy', 'Energy Efficiency',
56
- 'Energy Prices', 'Environmental Impacts of Food Production',
57
- 'Environmental Protection & Regulation', 'Famines', 'Farm Size',
58
- 'Fertilizers', 'Fish & Overfishing', 'Food Supply', 'Food Trade',
59
- 'Food Waste', 'Food and Agriculture', 'Forests & Deforestation',
60
- 'Fossil Fuels', 'Future Population Growth',
61
- 'Hunger & Undernourishment', 'Indoor Air Pollution', 'Land Use',
62
- 'Land Use & Yields in Agriculture', 'Lead Pollution',
63
- 'Meat & Dairy Production', 'Metals & Minerals',
64
- 'Natural Disasters', 'Nuclear Energy', 'Nuclear Weapons',
65
- 'Oil Spills', 'Outdoor Air Pollution', 'Ozone Layer', 'Pandemics',
66
- 'Pesticides', 'Plastic Pollution', 'Renewable Energy', 'Soil',
67
- 'Transport', 'Urbanization', 'Waste Management', 'Water Pollution',
68
- 'Water Use & Stress', 'Wildfires']
69
-
70
-
71
- DOCUMENT_METADATA_DEFAULT_VALUES = {
72
- "chunk_type": "",
73
- "document_id": "",
74
- "document_number": 0.0,
75
- "element_id": "",
76
- "figure_code": "",
77
- "file_size": "",
78
- "image_path": "",
79
- "n_pages": 0.0,
80
- "name": "",
81
- "num_characters": 0.0,
82
- "num_tokens": 0.0,
83
- "num_tokens_approx": 0.0,
84
- "num_words": 0.0,
85
- "page_number": 0,
86
- "release_date": 0.0,
87
- "report_type": "",
88
- "section_header": "",
89
- "short_name": "",
90
- "source": "",
91
- "toc_level0": "",
92
- "toc_level1": "",
93
- "toc_level2": "",
94
- "toc_level3": "",
95
- "url": "",
96
- "similarity_score": 0.0,
97
- "content": "",
98
- "reranking_score": 0.0,
99
- "query_used_for_retrieval": "",
100
- "sources_used": [""],
101
- "question_used": "",
102
- "index_used": ""
103
- }
 
1
  POSSIBLE_REPORTS = [
 
 
2
  "IPCC AR6 WGI SPM",
3
  "IPCC AR6 WGI FR",
4
  "IPCC AR6 WGI TS",
 
42
  "IPBES IAS A C5",
43
  "IPBES IAS A C6",
44
  "IPBES IAS A SPM"
45
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/chains/__init__.py DELETED
File without changes
climateqa/engine/chains/answer_ai_impact.py DELETED
@@ -1,46 +0,0 @@
1
- from langchain.prompts import ChatPromptTemplate
2
- from langchain_core.output_parsers import StrOutputParser
3
-
4
-
5
- prompt_template = """
6
- You are ClimateQ&A, an helpful AI Assistant specialized in answering climate-related questions using info from the IPCC and/or IPBES reports.
7
- Always stay true to climate and nature science and do not make up information.
8
- If you do not know the answer, just say you do not know.
9
-
10
- ## Guidelines
11
- - Explain that the environmental impact of AI is not covered by the IPCC or IPBES reports, but you can recommend info based on the sources below
12
- - Answer the question in the original language of the question
13
-
14
- ## Sources
15
- - You can propose to visit this page https://climateqa.com/docs/carbon-footprint/ to learn more about ClimateQ&A's own carbon footprint
16
- - You can recommend to look at the work of the AI & climate expert scientist Sasha Luccioni with in in particular those papers
17
- - Power Hungry Processing: Watts Driving the Cost of AI Deployment? - https://arxiv.org/abs/2311.16863 - about the carbon footprint at the inference stage of AI models
18
- - Counting Carbon: A Survey of Factors Influencing the Emissions of Machine Learning - https://arxiv.org/abs/2302.08476
19
- - Estimating the Carbon Footprint of BLOOM, a 176B Parameter Language Model - https://arxiv.org/abs/2211.02001 - about the carbon footprint of training a large language model
20
- - You can also recommend the following tools to calculate the carbon footprint of AI models
21
- - CodeCarbon - https://github.com/mlco2/codecarbon to measure the carbon footprint of your code
22
- - Ecologits - https://ecologits.ai/ to measure the carbon footprint of using LLMs APIs such
23
- """
24
-
25
-
26
- def make_ai_impact_chain(llm):
27
-
28
- prompt = ChatPromptTemplate.from_messages([
29
- ("system", prompt_template),
30
- ("user", "{question}")
31
- ])
32
-
33
- chain = prompt | llm | StrOutputParser()
34
- chain = chain.with_config({"run_name":"ai_impact_chain"})
35
-
36
- return chain
37
-
38
- def make_ai_impact_node(llm):
39
-
40
- ai_impact_chain = make_ai_impact_chain(llm)
41
-
42
- async def answer_ai_impact(state,config):
43
- answer = await ai_impact_chain.ainvoke({"question":state["user_input"]},config)
44
- return {"answer":answer}
45
-
46
- return answer_ai_impact
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/chains/answer_chitchat.py DELETED
@@ -1,56 +0,0 @@
1
- from langchain.prompts import ChatPromptTemplate
2
- from langchain_core.output_parsers import StrOutputParser
3
-
4
-
5
- chitchat_prompt_template = """
6
- You are ClimateQ&A, an helpful AI Assistant specialized in answering climate-related questions using info from the IPCC and/or IPBES reports.
7
- Always stay true to climate and nature science and do not make up information.
8
- If you do not know the answer, just say you do not know.
9
-
10
- ## Guidelines
11
- - If it's a conversational question, you can normally chat with the user
12
- - If the question is not related to any topic about the environment, refuse to answer and politely ask the user to ask another question about the environment
13
- - If the user ask if you speak any language, you can say you speak all languages :)
14
- - If the user ask about the bot itself "ClimateQ&A", you can explain that you are an AI assistant specialized in answering climate-related questions using info from the IPCC and/or IPBES reports and propose to visit the website here https://climateqa.com/docs/intro/ for more information
15
- - If the question is about ESG regulations, standards, or frameworks like the CSRD, TCFD, SASB, GRI, CDP, etc., you can explain that this is not a topic covered by the IPCC or IPBES reports.
16
- - Precise that you are specialized in finding trustworthy information from the scientific reports of the IPCC and IPBES and other scientific litterature
17
- - If relevant you can propose up to 3 example of questions they could ask from the IPCC or IPBES reports from the examples below
18
- - Always answer in the original language of the question
19
-
20
- ## Examples of questions you can suggest (in the original language of the question)
21
- "What evidence do we have of climate change?",
22
- "Are human activities causing global warming?",
23
- "What are the impacts of climate change?",
24
- "Can climate change be reversed?",
25
- "What is the difference between climate change and global warming?",
26
- """
27
-
28
-
29
- def make_chitchat_chain(llm):
30
-
31
- prompt = ChatPromptTemplate.from_messages([
32
- ("system", chitchat_prompt_template),
33
- ("user", "{question}")
34
- ])
35
-
36
- chain = prompt | llm | StrOutputParser()
37
- chain = chain.with_config({"run_name":"chitchat_chain"})
38
-
39
- return chain
40
-
41
-
42
-
43
- def make_chitchat_node(llm):
44
-
45
- chitchat_chain = make_chitchat_chain(llm)
46
-
47
- async def answer_chitchat(state,config):
48
- print("---- Answer chitchat ----")
49
-
50
- answer = await chitchat_chain.ainvoke({"question":state["user_input"]},config)
51
- state["answer"] = answer
52
- return state
53
- # return {"answer":answer}
54
-
55
- return answer_chitchat
56
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/chains/chitchat_categorization.py DELETED
@@ -1,43 +0,0 @@
1
-
2
- from langchain_core.pydantic_v1 import BaseModel, Field
3
- from typing import List
4
- from typing import Literal
5
- from langchain.prompts import ChatPromptTemplate
6
- from langchain_core.utils.function_calling import convert_to_openai_function
7
- from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
8
-
9
-
10
- class IntentCategorizer(BaseModel):
11
- """Analyzing the user message input"""
12
-
13
- environment: bool = Field(
14
- description="Return 'True' if the question relates to climate change, the environment, nature, etc. (Example: should I eat fish?). Return 'False' if the question is just chit chat or not related to the environment or climate change.",
15
- )
16
-
17
-
18
- def make_chitchat_intent_categorization_chain(llm):
19
-
20
- openai_functions = [convert_to_openai_function(IntentCategorizer)]
21
- llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"IntentCategorizer"})
22
-
23
- prompt = ChatPromptTemplate.from_messages([
24
- ("system", "You are a helpful assistant, you will analyze, translate and reformulate the user input message using the function provided"),
25
- ("user", "input: {input}")
26
- ])
27
-
28
- chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
29
- return chain
30
-
31
-
32
- def make_chitchat_intent_categorization_node(llm):
33
-
34
- categorization_chain = make_chitchat_intent_categorization_chain(llm)
35
-
36
- def categorize_message(state):
37
- output = categorization_chain.invoke({"input": state["user_input"]})
38
- print(f"\n\nChit chat output intent categorization: {output}\n")
39
- state["search_graphs_chitchat"] = output["environment"]
40
- print(f"\n\nChit chat output intent categorization: {state}\n")
41
- return state
42
-
43
- return categorize_message
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/chains/follow_up.py DELETED
@@ -1,33 +0,0 @@
1
- from typing import List
2
- from langchain.prompts import ChatPromptTemplate
3
-
4
-
5
- FOLLOW_UP_TEMPLATE = """Based on the previous question and answer, generate 2-3 relevant follow-up questions that would help explore the topic further.
6
-
7
- Previous Question: {user_input}
8
- Previous Answer: {answer}
9
-
10
- Generate short, concise, focused follow-up questions
11
- You don't need a full question as it will be reformulated later as a standalone question with the context. Eg. "Details the first point"
12
- """
13
-
14
- def make_follow_up_node(llm):
15
- prompt = ChatPromptTemplate.from_template(FOLLOW_UP_TEMPLATE)
16
-
17
- def generate_follow_up(state):
18
- print("---- Generate_follow_up ----")
19
- if not state.get("answer"):
20
- return state
21
-
22
- response = llm.invoke(prompt.format(
23
- user_input=state["user_input"],
24
- answer=state["answer"]
25
- ))
26
-
27
- # Extract questions from response
28
- follow_ups = [q.strip() for q in response.content.split("\n") if q.strip()]
29
- state["follow_up_questions"] = follow_ups
30
-
31
- return state
32
-
33
- return generate_follow_up
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/chains/graph_retriever.py DELETED
@@ -1,130 +0,0 @@
1
- import sys
2
- import os
3
- from contextlib import contextmanager
4
-
5
- from ..reranker import rerank_docs
6
- from ..graph_retriever import retrieve_graphs # GraphRetriever
7
- from ...utils import remove_duplicates_keep_highest_score
8
-
9
-
10
- def divide_into_parts(target, parts):
11
- # Base value for each part
12
- base = target // parts
13
- # Remainder to distribute
14
- remainder = target % parts
15
- # List to hold the result
16
- result = []
17
-
18
- for i in range(parts):
19
- if i < remainder:
20
- # These parts get base value + 1
21
- result.append(base + 1)
22
- else:
23
- # The rest get the base value
24
- result.append(base)
25
-
26
- return result
27
-
28
-
29
- @contextmanager
30
- def suppress_output():
31
- # Open a null device
32
- with open(os.devnull, 'w') as devnull:
33
- # Store the original stdout and stderr
34
- old_stdout = sys.stdout
35
- old_stderr = sys.stderr
36
- # Redirect stdout and stderr to the null device
37
- sys.stdout = devnull
38
- sys.stderr = devnull
39
- try:
40
- yield
41
- finally:
42
- # Restore stdout and stderr
43
- sys.stdout = old_stdout
44
- sys.stderr = old_stderr
45
-
46
-
47
- def make_graph_retriever_node(vectorstore, reranker, rerank_by_question=True, k_final=15, k_before_reranking=100):
48
-
49
- async def node_retrieve_graphs(state):
50
- print("---- Retrieving graphs ----")
51
-
52
- POSSIBLE_SOURCES = ["IEA", "OWID"]
53
- # questions = state["remaining_questions"] if state["remaining_questions"] is not None and state["remaining_questions"]!=[] else [state["query"]]
54
- questions = state["questions_list"] if state["questions_list"] is not None and state["questions_list"]!=[] else [state["query"]]
55
-
56
- # sources_input = state["sources_input"]
57
- sources_input = ["auto"]
58
-
59
- auto_mode = "auto" in sources_input
60
-
61
- # There are several options to get the final top k
62
- # Option 1 - Get 100 documents by question and rerank by question
63
- # Option 2 - Get 100/n documents by question and rerank the total
64
- if rerank_by_question:
65
- k_by_question = divide_into_parts(k_final,len(questions))
66
-
67
- docs = []
68
-
69
- for i,q in enumerate(questions):
70
-
71
- question = q["question"] if isinstance(q, dict) else q
72
-
73
- print(f"Subquestion {i}: {question}")
74
-
75
- # If auto mode, we use all sources
76
- if auto_mode:
77
- sources = POSSIBLE_SOURCES
78
- # Otherwise, we use the config
79
- else:
80
- sources = sources_input
81
-
82
- if any([x in POSSIBLE_SOURCES for x in sources]):
83
-
84
- sources = [x for x in sources if x in POSSIBLE_SOURCES]
85
-
86
- # Search the document store using the retriever
87
- docs_question = await retrieve_graphs(
88
- query = question,
89
- vectorstore = vectorstore,
90
- sources = sources,
91
- k_total = k_before_reranking,
92
- threshold = 0.5,
93
- )
94
- # docs_question = retriever.get_relevant_documents(question)
95
-
96
- # Rerank
97
- if reranker is not None and docs_question!=[]:
98
- with suppress_output():
99
- docs_question = rerank_docs(reranker,docs_question,question)
100
- else:
101
- # Add a default reranking score
102
- for doc in docs_question:
103
- doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
104
-
105
- # If rerank by question we select the top documents for each question
106
- if rerank_by_question:
107
- docs_question = docs_question[:k_by_question[i]]
108
-
109
- # Add sources used in the metadata
110
- for doc in docs_question:
111
- doc.metadata["sources_used"] = sources
112
-
113
- print(f"{len(docs_question)} graphs retrieved for subquestion {i + 1}: {docs_question}")
114
-
115
- docs.extend(docs_question)
116
-
117
- else:
118
- print(f"There are no graphs which match the sources filtered on. Sources filtered on: {sources}. Sources available: {POSSIBLE_SOURCES}.")
119
-
120
- # Remove duplicates and keep the duplicate document with the highest reranking score
121
- docs = remove_duplicates_keep_highest_score(docs)
122
-
123
- # Sorting the list in descending order by rerank_score
124
- # Then select the top k
125
- docs = sorted(docs, key=lambda x: x.metadata["reranking_score"], reverse=True)
126
- docs = docs[:k_final]
127
-
128
- return {"recommended_content": docs}
129
-
130
- return node_retrieve_graphs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/chains/intent_categorization.py DELETED
@@ -1,97 +0,0 @@
1
- from langchain_core.pydantic_v1 import BaseModel, Field
2
- from typing import List
3
- from typing import Literal
4
- from langchain.prompts import ChatPromptTemplate
5
- from langchain_core.utils.function_calling import convert_to_openai_function
6
- from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
7
-
8
-
9
- class IntentCategorizer(BaseModel):
10
- """Analyzing the user message input"""
11
-
12
- language: str = Field(
13
- description="Find the language of the message input in full words (ex: French, English, Spanish, ...), defaults to English",
14
- default="English",
15
- )
16
- intent: str = Field(
17
- enum=[
18
- "ai_impact",
19
- # "geo_info",
20
- # "esg",
21
- "search",
22
- "chitchat",
23
- ],
24
- description="""
25
- Categorize the user input in one of the following category
26
- Any question
27
-
28
- Examples:
29
- - ai_impact = Environmental impacts of AI: "What are the environmental impacts of AI", "How does AI affect the environment"
30
- - search = Searching for any quesiton about climate change, energy, biodiversity, nature, and everything we can find the IPCC or IPBES reports or scientific papers,
31
- - chitchat = Any general question that is not related to the environment or climate change or just conversational, or if you don't think searching the IPCC or IPBES reports would be relevant
32
- """,
33
- # - geo_info = Geolocated info about climate change: Any question where the user wants to know localized impacts of climate change, eg: "What will be the temperature in Marseille in 2050"
34
- # - esg = Any question about the ESG regulation, frameworks and standards like the CSRD, TCFD, SASB, GRI, CDP, etc.
35
-
36
- )
37
-
38
-
39
-
40
- def make_intent_categorization_chain(llm):
41
-
42
- openai_functions = [convert_to_openai_function(IntentCategorizer)]
43
- llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"IntentCategorizer"})
44
-
45
- prompt = ChatPromptTemplate.from_messages([
46
- ("system", "You are a helpful assistant, you will analyze, detect the language, and categorize the user input message using the function provided. You MUST detect and return the language of the input message. Categorize the user input as ai ONLY if it is related to Artificial Intelligence, search if it is related to the environment, climate change, energy, biodiversity, nature, etc. and chitchat if it is just general conversation."),
47
- ("user", "input: {input}")
48
- ])
49
-
50
- chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
51
- return chain
52
-
53
-
54
- def make_intent_categorization_node(llm):
55
-
56
- categorization_chain = make_intent_categorization_chain(llm)
57
-
58
- def categorize_message(state):
59
- print("---- Categorize_message ----")
60
- print(f"Input state: {state}")
61
-
62
- output = categorization_chain.invoke({"input": state["user_input"]})
63
- print(f"\n\nRaw output from categorization: {output}\n")
64
-
65
- if "language" not in output:
66
- print("WARNING: Language field missing from output, setting default to English")
67
- output["language"] = "English"
68
- else:
69
- print(f"Language detected: {output['language']}")
70
-
71
- output["query"] = state["user_input"]
72
- print(f"Final output: {output}")
73
- return output
74
-
75
- return categorize_message
76
-
77
-
78
-
79
-
80
- # SAMPLE_QUESTIONS = [
81
- # "Est-ce que l'IA a un impact sur l'environnement ?",
82
- # "Que dit le GIEC sur l'impact de l'IA",
83
- # "Qui sont les membres du GIEC",
84
- # "What is the impact of El Nino ?",
85
- # "Yo",
86
- # "Hello ça va bien ?",
87
- # "Par qui as tu été créé ?",
88
- # "What role do cloud formations play in modulating the Earth's radiative balance, and how are they represented in current climate models?",
89
- # "Which industries have the highest GHG emissions?",
90
- # "What are invasive alien species and how do they threaten biodiversity and ecosystems?",
91
- # "Are human activities causing global warming?",
92
- # "What is the motivation behind mining the deep seabed?",
93
- # "Tu peux m'écrire un poème sur le changement climatique ?",
94
- # "Tu peux m'écrire un poème sur les bonbons ?",
95
- # "What will be the temperature in 2100 in Strasbourg?",
96
- # "C'est quoi le lien entre biodiversity and changement climatique ?",
97
- # ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/chains/keywords_extraction.py DELETED
@@ -1,40 +0,0 @@
1
-
2
- from langchain_core.pydantic_v1 import BaseModel, Field
3
- from typing import List
4
- from typing import Literal
5
- from langchain.prompts import ChatPromptTemplate
6
- from langchain_core.utils.function_calling import convert_to_openai_function
7
- from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
8
-
9
-
10
- class KeywordExtraction(BaseModel):
11
- """
12
- Analyzing the user query to extract keywords to feed a search engine
13
- """
14
-
15
- keywords: List[str] = Field(
16
- description="""
17
- Extract the keywords from the user query to feed a search engine as a list
18
- Avoid adding super specific keywords to prefer general keywords
19
- Maximum 3 keywords
20
-
21
- Examples:
22
- - "What is the impact of deep sea mining ?" -> ["deep sea mining"]
23
- - "How will El Nino be impacted by climate change" -> ["el nino","climate change"]
24
- - "Is climate change a hoax" -> ["climate change","hoax"]
25
- """
26
- )
27
-
28
-
29
- def make_keywords_extraction_chain(llm):
30
-
31
- openai_functions = [convert_to_openai_function(KeywordExtraction)]
32
- llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"KeywordExtraction"})
33
-
34
- prompt = ChatPromptTemplate.from_messages([
35
- ("system", "You are a helpful assistant"),
36
- ("user", "input: {input}")
37
- ])
38
-
39
- chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
40
- return chain
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/chains/query_transformation.py DELETED
@@ -1,300 +0,0 @@
1
-
2
-
3
- from langchain_core.pydantic_v1 import BaseModel, Field
4
- from typing import List
5
- from typing import Literal
6
- from langchain.prompts import ChatPromptTemplate
7
- from langchain_core.utils.function_calling import convert_to_openai_function
8
- from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
9
-
10
- # OLD QUERY ANALYSIS
11
- # keywords: List[str] = Field(
12
- # description="""
13
- # Extract the keywords from the user query to feed a search engine as a list
14
- # Maximum 3 keywords
15
-
16
- # Examples:
17
- # - "What is the impact of deep sea mining ?" -> deep sea mining
18
- # - "How will El Nino be impacted by climate change" -> el nino;climate change
19
- # - "Is climate change a hoax" -> climate change;hoax
20
- # """
21
- # )
22
-
23
- # alternative_queries: List[str] = Field(
24
- # description="""
25
- # Generate alternative search questions from the user query to feed a search engine
26
- # """
27
- # )
28
-
29
- # step_back_question: str = Field(
30
- # description="""
31
- # You are an expert at world knowledge. Your task is to step back and paraphrase a question to a more generic step-back question, which is easier to answer.
32
- # This questions should help you get more context and information about the user query
33
- # """
34
- # )
35
- # - OpenAlex is for any other questions that are not in the previous categories but could be found in the scientific litterature
36
- #
37
-
38
-
39
- # topics: List[Literal[
40
- # "Climate change",
41
- # "Biodiversity",
42
- # "Energy",
43
- # "Decarbonization",
44
- # "Climate science",
45
- # "Nature",
46
- # "Climate policy and justice",
47
- # "Oceans",
48
- # "Deep sea mining",
49
- # "ESG and regulations",
50
- # "CSRD",
51
- # ]] = Field(
52
- # ...,
53
- # description = """
54
- # Choose the topics that are most relevant to the user query, ex: Climate change, Energy, Biodiversity, ...
55
- # """,
56
- # )
57
- # date: str = Field(description="The date or period mentioned, ex: 2050, between 2020 and 2050")
58
- # location:Location
59
-
60
-
61
-
62
- ROUTING_INDEX = {
63
- "IPx":["IPCC", "IPBES", "IPOS"],
64
- "POC": ["AcclimaTerra", "PCAET","Biodiv"],
65
- "OpenAlex":["OpenAlex"],
66
- }
67
-
68
- POSSIBLE_SOURCES = [y for values in ROUTING_INDEX.values() for y in values]
69
-
70
- # Prompt from the original paper https://arxiv.org/pdf/2305.14283
71
- # Query Rewriting for Retrieval-Augmented Large Language Models
72
- class QueryDecomposition(BaseModel):
73
- """
74
- Decompose the user query into smaller parts to think step by step to answer this question
75
- Act as a simple planning agent
76
- """
77
-
78
- questions: List[str] = Field(
79
- description="""
80
- Think step by step to answer this question, and provide one or several search engine questions in the provided language for knowledge that you need.
81
- Suppose that the user is looking for information about climate change, energy, biodiversity, nature, and everything we can find the IPCC reports and scientific literature
82
- - If it's already a standalone and explicit question, just return the reformulated question for the search engine
83
- - If you need to decompose the question, output a list of maximum 2 to 3 questions
84
- """
85
- )
86
-
87
-
88
- class Location(BaseModel):
89
- country:str = Field(...,description="The country if directly mentioned or inferred from the location (cities, regions, adresses), ex: France, USA, ...")
90
- location:str = Field(...,description="The specific place if mentioned (cities, regions, addresses), ex: Marseille, New York, Wisconsin, ...")
91
-
92
- class QueryTranslation(BaseModel):
93
- """Translate the query into a given language"""
94
-
95
- question : str = Field(
96
- description="""
97
- Translate the questions into the given language
98
- If the question is alrealdy in the given language, just return the same question
99
- """,
100
- )
101
-
102
-
103
- class QueryAnalysis(BaseModel):
104
- """
105
- Analyze the user query to extract the relevant sources
106
-
107
- Deprecated:
108
- Analyzing the user query to extract topics, sources and date
109
- Also do query expansion to get alternative search queries
110
- Also provide simple keywords to feed a search engine
111
- """
112
-
113
- sources: List[Literal["IPCC", "IPBES", "IPOS", "AcclimaTerra", "PCAET","Biodiv"]] = Field( #,"OpenAlex"]] = Field(
114
- ...,
115
- description="""
116
- Given a user question choose which documents would be most relevant for answering their question,
117
- - IPCC is for questions about climate change, energy, impacts, and everything we can find the IPCC reports
118
- - IPBES is for questions about biodiversity and nature
119
- - IPOS is for questions about the ocean and deep sea mining
120
- - AcclimaTerra is for questions about any specific place in, or close to, the french region "Nouvelle-Aquitaine"
121
- - PCAET is the Plan Climat Eneregie Territorial for the city of Paris
122
- - Biodiv is the Biodiversity plan for the city of Paris
123
- """,
124
- )
125
-
126
-
127
-
128
- def make_query_decomposition_chain(llm):
129
- """Chain to decompose a query into smaller parts to think step by step to answer this question
130
-
131
- Args:
132
- llm (_type_): _description_
133
-
134
- Returns:
135
- _type_: _description_
136
- """
137
-
138
- openai_functions = [convert_to_openai_function(QueryDecomposition)]
139
- llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"QueryDecomposition"})
140
-
141
- prompt = ChatPromptTemplate.from_messages([
142
- ("system", "You are a helpful assistant, you will analyze, translate and reformulate the user input message using the function provided"),
143
- ("user", "input: {input}")
144
- ])
145
-
146
- chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
147
- return chain
148
-
149
-
150
- def make_query_analysis_chain(llm):
151
- """Analyze the user query to extract the relevant sources"""
152
-
153
- openai_functions = [convert_to_openai_function(QueryAnalysis)]
154
- llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"QueryAnalysis"})
155
-
156
-
157
-
158
- prompt = ChatPromptTemplate.from_messages([
159
- ("system", "You are a helpful assistant, you will analyze the user input message using the function provided"),
160
- ("user", "input: {input}")
161
- ])
162
-
163
-
164
- chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
165
- return chain
166
-
167
-
168
- def make_query_translation_chain(llm):
169
- """Analyze the user query to extract the relevant sources"""
170
-
171
- openai_functions = [convert_to_openai_function(QueryTranslation)]
172
- llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"QueryTranslation"})
173
-
174
-
175
-
176
- prompt = ChatPromptTemplate.from_messages([
177
- ("system", "You are a helpful assistant, translate the question into {language}"),
178
- ("user", "input: {input}")
179
- ])
180
-
181
-
182
- chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
183
- return chain
184
-
185
- def group_by_sources_types(sources):
186
- sources_types = {}
187
- IPx_sources = ["IPCC", "IPBES", "IPOS"]
188
- local_sources = ["AcclimaTerra", "PCAET","Biodiv"]
189
- if any(source in IPx_sources for source in sources):
190
- sources_types["IPx"] = list(set(sources).intersection(IPx_sources))
191
- if any(source in local_sources for source in sources):
192
- sources_types["POC"] = list(set(sources).intersection(local_sources))
193
- return sources_types
194
-
195
-
196
- def make_query_transform_node(llm,k_final=15):
197
- """
198
- Creates a query transformation node that processes and transforms a given query state.
199
- Args:
200
- llm: The language model to be used for query decomposition and rewriting.
201
- k_final (int, optional): The final number of questions to be generated. Defaults to 15.
202
- Returns:
203
- function: A function that takes a query state and returns a transformed state.
204
- The returned function performs the following steps:
205
- 1. Checks if the query should be processed in auto mode based on the state.
206
- 2. Retrieves the input sources from the state or defaults to a predefined routing index.
207
- 3. Decomposes the query using the decomposition chain.
208
- 4. Analyzes each decomposed question using the rewriter chain.
209
- 5. Ensures that the sources returned by the language model are valid.
210
- 6. Explodes the questions into multiple questions with different sources based on the mode.
211
- 7. Constructs a new state with the transformed questions and their respective sources.
212
- """
213
-
214
-
215
- decomposition_chain = make_query_decomposition_chain(llm)
216
- query_analysis_chain = make_query_analysis_chain(llm)
217
- query_translation_chain = make_query_translation_chain(llm)
218
-
219
- def transform_query(state):
220
- print("---- Transform query ----")
221
-
222
- auto_mode = state.get("sources_auto", True)
223
- sources_input = state.get("sources_input", ROUTING_INDEX["IPx"])
224
-
225
-
226
- new_state = {}
227
-
228
- # Decomposition
229
- decomposition_output = decomposition_chain.invoke({"input":state["query"]})
230
- new_state.update(decomposition_output)
231
-
232
-
233
- # Query Analysis
234
- questions = []
235
- for question in new_state["questions"]:
236
- question_state = {"question":question}
237
- query_analysis_output = query_analysis_chain.invoke({"input":question})
238
-
239
- # TODO WARNING llm should always return smthg
240
- # The case when the llm does not return any sources or wrong ouput
241
- if not query_analysis_output["sources"] or not all(source in ["IPCC", "IPBS", "IPOS","AcclimaTerra", "PCAET","Biodiv"] for source in query_analysis_output["sources"]):
242
- query_analysis_output["sources"] = ["IPCC", "IPBES", "IPOS"]
243
-
244
- sources_types = group_by_sources_types(query_analysis_output["sources"])
245
- for source_type,sources in sources_types.items():
246
- question_state = {
247
- "question":question,
248
- "sources":sources,
249
- "source_type":source_type
250
- }
251
-
252
- questions.append(question_state)
253
-
254
- # Translate question into the document language
255
- for q in questions:
256
- if q["source_type"]=="IPx":
257
- translation_output = query_translation_chain.invoke({"input":q["question"],"language":"English"})
258
- q["question"] = translation_output["question"]
259
- elif q["source_type"]=="POC":
260
- translation_output = query_translation_chain.invoke({"input":q["question"],"language":"French"})
261
- q["question"] = translation_output["question"]
262
-
263
- # Explode the questions into multiple questions with different sources
264
- new_questions = []
265
- for q in questions:
266
- question,sources,source_type = q["question"],q["sources"], q["source_type"]
267
-
268
- # If not auto mode we take the configuration
269
- if not auto_mode:
270
- sources = sources_input
271
-
272
- for index,index_sources in ROUTING_INDEX.items():
273
- selected_sources = list(set(sources).intersection(index_sources))
274
- if len(selected_sources) > 0:
275
- new_questions.append({"question":question,"sources":selected_sources,"index":index, "source_type":source_type})
276
-
277
- # # Add the number of questions to search
278
- # k_by_question = k_final // len(new_questions)
279
- # for q in new_questions:
280
- # q["k"] = k_by_question
281
-
282
- # new_state["questions"] = new_questions
283
- # new_state["remaining_questions"] = new_questions
284
-
285
- n_questions = {
286
- "total":len(new_questions),
287
- "IPx":len([q for q in new_questions if q["index"] == "IPx"]),
288
- "POC":len([q for q in new_questions if q["index"] == "POC"]),
289
- }
290
-
291
- new_state = {
292
- "questions_list":new_questions,
293
- "n_questions":n_questions,
294
- "handled_questions_index":[],
295
- }
296
- print("New questions")
297
- print(new_questions)
298
- return new_state
299
-
300
- return transform_query
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/chains/retrieve_documents.py DELETED
@@ -1,705 +0,0 @@
1
- import sys
2
- import os
3
- from contextlib import contextmanager
4
-
5
- from langchain_core.tools import tool
6
- from langchain_core.runnables import chain
7
- from langchain_core.runnables import RunnableParallel, RunnablePassthrough
8
- from langchain_core.runnables import RunnableLambda
9
-
10
- from ..reranker import rerank_docs, rerank_and_sort_docs
11
- # from ...knowledge.retriever import ClimateQARetriever
12
- from ...knowledge.openalex import OpenAlexRetriever
13
- from .keywords_extraction import make_keywords_extraction_chain
14
- from ..utils import log_event
15
- from langchain_core.vectorstores import VectorStore
16
- from typing import List
17
- from langchain_core.documents.base import Document
18
- from ..llm import get_llm
19
- from .prompts import retrieve_chapter_prompt_template
20
- from langchain_core.prompts import ChatPromptTemplate
21
- from langchain_core.output_parsers import StrOutputParser
22
- from ..vectorstore import get_pinecone_vectorstore
23
- from ..embeddings import get_embeddings_function
24
- import ast
25
-
26
- import asyncio
27
-
28
- from typing import Any, Dict, List, Tuple
29
-
30
-
31
- def divide_into_parts(target, parts):
32
- # Base value for each part
33
- base = target // parts
34
- # Remainder to distribute
35
- remainder = target % parts
36
- # List to hold the result
37
- result = []
38
-
39
- for i in range(parts):
40
- if i < remainder:
41
- # These parts get base value + 1
42
- result.append(base + 1)
43
- else:
44
- # The rest get the base value
45
- result.append(base)
46
-
47
- return result
48
-
49
-
50
- @contextmanager
51
- def suppress_output():
52
- # Open a null device
53
- with open(os.devnull, 'w') as devnull:
54
- # Store the original stdout and stderr
55
- old_stdout = sys.stdout
56
- old_stderr = sys.stderr
57
- # Redirect stdout and stderr to the null device
58
- sys.stdout = devnull
59
- sys.stderr = devnull
60
- try:
61
- yield
62
- finally:
63
- # Restore stdout and stderr
64
- sys.stdout = old_stdout
65
- sys.stderr = old_stderr
66
-
67
-
68
- @tool
69
- def query_retriever(question):
70
- """Just a dummy tool to simulate the retriever query"""
71
- return question
72
-
73
- def _add_sources_used_in_metadata(docs,sources,question,index):
74
- for doc in docs:
75
- doc.metadata["sources_used"] = sources
76
- doc.metadata["question_used"] = question
77
- doc.metadata["index_used"] = index
78
- return docs
79
-
80
- def _get_k_summary_by_question(n_questions):
81
- if n_questions == 0:
82
- return 0
83
- elif n_questions == 1:
84
- return 5
85
- elif n_questions == 2:
86
- return 3
87
- elif n_questions == 3:
88
- return 2
89
- else:
90
- return 1
91
-
92
- def _get_k_images_by_question(n_questions):
93
- if n_questions == 0:
94
- return 0
95
- elif n_questions == 1:
96
- return 7
97
- elif n_questions == 2:
98
- return 5
99
- elif n_questions == 3:
100
- return 3
101
- else:
102
- return 1
103
-
104
- def _add_metadata_and_score(docs: List) -> Document:
105
- # Add score to metadata
106
- docs_with_metadata = []
107
- for i,(doc,score) in enumerate(docs):
108
- doc.page_content = doc.page_content.replace("\r\n"," ")
109
- doc.metadata["similarity_score"] = score
110
- doc.metadata["content"] = doc.page_content
111
- if doc.metadata["page_number"] != "N/A":
112
- doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
113
- else:
114
- doc.metadata["page_number"] = 1
115
- # doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
116
- docs_with_metadata.append(doc)
117
- return docs_with_metadata
118
-
119
- def remove_duplicates_chunks(docs):
120
- # Remove duplicates or almost duplicates
121
- docs = sorted(docs,key=lambda x: x[1],reverse=True)
122
- seen = set()
123
- result = []
124
- for doc in docs:
125
- if doc[0].page_content not in seen:
126
- seen.add(doc[0].page_content)
127
- result.append(doc)
128
- return result
129
-
130
- def get_ToCs(version: str) :
131
-
132
- filters_text = {
133
- "chunk_type":"toc",
134
- "version": version
135
- }
136
- embeddings_function = get_embeddings_function()
137
- vectorstore = get_pinecone_vectorstore(embeddings_function, index_name="climateqa-v2")
138
- tocs = vectorstore.similarity_search_with_score(query="",filter = filters_text)
139
-
140
- # remove duplicates or almost duplicates
141
- tocs = remove_duplicates_chunks(tocs)
142
-
143
- return tocs
144
-
145
- async def get_POC_relevant_documents(
146
- query: str,
147
- vectorstore:VectorStore,
148
- sources:list = ["Acclimaterra","PCAET","Plan Biodiversite"],
149
- search_figures:bool = False,
150
- search_only:bool = False,
151
- k_documents:int = 10,
152
- threshold:float = 0.6,
153
- k_images: int = 5,
154
- reports:list = [],
155
- min_size:int = 200,
156
- ) :
157
- # Prepare base search kwargs
158
- filters = {}
159
- docs_question = []
160
- docs_images = []
161
-
162
- # TODO add source selection
163
- # if len(reports) > 0:
164
- # filters["short_name"] = {"$in":reports}
165
- # else:
166
- # filters["source"] = { "$in": sources}
167
-
168
- filters_text = {
169
- **filters,
170
- "chunk_type":"text",
171
- # "report_type": {}, # TODO to be completed to choose the right documents / chapters according to the analysis of the question
172
- }
173
-
174
- docs_question = vectorstore.similarity_search_with_score(query=query,filter = filters_text,k = k_documents)
175
- # remove duplicates or almost duplicates
176
- docs_question = remove_duplicates_chunks(docs_question)
177
- docs_question = [x for x in docs_question if x[1] > threshold]
178
-
179
- if search_figures:
180
- # Images
181
- filters_image = {
182
- **filters,
183
- "chunk_type":"image"
184
- }
185
- docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images)
186
-
187
- docs_question, docs_images = _add_metadata_and_score(docs_question), _add_metadata_and_score(docs_images)
188
-
189
- docs_question = [x for x in docs_question if len(x.page_content) > min_size]
190
-
191
- return {
192
- "docs_question" : docs_question,
193
- "docs_images" : docs_images
194
- }
195
-
196
- async def get_POC_documents_by_ToC_relevant_documents(
197
- query: str,
198
- tocs: list,
199
- vectorstore:VectorStore,
200
- version: str,
201
- sources:list = ["Acclimaterra","PCAET","Plan Biodiversite"],
202
- search_figures:bool = False,
203
- search_only:bool = False,
204
- k_documents:int = 10,
205
- threshold:float = 0.6,
206
- k_images: int = 5,
207
- reports:list = [],
208
- min_size:int = 200,
209
- proportion: float = 0.5,
210
- ) :
211
- """
212
- Args:
213
- - tocs : list with the table of contents of each document
214
- - version : version of the parsed documents (e.g. "v4")
215
- - proportion : share of documents retrieved using ToCs
216
- """
217
- # Prepare base search kwargs
218
- filters = {}
219
- docs_question = []
220
- docs_images = []
221
-
222
- # TODO add source selection
223
- # if len(reports) > 0:
224
- # filters["short_name"] = {"$in":reports}
225
- # else:
226
- # filters["source"] = { "$in": sources}
227
-
228
- k_documents_toc = round(k_documents * proportion)
229
-
230
- relevant_tocs = await get_relevant_toc_level_for_query(query, tocs)
231
-
232
- print(f"Relevant ToCs : {relevant_tocs}")
233
- # Transform the ToC dict {"document": str, "chapter": str} into a list of string
234
- toc_filters = [toc['chapter'] for toc in relevant_tocs]
235
-
236
- filters_text_toc = {
237
- **filters,
238
- "chunk_type":"text",
239
- "toc_level0": {"$in": toc_filters},
240
- "version": version
241
- # "report_type": {}, # TODO to be completed to choose the right documents / chapters according to the analysis of the question
242
- }
243
-
244
- docs_question = vectorstore.similarity_search_with_score(query=query,filter = filters_text_toc,k = k_documents_toc)
245
-
246
- filters_text = {
247
- **filters,
248
- "chunk_type":"text",
249
- "version": version
250
- # "report_type": {}, # TODO to be completed to choose the right documents / chapters according to the analysis of the question
251
- }
252
-
253
- docs_question += vectorstore.similarity_search_with_score(query=query,filter = filters_text,k = k_documents - k_documents_toc)
254
-
255
- # remove duplicates or almost duplicates
256
- docs_question = remove_duplicates_chunks(docs_question)
257
- docs_question = [x for x in docs_question if x[1] > threshold]
258
-
259
- if search_figures:
260
- # Images
261
- filters_image = {
262
- **filters,
263
- "chunk_type":"image"
264
- }
265
- docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images)
266
-
267
- docs_question, docs_images = _add_metadata_and_score(docs_question), _add_metadata_and_score(docs_images)
268
-
269
- docs_question = [x for x in docs_question if len(x.page_content) > min_size]
270
-
271
- return {
272
- "docs_question" : docs_question,
273
- "docs_images" : docs_images
274
- }
275
-
276
-
277
- async def get_IPCC_relevant_documents(
278
- query: str,
279
- vectorstore:VectorStore,
280
- sources:list = ["IPCC","IPBES","IPOS"],
281
- search_figures:bool = False,
282
- reports:list = [],
283
- threshold:float = 0.6,
284
- k_summary:int = 3,
285
- k_total:int = 10,
286
- k_images: int = 5,
287
- namespace:str = "vectors",
288
- min_size:int = 200,
289
- search_only:bool = False,
290
- ) :
291
-
292
- # Check if all elements in the list are either IPCC or IPBES
293
- assert isinstance(sources,list)
294
- assert sources
295
- assert all([x in ["IPCC","IPBES","IPOS"] for x in sources])
296
- assert k_total > k_summary, "k_total should be greater than k_summary"
297
-
298
- # Prepare base search kwargs
299
- filters = {}
300
-
301
- if len(reports) > 0:
302
- filters["short_name"] = {"$in":reports}
303
- else:
304
- filters["source"] = { "$in": sources}
305
-
306
- # INIT
307
- docs_summaries = []
308
- docs_full = []
309
- docs_images = []
310
-
311
- if search_only:
312
- # Only search for images if search_only is True
313
- if search_figures:
314
- filters_image = {
315
- **filters,
316
- "chunk_type":"image"
317
- }
318
- docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images)
319
- docs_images = _add_metadata_and_score(docs_images)
320
- else:
321
- # Regular search flow for text and optionally images
322
- # Search for k_summary documents in the summaries dataset
323
- filters_summaries = {
324
- **filters,
325
- "chunk_type":"text",
326
- "report_type": { "$in":["SPM"]},
327
- }
328
-
329
- docs_summaries = vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = k_summary)
330
- docs_summaries = [x for x in docs_summaries if x[1] > threshold]
331
-
332
- # Search for k_total - k_summary documents in the full reports dataset
333
- filters_full = {
334
- **filters,
335
- "chunk_type":"text",
336
- "report_type": { "$nin":["SPM"]},
337
- }
338
- docs_full = vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_total)
339
-
340
- if search_figures:
341
- # Images
342
- filters_image = {
343
- **filters,
344
- "chunk_type":"image"
345
- }
346
- docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images)
347
-
348
- docs_summaries, docs_full, docs_images = _add_metadata_and_score(docs_summaries), _add_metadata_and_score(docs_full), _add_metadata_and_score(docs_images)
349
-
350
- # Filter if length are below threshold
351
- docs_summaries = [x for x in docs_summaries if len(x.page_content) > min_size]
352
- docs_full = [x for x in docs_full if len(x.page_content) > min_size]
353
-
354
- return {
355
- "docs_summaries" : docs_summaries,
356
- "docs_full" : docs_full,
357
- "docs_images" : docs_images,
358
- }
359
-
360
-
361
-
362
- def concatenate_documents(index, source_type, docs_question_dict, k_by_question, k_summary_by_question, k_images_by_question):
363
- # Keep the right number of documents - The k_summary documents from SPM are placed in front
364
- if source_type == "IPx":
365
- docs_question = docs_question_dict["docs_summaries"][:k_summary_by_question] + docs_question_dict["docs_full"][:(k_by_question - k_summary_by_question)]
366
- elif source_type == "POC" :
367
- docs_question = docs_question_dict["docs_question"][:k_by_question]
368
- else :
369
- raise ValueError("source_type should be either Vector or POC")
370
- # docs_question = [doc for key in docs_question_dict.keys() for doc in docs_question_dict[key]][:(k_by_question)]
371
-
372
- images_question = docs_question_dict["docs_images"][:k_images_by_question]
373
-
374
- return docs_question, images_question
375
-
376
-
377
-
378
- # The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results
379
- # @chain
380
- async def retrieve_documents(
381
- current_question: Dict[str, Any],
382
- config: Dict[str, Any],
383
- source_type: str,
384
- vectorstore: VectorStore,
385
- reranker: Any,
386
- version: str = "",
387
- search_figures: bool = False,
388
- search_only: bool = False,
389
- reports: list = [],
390
- rerank_by_question: bool = True,
391
- k_images_by_question: int = 5,
392
- k_before_reranking: int = 100,
393
- k_by_question: int = 5,
394
- k_summary_by_question: int = 3,
395
- tocs: list = [],
396
- by_toc=False
397
- ) -> Tuple[List[Document], List[Document]]:
398
- """
399
- Unpack the first question of the remaining questions, and retrieve and rerank corresponding documents, based on the question and selected_sources
400
-
401
- Args:
402
- state (dict): The current state containing documents, related content, relevant content sources, remaining questions and n_questions.
403
- current_question (dict): The current question being processed.
404
- config (dict): Configuration settings for logging and other purposes.
405
- vectorstore (object): The vector store used to retrieve relevant documents.
406
- reranker (object): The reranker used to rerank the retrieved documents.
407
- llm (object): The language model used for processing.
408
- rerank_by_question (bool, optional): Whether to rerank documents by question. Defaults to True.
409
- k_final (int, optional): The final number of documents to retrieve. Defaults to 15.
410
- k_before_reranking (int, optional): The number of documents to retrieve before reranking. Defaults to 100.
411
- k_summary (int, optional): The number of summary documents to retrieve. Defaults to 5.
412
- k_images (int, optional): The number of image documents to retrieve. Defaults to 5.
413
- Returns:
414
- dict: The updated state containing the retrieved and reranked documents, related content, and remaining questions.
415
- """
416
- sources = current_question["sources"]
417
- question = current_question["question"]
418
- index = current_question["index"]
419
- source_type = current_question["source_type"]
420
-
421
- print(f"Retrieve documents for question: {question}")
422
- await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config)
423
-
424
- print(f"""---- Retrieve documents from {current_question["source_type"]}----""")
425
-
426
-
427
- if source_type == "IPx":
428
- docs_question_dict = await get_IPCC_relevant_documents(
429
- query = question,
430
- vectorstore=vectorstore,
431
- search_figures = search_figures,
432
- sources = sources,
433
- min_size = 200,
434
- k_summary = k_before_reranking-1,
435
- k_total = k_before_reranking,
436
- k_images = k_images_by_question,
437
- threshold = 0.5,
438
- search_only = search_only,
439
- reports = reports,
440
- )
441
-
442
- if source_type == 'POC':
443
- if by_toc == True:
444
- print("---- Retrieve documents by ToC----")
445
- docs_question_dict = await get_POC_documents_by_ToC_relevant_documents(
446
- query=question,
447
- tocs = tocs,
448
- vectorstore=vectorstore,
449
- version=version,
450
- search_figures = search_figures,
451
- sources = sources,
452
- threshold = 0.5,
453
- search_only = search_only,
454
- reports = reports,
455
- min_size= 200,
456
- k_documents= k_before_reranking,
457
- k_images= k_by_question
458
- )
459
- else :
460
- docs_question_dict = await get_POC_relevant_documents(
461
- query = question,
462
- vectorstore=vectorstore,
463
- search_figures = search_figures,
464
- sources = sources,
465
- threshold = 0.5,
466
- search_only = search_only,
467
- reports = reports,
468
- min_size= 200,
469
- k_documents= k_before_reranking,
470
- k_images= k_by_question
471
- )
472
-
473
- # Rerank
474
- if reranker is not None and rerank_by_question:
475
- with suppress_output():
476
- for key in docs_question_dict.keys():
477
- docs_question_dict[key] = rerank_and_sort_docs(reranker,docs_question_dict[key],question)
478
- else:
479
- # Add a default reranking score
480
- for key in docs_question_dict.keys():
481
- if isinstance(docs_question_dict[key], list) and len(docs_question_dict[key]) > 0:
482
- for doc in docs_question_dict[key]:
483
- doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
484
-
485
- # Keep the right number of documents
486
- docs_question, images_question = concatenate_documents(index, source_type, docs_question_dict, k_by_question, k_summary_by_question, k_images_by_question)
487
-
488
- # Rerank the documents to put the most relevant in front
489
- if reranker is not None and rerank_by_question:
490
- docs_question = rerank_and_sort_docs(reranker, docs_question, question)
491
-
492
- # Add sources used in the metadata
493
- docs_question = _add_sources_used_in_metadata(docs_question,sources,question,index)
494
- images_question = _add_sources_used_in_metadata(images_question,sources,question,index)
495
-
496
- return docs_question, images_question
497
-
498
-
499
- async def retrieve_documents_for_all_questions(
500
- search_figures,
501
- search_only,
502
- reports,
503
- questions_list,
504
- n_questions,
505
- config,
506
- source_type,
507
- to_handle_questions_index,
508
- vectorstore,
509
- reranker,
510
- rerank_by_question=True,
511
- k_final=15,
512
- k_before_reranking=100,
513
- version: str = "",
514
- tocs: list[dict] = [],
515
- by_toc: bool = False
516
- ):
517
- """
518
- Retrieve documents in parallel for all questions.
519
- """
520
- # to_handle_questions_index = [x for x in state["questions_list"] if x["source_type"] == "IPx"]
521
-
522
- # TODO split les questions selon le type de sources dans le state question + conditions sur le nombre de questions traités par type de source
523
- # search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"]
524
- # search_only = state["search_only"]
525
- # reports = state["reports"]
526
- # questions_list = state["questions_list"]
527
-
528
- # k_by_question = k_final // state["n_questions"]["total"]
529
- # k_summary_by_question = _get_k_summary_by_question(state["n_questions"]["total"])
530
- # k_images_by_question = _get_k_images_by_question(state["n_questions"]["total"])
531
- k_by_question = k_final // n_questions
532
- k_summary_by_question = _get_k_summary_by_question(n_questions)
533
- k_images_by_question = _get_k_images_by_question(n_questions)
534
- k_before_reranking=100
535
-
536
- print(f"Source type here is {source_type}")
537
- tasks = [
538
- retrieve_documents(
539
- current_question=question,
540
- config=config,
541
- source_type=source_type,
542
- vectorstore=vectorstore,
543
- reranker=reranker,
544
- search_figures=search_figures,
545
- search_only=search_only,
546
- reports=reports,
547
- rerank_by_question=rerank_by_question,
548
- k_images_by_question=k_images_by_question,
549
- k_before_reranking=k_before_reranking,
550
- k_by_question=k_by_question,
551
- k_summary_by_question=k_summary_by_question,
552
- tocs=tocs,
553
- version=version,
554
- by_toc=by_toc
555
- )
556
- for i, question in enumerate(questions_list) if i in to_handle_questions_index
557
- ]
558
- results = await asyncio.gather(*tasks)
559
- # Combine results
560
- new_state = {"documents": [], "related_contents": [], "handled_questions_index": to_handle_questions_index}
561
- for docs_question, images_question in results:
562
- new_state["documents"].extend(docs_question)
563
- new_state["related_contents"].extend(images_question)
564
- return new_state
565
-
566
- # ToC Retriever
567
- async def get_relevant_toc_level_for_query(
568
- query: str,
569
- tocs: list[Document],
570
- ) -> list[dict] :
571
-
572
- doc_list = []
573
- for doc in tocs:
574
- doc_name = doc[0].metadata['name']
575
- toc = doc[0].page_content
576
- doc_list.append({'document': doc_name, 'toc': toc})
577
-
578
- llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
579
-
580
- prompt = ChatPromptTemplate.from_template(retrieve_chapter_prompt_template)
581
- chain = prompt | llm | StrOutputParser()
582
- response = chain.invoke({"query": query, "doc_list": doc_list})
583
-
584
- try:
585
- relevant_tocs = ast.literal_eval(response)
586
- except Exception as e:
587
- print(f" Failed to parse the result because of : {e}")
588
-
589
- return relevant_tocs
590
-
591
-
592
- def make_IPx_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
593
-
594
- async def retrieve_IPx_docs(state, config):
595
- source_type = "IPx"
596
- IPx_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "IPx"]
597
-
598
- search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"]
599
- search_only = state["search_only"]
600
- reports = state["reports"]
601
- questions_list = state["questions_list"]
602
- n_questions=state["n_questions"]["total"]
603
-
604
- state = await retrieve_documents_for_all_questions(
605
- search_figures=search_figures,
606
- search_only=search_only,
607
- reports=reports,
608
- questions_list=questions_list,
609
- n_questions=n_questions,
610
- config=config,
611
- source_type=source_type,
612
- to_handle_questions_index=IPx_questions_index,
613
- vectorstore=vectorstore,
614
- reranker=reranker,
615
- rerank_by_question=rerank_by_question,
616
- k_final=k_final,
617
- k_before_reranking=k_before_reranking,
618
- )
619
- return state
620
-
621
- return retrieve_IPx_docs
622
-
623
-
624
- def make_POC_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
625
-
626
- async def retrieve_POC_docs_node(state, config):
627
- source_type = "POC"
628
- POC_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"]
629
-
630
- search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"]
631
- search_only = state["search_only"]
632
- reports = state["reports"]
633
- questions_list = state["questions_list"]
634
- n_questions=state["n_questions"]["total"]
635
-
636
- state = await retrieve_documents_for_all_questions(
637
- search_figures=search_figures,
638
- search_only=search_only,
639
- reports=reports,
640
- questions_list=questions_list,
641
- n_questions=n_questions,
642
- config=config,
643
- source_type=source_type,
644
- to_handle_questions_index=POC_questions_index,
645
- vectorstore=vectorstore,
646
- reranker=reranker,
647
- rerank_by_question=rerank_by_question,
648
- k_final=k_final,
649
- k_before_reranking=k_before_reranking,
650
- )
651
- return state
652
-
653
- return retrieve_POC_docs_node
654
-
655
-
656
- def make_POC_by_ToC_retriever_node(
657
- vectorstore: VectorStore,
658
- reranker,
659
- llm,
660
- version: str = "",
661
- rerank_by_question=True,
662
- k_final=15,
663
- k_before_reranking=100,
664
- k_summary=5,
665
- ):
666
-
667
- async def retrieve_POC_docs_node(state, config):
668
- search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"]
669
- search_only = state["search_only"]
670
- search_only = state["search_only"]
671
- reports = state["reports"]
672
- questions_list = state["questions_list"]
673
- n_questions=state["n_questions"]["total"]
674
-
675
- tocs = get_ToCs(version=version)
676
-
677
- source_type = "POC"
678
- POC_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"]
679
-
680
- state = await retrieve_documents_for_all_questions(
681
- search_figures=search_figures,
682
- search_only=search_only,
683
- config=config,
684
- reports=reports,
685
- questions_list=questions_list,
686
- n_questions=n_questions,
687
- source_type=source_type,
688
- to_handle_questions_index=POC_questions_index,
689
- vectorstore=vectorstore,
690
- reranker=reranker,
691
- rerank_by_question=rerank_by_question,
692
- k_final=k_final,
693
- k_before_reranking=k_before_reranking,
694
- tocs=tocs,
695
- version=version,
696
- by_toc=True
697
- )
698
- return state
699
-
700
- return retrieve_POC_docs_node
701
-
702
-
703
-
704
-
705
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/chains/retrieve_papers.py DELETED
@@ -1,95 +0,0 @@
1
- from climateqa.engine.keywords import make_keywords_chain
2
- from climateqa.engine.llm import get_llm
3
- from climateqa.knowledge.openalex import OpenAlex
4
- from climateqa.engine.chains.answer_rag import make_rag_papers_chain
5
- from front.utils import make_html_papers
6
- from climateqa.engine.reranker import get_reranker
7
-
8
- oa = OpenAlex()
9
-
10
- llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
11
- reranker = get_reranker("nano")
12
-
13
-
14
- papers_cols_widths = {
15
- "id":100,
16
- "title":300,
17
- "doi":100,
18
- "publication_year":100,
19
- "abstract":500,
20
- "is_oa":50,
21
- }
22
-
23
- papers_cols = list(papers_cols_widths.keys())
24
- papers_cols_widths = list(papers_cols_widths.values())
25
-
26
-
27
-
28
- def generate_keywords(query):
29
- chain = make_keywords_chain(llm)
30
- keywords = chain.invoke(query)
31
- keywords = " AND ".join(keywords["keywords"])
32
- return keywords
33
-
34
-
35
- async def find_papers(query,after, relevant_content_sources_selection, reranker= reranker):
36
- if "Papers (OpenAlex)" in relevant_content_sources_selection:
37
- summary = ""
38
- keywords = generate_keywords(query)
39
- df_works = oa.search(keywords,after = after)
40
-
41
- print(f"Found {len(df_works)} papers")
42
-
43
- if not df_works.empty:
44
- df_works = df_works.dropna(subset=["abstract"])
45
- df_works = df_works[df_works["abstract"] != ""].reset_index(drop = True)
46
- df_works = oa.rerank(query,df_works,reranker)
47
- df_works = df_works.sort_values("rerank_score",ascending=False)
48
- docs_html = []
49
- for i in range(10):
50
- docs_html.append(make_html_papers(df_works, i))
51
- docs_html = "".join(docs_html)
52
- G = oa.make_network(df_works)
53
-
54
- height = "750px"
55
- network = oa.show_network(G,color_by = "rerank_score",notebook=False,height = height)
56
- network_html = network.generate_html()
57
-
58
- network_html = network_html.replace("'", "\"")
59
- css_to_inject = "<style>#mynetwork { border: none !important; } .card { border: none !important; }</style>"
60
- network_html = network_html + css_to_inject
61
-
62
-
63
- network_html = f"""<iframe style="width: 100%; height: {height};margin:0 auto" name="result" allow="midi; geolocation; microphone; camera;
64
- display-capture; encrypted-media;" sandbox="allow-modals allow-forms
65
- allow-scripts allow-same-origin allow-popups
66
- allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
67
- allowpaymentrequest="" frameborder="0" srcdoc='{network_html}'></iframe>"""
68
-
69
-
70
- docs = df_works["content"].head(10).tolist()
71
-
72
- df_works = df_works.reset_index(drop = True).reset_index().rename(columns = {"index":"doc"})
73
- df_works["doc"] = df_works["doc"] + 1
74
- df_works = df_works[papers_cols]
75
-
76
- yield docs_html, network_html, summary
77
-
78
- chain = make_rag_papers_chain(llm)
79
- result = chain.astream_log({"question": query,"docs": docs,"language":"English"})
80
- path_answer = "/logs/StrOutputParser/streamed_output/-"
81
-
82
- async for op in result:
83
-
84
- op = op.ops[0]
85
-
86
- if op['path'] == path_answer: # reforulated question
87
- new_token = op['value'] # str
88
- summary += new_token
89
- else:
90
- continue
91
- yield docs_html, network_html, summary
92
- else :
93
- print("No papers found")
94
- else :
95
- yield "","", ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/chains/retriever.py DELETED
@@ -1,126 +0,0 @@
1
- # import sys
2
- # import os
3
- # from contextlib import contextmanager
4
-
5
- # from ..reranker import rerank_docs
6
- # from ...knowledge.retriever import ClimateQARetriever
7
-
8
-
9
-
10
-
11
- # def divide_into_parts(target, parts):
12
- # # Base value for each part
13
- # base = target // parts
14
- # # Remainder to distribute
15
- # remainder = target % parts
16
- # # List to hold the result
17
- # result = []
18
-
19
- # for i in range(parts):
20
- # if i < remainder:
21
- # # These parts get base value + 1
22
- # result.append(base + 1)
23
- # else:
24
- # # The rest get the base value
25
- # result.append(base)
26
-
27
- # return result
28
-
29
-
30
- # @contextmanager
31
- # def suppress_output():
32
- # # Open a null device
33
- # with open(os.devnull, 'w') as devnull:
34
- # # Store the original stdout and stderr
35
- # old_stdout = sys.stdout
36
- # old_stderr = sys.stderr
37
- # # Redirect stdout and stderr to the null device
38
- # sys.stdout = devnull
39
- # sys.stderr = devnull
40
- # try:
41
- # yield
42
- # finally:
43
- # # Restore stdout and stderr
44
- # sys.stdout = old_stdout
45
- # sys.stderr = old_stderr
46
-
47
-
48
-
49
- # def make_retriever_node(vectorstore,reranker,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
50
-
51
- # def retrieve_documents(state):
52
-
53
- # POSSIBLE_SOURCES = ["IPCC","IPBES","IPOS"] # ,"OpenAlex"]
54
- # questions = state["questions"]
55
-
56
- # # Use sources from the user input or from the LLM detection
57
- # if "sources_input" not in state or state["sources_input"] is None:
58
- # sources_input = ["auto"]
59
- # else:
60
- # sources_input = state["sources_input"]
61
- # auto_mode = "auto" in sources_input
62
-
63
- # # There are several options to get the final top k
64
- # # Option 1 - Get 100 documents by question and rerank by question
65
- # # Option 2 - Get 100/n documents by question and rerank the total
66
- # if rerank_by_question:
67
- # k_by_question = divide_into_parts(k_final,len(questions))
68
-
69
- # docs = []
70
-
71
- # for i,q in enumerate(questions):
72
-
73
- # sources = q["sources"]
74
- # question = q["question"]
75
-
76
- # # If auto mode, we use the sources detected by the LLM
77
- # if auto_mode:
78
- # sources = [x for x in sources if x in POSSIBLE_SOURCES]
79
-
80
- # # Otherwise, we use the config
81
- # else:
82
- # sources = sources_input
83
-
84
- # # Search the document store using the retriever
85
- # # Configure high top k for further reranking step
86
- # retriever = ClimateQARetriever(
87
- # vectorstore=vectorstore,
88
- # sources = sources,
89
- # # reports = ias_reports,
90
- # min_size = 200,
91
- # k_summary = k_summary,
92
- # k_total = k_before_reranking,
93
- # threshold = 0.5,
94
- # )
95
- # docs_question = retriever.get_relevant_documents(question)
96
-
97
- # # Rerank
98
- # if reranker is not None:
99
- # with suppress_output():
100
- # docs_question = rerank_docs(reranker,docs_question,question)
101
- # else:
102
- # # Add a default reranking score
103
- # for doc in docs_question:
104
- # doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
105
-
106
- # # If rerank by question we select the top documents for each question
107
- # if rerank_by_question:
108
- # docs_question = docs_question[:k_by_question[i]]
109
-
110
- # # Add sources used in the metadata
111
- # for doc in docs_question:
112
- # doc.metadata["sources_used"] = sources
113
-
114
- # # Add to the list of docs
115
- # docs.extend(docs_question)
116
-
117
- # # Sorting the list in descending order by rerank_score
118
- # # Then select the top k
119
- # docs = sorted(docs, key=lambda x: x.metadata["reranking_score"], reverse=True)
120
- # docs = docs[:k_final]
121
-
122
- # new_state = {"documents":docs}
123
- # return new_state
124
-
125
- # return retrieve_documents
126
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/chains/sample_router.py DELETED
@@ -1,66 +0,0 @@
1
-
2
- # from typing import List
3
- # from typing import Literal
4
- # from langchain.prompts import ChatPromptTemplate
5
- # from langchain_core.utils.function_calling import convert_to_openai_function
6
- # from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
7
-
8
- # # https://livingdatalab.com/posts/2023-11-05-openai-function-calling-with-langchain.html
9
-
10
- # class Location(BaseModel):
11
- # country:str = Field(...,description="The country if directly mentioned or inferred from the location (cities, regions, adresses), ex: France, USA, ...")
12
- # location:str = Field(...,description="The specific place if mentioned (cities, regions, addresses), ex: Marseille, New York, Wisconsin, ...")
13
-
14
- # class QueryAnalysis(BaseModel):
15
- # """Analyzing the user query"""
16
-
17
- # language: str = Field(
18
- # description="Find the language of the query in full words (ex: French, English, Spanish, ...), defaults to English"
19
- # )
20
- # intent: str = Field(
21
- # enum=[
22
- # "Environmental impacts of AI",
23
- # "Geolocated info about climate change",
24
- # "Climate change",
25
- # "Biodiversity",
26
- # "Deep sea mining",
27
- # "Chitchat",
28
- # ],
29
- # description="""
30
- # Categorize the user query in one of the following category,
31
-
32
- # Examples:
33
- # - Geolocated info about climate change: "What will be the temperature in Marseille in 2050"
34
- # - Climate change: "What is radiative forcing", "How much will
35
- # """,
36
- # )
37
- # sources: List[Literal["IPCC", "IPBES", "IPOS"]] = Field(
38
- # ...,
39
- # description="""
40
- # Given a user question choose which documents would be most relevant for answering their question,
41
- # - IPCC is for questions about climate change, energy, impacts, and everything we can find the IPCC reports
42
- # - IPBES is for questions about biodiversity and nature
43
- # - IPOS is for questions about the ocean and deep sea mining
44
-
45
- # """,
46
- # )
47
- # date: str = Field(description="The date or period mentioned, ex: 2050, between 2020 and 2050")
48
- # location:Location
49
- # # query: str = Field(
50
- # # description = """
51
- # # Translate to english and reformulate the following user message to be a short standalone question, in the context of an educational discussion about climate change.
52
- # # The reformulated question will used in a search engine
53
- # # By default, assume that the user is asking information about the last century,
54
- # # Use the following examples
55
-
56
- # # ### Examples:
57
- # # La technologie nous sauvera-t-elle ? -> Can technology help humanity mitigate the effects of climate change?
58
- # # what are our reserves in fossil fuel? -> What are the current reserves of fossil fuels and how long will they last?
59
- # # what are the main causes of climate change? -> What are the main causes of climate change in the last century?
60
-
61
- # # Question in English:
62
- # # """
63
- # # )
64
-
65
- # openai_functions = [convert_to_openai_function(QueryAnalysis)]
66
- # llm2 = llm.bind(functions = openai_functions,function_call={"name":"QueryAnalysis"})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/chains/set_defaults.py DELETED
@@ -1,13 +0,0 @@
1
- def set_defaults(state):
2
- print("---- Setting defaults ----")
3
-
4
- if not state["audience"] or state["audience"] is None:
5
- state.update({"audience": "experts"})
6
-
7
- sources_input = state["sources_input"] if "sources_input" in state else ["auto"]
8
- state.update({"sources_input": sources_input})
9
-
10
- # if not state["sources_input"] or state["sources_input"] is None:
11
- # state.update({"sources_input": ["auto"]})
12
-
13
- return state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/chains/standalone_question.py DELETED
@@ -1,42 +0,0 @@
1
- from langchain.prompts import ChatPromptTemplate
2
-
3
- def make_standalone_question_chain(llm):
4
- prompt = ChatPromptTemplate.from_messages([
5
- ("system", """You are a helpful assistant that transforms user questions into standalone questions
6
- by incorporating context from the chat history if needed. The output should be a self-contained
7
- question that can be understood without any additional context.
8
-
9
- Examples:
10
- Chat History: "Let's talk about renewable energy"
11
- User Input: "What about solar?"
12
- Output: "What are the key aspects of solar energy as a renewable energy source?"
13
-
14
- Chat History: "What causes global warming?"
15
- User Input: "And what are its effects?"
16
- Output: "What are the effects of global warming on the environment and society?"
17
- """),
18
- ("user", """Chat History: {chat_history}
19
- User Question: {question}
20
-
21
- Transform this into a standalone question:
22
- Make sure to keep the original language of the question.""")
23
- ])
24
-
25
- chain = prompt | llm
26
- return chain
27
-
28
- def make_standalone_question_node(llm):
29
- standalone_chain = make_standalone_question_chain(llm)
30
-
31
- def transform_to_standalone(state):
32
- chat_history = state.get("chat_history", "")
33
- if chat_history == "":
34
- return {}
35
- output = standalone_chain.invoke({
36
- "chat_history": chat_history,
37
- "question": state["user_input"]
38
- })
39
- state["user_input"] = output.content
40
- return state
41
-
42
- return transform_to_standalone
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/chains/translation.py DELETED
@@ -1,42 +0,0 @@
1
-
2
- from langchain_core.pydantic_v1 import BaseModel, Field
3
- from typing import List
4
- from typing import Literal
5
- from langchain.prompts import ChatPromptTemplate
6
- from langchain_core.utils.function_calling import convert_to_openai_function
7
- from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
8
-
9
-
10
- class Translation(BaseModel):
11
- """Analyzing the user message input"""
12
-
13
- translation: str = Field(
14
- description="Translate the message input to English",
15
- )
16
-
17
-
18
- def make_translation_chain(llm):
19
-
20
- openai_functions = [convert_to_openai_function(Translation)]
21
- llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"Translation"})
22
-
23
- prompt = ChatPromptTemplate.from_messages([
24
- ("system", "You are a helpful assistant, you will translate the user input message to English using the function provided"),
25
- ("user", "input: {input}")
26
- ])
27
-
28
- chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
29
- return chain
30
-
31
-
32
- def make_translation_node(llm):
33
- translation_chain = make_translation_chain(llm)
34
-
35
- def translate_query(state):
36
- print("---- Translate query ----")
37
-
38
- user_input = state["user_input"]
39
- translation = translation_chain.invoke({"input":user_input})
40
- return {"query":translation["translation"]}
41
-
42
- return translate_query
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/embeddings.py CHANGED
@@ -2,7 +2,7 @@
2
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
3
  from langchain_community.embeddings import HuggingFaceEmbeddings
4
 
5
- def get_embeddings_function(version = "v1.2",query_instruction = "Represent this sentence for searching relevant passages: "):
6
 
7
  if version == "v1.2":
8
 
@@ -10,12 +10,12 @@ def get_embeddings_function(version = "v1.2",query_instruction = "Represent this
10
  # Best embedding model at a reasonable size at the moment (2023-11-22)
11
 
12
  model_name = "BAAI/bge-base-en-v1.5"
13
- encode_kwargs = {'normalize_embeddings': True,"show_progress_bar":False} # set True to compute cosine similarity
14
  print("Loading embeddings model: ", model_name)
15
  embeddings_function = HuggingFaceBgeEmbeddings(
16
  model_name=model_name,
17
  encode_kwargs=encode_kwargs,
18
- query_instruction=query_instruction,
19
  )
20
 
21
  else:
@@ -23,6 +23,3 @@ def get_embeddings_function(version = "v1.2",query_instruction = "Represent this
23
  embeddings_function = HuggingFaceEmbeddings(model_name = "sentence-transformers/multi-qa-mpnet-base-dot-v1")
24
 
25
  return embeddings_function
26
-
27
-
28
-
 
2
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
3
  from langchain_community.embeddings import HuggingFaceEmbeddings
4
 
5
+ def get_embeddings_function(version = "v1.2"):
6
 
7
  if version == "v1.2":
8
 
 
10
  # Best embedding model at a reasonable size at the moment (2023-11-22)
11
 
12
  model_name = "BAAI/bge-base-en-v1.5"
13
+ encode_kwargs = {'normalize_embeddings': True} # set True to compute cosine similarity
14
  print("Loading embeddings model: ", model_name)
15
  embeddings_function = HuggingFaceBgeEmbeddings(
16
  model_name=model_name,
17
  encode_kwargs=encode_kwargs,
18
+ query_instruction="Represent this sentence for searching relevant passages: "
19
  )
20
 
21
  else:
 
23
  embeddings_function = HuggingFaceEmbeddings(model_name = "sentence-transformers/multi-qa-mpnet-base-dot-v1")
24
 
25
  return embeddings_function
 
 
 
climateqa/engine/graph.py DELETED
@@ -1,346 +0,0 @@
1
- import sys
2
- import os
3
- from contextlib import contextmanager
4
-
5
- from langchain.schema import Document
6
- from langgraph.graph import END, StateGraph
7
- from langchain_core.runnables.graph import CurveStyle, MermaidDrawMethod
8
-
9
- from typing_extensions import TypedDict
10
- from typing import List, Dict
11
-
12
- import operator
13
- from typing import Annotated
14
- import pandas as pd
15
- from IPython.display import display, HTML, Image
16
-
17
- from .chains.answer_chitchat import make_chitchat_node
18
- from .chains.answer_ai_impact import make_ai_impact_node
19
- from .chains.query_transformation import make_query_transform_node
20
- from .chains.translation import make_translation_node
21
- from .chains.intent_categorization import make_intent_categorization_node
22
- from .chains.retrieve_documents import make_IPx_retriever_node, make_POC_retriever_node, make_POC_by_ToC_retriever_node
23
- from .chains.answer_rag import make_rag_node
24
- from .chains.graph_retriever import make_graph_retriever_node
25
- from .chains.chitchat_categorization import make_chitchat_intent_categorization_node
26
- from .chains.standalone_question import make_standalone_question_node
27
- from .chains.follow_up import make_follow_up_node # Add this import
28
-
29
- class GraphState(TypedDict):
30
- """
31
- Represents the state of our graph.
32
- """
33
- user_input : str
34
- chat_history : str
35
- language : str
36
- intent : str
37
- search_graphs_chitchat : bool
38
- query: str
39
- questions_list : List[dict]
40
- handled_questions_index : Annotated[list[int], operator.add]
41
- n_questions : int
42
- answer: str
43
- audience: str = "experts"
44
- sources_input: List[str] = ["IPCC","IPBES"] # Deprecated -> used only graphs that can only be OWID
45
- relevant_content_sources_selection: List[str] = ["Figures (IPCC/IPBES)"]
46
- sources_auto: bool = True
47
- min_year: int = 1960
48
- max_year: int = None
49
- documents: Annotated[List[Document], operator.add]
50
- related_contents : Annotated[List[Document], operator.add] # Images
51
- recommended_content : List[Document] # OWID Graphs # TODO merge with related_contents
52
- search_only : bool = False
53
- reports : List[str] = []
54
- follow_up_questions: List[str] = []
55
-
56
- def dummy(state):
57
- return
58
-
59
- def search(state): #TODO
60
- return
61
-
62
- def answer_search(state):#TODO
63
- return
64
-
65
- def route_intent(state):
66
- intent = state["intent"]
67
- if intent in ["chitchat","esg"]:
68
- return "answer_chitchat"
69
- # elif intent == "ai_impact":
70
- # return "answer_ai_impact"
71
- else:
72
- # Search route
73
- return "answer_climate"
74
-
75
- def chitchat_route_intent(state):
76
- intent = state["search_graphs_chitchat"]
77
- if intent is True:
78
- return END #TODO
79
- elif intent is False:
80
- return END
81
-
82
- def route_translation(state):
83
- if state["language"].lower() == "english":
84
- return "transform_query"
85
- else:
86
- return "transform_query"
87
- # return "translate_query" #TODO : add translation
88
-
89
-
90
- def route_based_on_relevant_docs(state,threshold_docs=0.2):
91
- docs = [x for x in state["documents"] if x.metadata["reranking_score"] > threshold_docs]
92
- print("Route : ", ["answer_rag" if len(docs) > 0 else "answer_rag_no_docs"])
93
- if len(docs) > 0:
94
- return "answer_rag"
95
- else:
96
- return "answer_rag_no_docs"
97
-
98
- def route_continue_retrieve_documents(state):
99
- index_question_ipx = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "IPx"]
100
- questions_ipx_finished = all(elem in state["handled_questions_index"] for elem in index_question_ipx)
101
- if questions_ipx_finished:
102
- return "end_retrieve_IPx_documents"
103
- else:
104
- return "retrieve_documents"
105
-
106
-
107
- def route_retrieve_documents(state):
108
- sources_to_retrieve = []
109
-
110
- if "Graphs (OurWorldInData)" in state["relevant_content_sources_selection"] :
111
- sources_to_retrieve.append("retrieve_graphs")
112
-
113
- if sources_to_retrieve == []:
114
- return END
115
- return sources_to_retrieve
116
-
117
- def route_follow_up(state):
118
- if state["follow_up_questions"]:
119
- return "process_follow_up"
120
- return END
121
-
122
- def make_id_dict(values):
123
- return {k:k for k in values}
124
-
125
- def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_region, reranker, threshold_docs=0.2):
126
-
127
- workflow = StateGraph(GraphState)
128
-
129
- # Define the node functions
130
- standalone_question_node = make_standalone_question_node(llm)
131
- categorize_intent = make_intent_categorization_node(llm)
132
- transform_query = make_query_transform_node(llm)
133
- translate_query = make_translation_node(llm)
134
- answer_chitchat = make_chitchat_node(llm)
135
- answer_ai_impact = make_ai_impact_node(llm)
136
- retrieve_documents = make_IPx_retriever_node(vectorstore_ipcc, reranker, llm)
137
- retrieve_graphs = make_graph_retriever_node(vectorstore_graphs, reranker)
138
- # retrieve_local_data = make_POC_retriever_node(vectorstore_region, reranker, llm)
139
- answer_rag = make_rag_node(llm, with_docs=True)
140
- answer_rag_no_docs = make_rag_node(llm, with_docs=False)
141
- chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
142
- generate_follow_up = make_follow_up_node(llm)
143
-
144
- # Define the nodes
145
- # workflow.add_node("set_defaults", set_defaults)
146
- workflow.add_node("standalone_question", standalone_question_node)
147
- workflow.add_node("categorize_intent", categorize_intent)
148
- workflow.add_node("answer_climate", dummy)
149
- workflow.add_node("answer_search", answer_search)
150
- workflow.add_node("transform_query", transform_query)
151
- workflow.add_node("translate_query", translate_query)
152
- workflow.add_node("answer_chitchat", answer_chitchat)
153
- workflow.add_node("chitchat_categorize_intent", chitchat_categorize_intent)
154
- workflow.add_node("retrieve_graphs", retrieve_graphs)
155
- # workflow.add_node("retrieve_local_data", retrieve_local_data)
156
- workflow.add_node("retrieve_graphs_chitchat", retrieve_graphs)
157
- workflow.add_node("retrieve_documents", retrieve_documents)
158
- workflow.add_node("answer_rag", answer_rag)
159
- workflow.add_node("answer_rag_no_docs", answer_rag_no_docs)
160
- workflow.add_node("generate_follow_up", generate_follow_up)
161
- # workflow.add_node("process_follow_up", standalone_question_node)
162
-
163
- # Entry point
164
- workflow.set_entry_point("standalone_question")
165
-
166
- # CONDITIONAL EDGES
167
- workflow.add_conditional_edges(
168
- "categorize_intent",
169
- route_intent,
170
- make_id_dict(["answer_chitchat","answer_climate"])
171
- )
172
-
173
- workflow.add_conditional_edges(
174
- "chitchat_categorize_intent",
175
- chitchat_route_intent,
176
- make_id_dict(["retrieve_graphs_chitchat", END])
177
- )
178
-
179
- workflow.add_conditional_edges(
180
- "answer_climate",
181
- route_translation,
182
- make_id_dict(["translate_query","transform_query"])
183
- )
184
-
185
- workflow.add_conditional_edges(
186
- "answer_search",
187
- lambda x : route_based_on_relevant_docs(x,threshold_docs=threshold_docs),
188
- make_id_dict(["answer_rag","answer_rag_no_docs"])
189
- )
190
- workflow.add_conditional_edges(
191
- "transform_query",
192
- route_retrieve_documents,
193
- make_id_dict(["retrieve_graphs", END])
194
- )
195
-
196
- # workflow.add_conditional_edges(
197
- # "generate_follow_up",
198
- # route_follow_up,
199
- # make_id_dict(["process_follow_up", END])
200
- # )
201
-
202
- # Define the edges
203
- workflow.add_edge("standalone_question", "categorize_intent")
204
- workflow.add_edge("translate_query", "transform_query")
205
- workflow.add_edge("transform_query", "retrieve_documents") #TODO put back
206
- # workflow.add_edge("transform_query", "retrieve_local_data")
207
- # workflow.add_edge("transform_query", END) # TODO remove
208
-
209
- workflow.add_edge("retrieve_graphs", END)
210
- workflow.add_edge("answer_rag", "generate_follow_up")
211
- workflow.add_edge("answer_rag_no_docs", "generate_follow_up")
212
- workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
213
- workflow.add_edge("retrieve_graphs_chitchat", END)
214
-
215
- # workflow.add_edge("retrieve_local_data", "answer_search")
216
- workflow.add_edge("retrieve_documents", "answer_search")
217
- workflow.add_edge("generate_follow_up",END)
218
- # workflow.add_edge("process_follow_up", "categorize_intent")
219
-
220
- # Compile
221
- app = workflow.compile()
222
- return app
223
-
224
- def make_graph_agent_poc(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_region, reranker, version:str, threshold_docs=0.2):
225
- """_summary_
226
-
227
- Args:
228
- llm (_type_): _description_
229
- vectorstore_ipcc (_type_): _description_
230
- vectorstore_graphs (_type_): _description_
231
- vectorstore_region (_type_): _description_
232
- reranker (_type_): _description_
233
- version (str): version of the parsed documents (e.g "v4")
234
- threshold_docs (float, optional): _description_. Defaults to 0.2.
235
-
236
- Returns:
237
- _type_: _description_
238
- """
239
-
240
-
241
- workflow = StateGraph(GraphState)
242
-
243
- # Define the node functions
244
- standalone_question_node = make_standalone_question_node(llm)
245
-
246
- categorize_intent = make_intent_categorization_node(llm)
247
- transform_query = make_query_transform_node(llm)
248
- translate_query = make_translation_node(llm)
249
- answer_chitchat = make_chitchat_node(llm)
250
- answer_ai_impact = make_ai_impact_node(llm)
251
- retrieve_documents = make_IPx_retriever_node(vectorstore_ipcc, reranker, llm)
252
- retrieve_graphs = make_graph_retriever_node(vectorstore_graphs, reranker)
253
- # retrieve_local_data = make_POC_retriever_node(vectorstore_region, reranker, llm)
254
- retrieve_local_data = make_POC_by_ToC_retriever_node(vectorstore_region, reranker, llm, version=version)
255
- answer_rag = make_rag_node(llm, with_docs=True)
256
- answer_rag_no_docs = make_rag_node(llm, with_docs=False)
257
- chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
258
- generate_follow_up = make_follow_up_node(llm)
259
-
260
- # Define the nodes
261
- # workflow.add_node("set_defaults", set_defaults)
262
- workflow.add_node("standalone_question", standalone_question_node)
263
- workflow.add_node("categorize_intent", categorize_intent)
264
- workflow.add_node("answer_climate", dummy)
265
- workflow.add_node("answer_search", answer_search)
266
- # workflow.add_node("end_retrieve_local_documents", dummy)
267
- # workflow.add_node("end_retrieve_IPx_documents", dummy)
268
- workflow.add_node("transform_query", transform_query)
269
- workflow.add_node("translate_query", translate_query)
270
- workflow.add_node("answer_chitchat", answer_chitchat)
271
- workflow.add_node("chitchat_categorize_intent", chitchat_categorize_intent)
272
- workflow.add_node("retrieve_graphs", retrieve_graphs)
273
- workflow.add_node("retrieve_local_data", retrieve_local_data)
274
- workflow.add_node("retrieve_graphs_chitchat", retrieve_graphs)
275
- workflow.add_node("retrieve_documents", retrieve_documents)
276
- workflow.add_node("answer_rag", answer_rag)
277
- workflow.add_node("answer_rag_no_docs", answer_rag_no_docs)
278
- workflow.add_node("generate_follow_up", generate_follow_up)
279
-
280
- # Entry point
281
- workflow.set_entry_point("standalone_question")
282
-
283
- # CONDITIONAL EDGES
284
- workflow.add_conditional_edges(
285
- "categorize_intent",
286
- route_intent,
287
- make_id_dict(["answer_chitchat","answer_climate"])
288
- )
289
-
290
- workflow.add_conditional_edges(
291
- "chitchat_categorize_intent",
292
- chitchat_route_intent,
293
- make_id_dict(["retrieve_graphs_chitchat", END])
294
- )
295
-
296
- workflow.add_conditional_edges(
297
- "answer_climate",
298
- route_translation,
299
- make_id_dict(["translate_query","transform_query"])
300
- )
301
-
302
- workflow.add_conditional_edges(
303
- "answer_search",
304
- lambda x : route_based_on_relevant_docs(x,threshold_docs=threshold_docs),
305
- make_id_dict(["answer_rag","answer_rag_no_docs"])
306
- )
307
- workflow.add_conditional_edges(
308
- "transform_query",
309
- route_retrieve_documents,
310
- make_id_dict(["retrieve_graphs", END])
311
- )
312
-
313
- # Define the edges
314
- workflow.add_edge("standalone_question", "categorize_intent")
315
- workflow.add_edge("translate_query", "transform_query")
316
- workflow.add_edge("transform_query", "retrieve_documents") #TODO put back
317
- workflow.add_edge("transform_query", "retrieve_local_data")
318
- # workflow.add_edge("transform_query", END) # TODO remove
319
-
320
- workflow.add_edge("retrieve_graphs", END)
321
- workflow.add_edge("answer_rag", "generate_follow_up")
322
- workflow.add_edge("answer_rag_no_docs", "generate_follow_up")
323
- workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
324
- workflow.add_edge("retrieve_graphs_chitchat", END)
325
-
326
- workflow.add_edge("retrieve_local_data", "answer_search")
327
- workflow.add_edge("retrieve_documents", "answer_search")
328
- workflow.add_edge("generate_follow_up",END)
329
-
330
-
331
- # Compile
332
- app = workflow.compile()
333
- return app
334
-
335
-
336
-
337
-
338
- def display_graph(app):
339
-
340
- display(
341
- Image(
342
- app.get_graph(xray = True).draw_mermaid_png(
343
- draw_method=MermaidDrawMethod.API,
344
- )
345
- )
346
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/graph_retriever.py DELETED
@@ -1,88 +0,0 @@
1
- from langchain_core.retrievers import BaseRetriever
2
- from langchain_core.documents.base import Document
3
- from langchain_core.vectorstores import VectorStore
4
- from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
5
-
6
- from typing import List
7
-
8
- # class GraphRetriever(BaseRetriever):
9
- # vectorstore:VectorStore
10
- # sources:list = ["OWID"] # plus tard ajouter OurWorldInData # faudra integrate avec l'autre retriever
11
- # threshold:float = 0.5
12
- # k_total:int = 10
13
-
14
- # def _get_relevant_documents(
15
- # self, query: str, *, run_manager: CallbackManagerForRetrieverRun
16
- # ) -> List[Document]:
17
-
18
- # # Check if all elements in the list are IEA or OWID
19
- # assert isinstance(self.sources,list)
20
- # assert self.sources
21
- # assert any([x in ["OWID"] for x in self.sources])
22
-
23
- # # Prepare base search kwargs
24
- # filters = {}
25
-
26
- # filters["source"] = {"$in": self.sources}
27
-
28
- # docs = self.vectorstore.similarity_search_with_score(query=query, filter=filters, k=self.k_total)
29
-
30
- # # Filter if scores are below threshold
31
- # docs = [x for x in docs if x[1] > self.threshold]
32
-
33
- # # Remove duplicate documents
34
- # unique_docs = []
35
- # seen_docs = []
36
- # for i, doc in enumerate(docs):
37
- # if doc[0].page_content not in seen_docs:
38
- # unique_docs.append(doc)
39
- # seen_docs.append(doc[0].page_content)
40
-
41
- # # Add score to metadata
42
- # results = []
43
- # for i,(doc,score) in enumerate(unique_docs):
44
- # doc.metadata["similarity_score"] = score
45
- # doc.metadata["content"] = doc.page_content
46
- # results.append(doc)
47
-
48
- # return results
49
-
50
- async def retrieve_graphs(
51
- query: str,
52
- vectorstore:VectorStore,
53
- sources:list = ["OWID"], # plus tard ajouter OurWorldInData # faudra integrate avec l'autre retriever
54
- threshold:float = 0.5,
55
- k_total:int = 10,
56
- )-> List[Document]:
57
-
58
- # Check if all elements in the list are IEA or OWID
59
- assert isinstance(sources,list)
60
- assert sources
61
- assert any([x in ["OWID"] for x in sources])
62
-
63
- # Prepare base search kwargs
64
- filters = {}
65
-
66
- filters["source"] = {"$in": sources}
67
-
68
- docs = vectorstore.similarity_search_with_score(query=query, filter=filters, k=k_total)
69
-
70
- # Filter if scores are below threshold
71
- docs = [x for x in docs if x[1] > threshold]
72
-
73
- # Remove duplicate documents
74
- unique_docs = []
75
- seen_docs = []
76
- for i, doc in enumerate(docs):
77
- if doc[0].page_content not in seen_docs:
78
- unique_docs.append(doc)
79
- seen_docs.append(doc[0].page_content)
80
-
81
- # Add score to metadata
82
- results = []
83
- for i,(doc,score) in enumerate(unique_docs):
84
- doc.metadata["similarity_score"] = score
85
- doc.metadata["content"] = doc.page_content
86
- results.append(doc)
87
-
88
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/keywords.py CHANGED
@@ -11,12 +11,10 @@ class KeywordsOutput(BaseModel):
11
 
12
  keywords: list = Field(
13
  description="""
14
- Generate 1 or 2 relevant keywords from the user query to ask a search engine for scientific research papers. Answer only with English keywords.
15
- Do not use special characters or accents.
16
 
17
  Example:
18
  - "What is the impact of deep sea mining ?" -> ["deep sea mining"]
19
- - "Quel est l'impact de l'exploitation minière en haute mer ?" -> ["deep sea mining"]
20
  - "How will El Nino be impacted by climate change" -> ["el nino"]
21
  - "Is climate change a hoax" -> [Climate change","hoax"]
22
  """
 
11
 
12
  keywords: list = Field(
13
  description="""
14
+ Generate 1 or 2 relevant keywords from the user query to ask a search engine for scientific research papers.
 
15
 
16
  Example:
17
  - "What is the impact of deep sea mining ?" -> ["deep sea mining"]
 
18
  - "How will El Nino be impacted by climate change" -> ["el nino"]
19
  - "Is climate change a hoax" -> [Climate change","hoax"]
20
  """
climateqa/engine/llm/__init__.py CHANGED
@@ -1,6 +1,5 @@
1
  from climateqa.engine.llm.openai import get_llm as get_openai_llm
2
  from climateqa.engine.llm.azure import get_llm as get_azure_llm
3
- from climateqa.engine.llm.ollama import get_llm as get_ollama_llm
4
 
5
 
6
  def get_llm(provider="openai",**kwargs):
@@ -9,8 +8,6 @@ def get_llm(provider="openai",**kwargs):
9
  return get_openai_llm(**kwargs)
10
  elif provider == "azure":
11
  return get_azure_llm(**kwargs)
12
- elif provider == "ollama":
13
- return get_ollama_llm(**kwargs)
14
  else:
15
  raise ValueError(f"Unknown provider: {provider}")
16
 
 
1
  from climateqa.engine.llm.openai import get_llm as get_openai_llm
2
  from climateqa.engine.llm.azure import get_llm as get_azure_llm
 
3
 
4
 
5
  def get_llm(provider="openai",**kwargs):
 
8
  return get_openai_llm(**kwargs)
9
  elif provider == "azure":
10
  return get_azure_llm(**kwargs)
 
 
11
  else:
12
  raise ValueError(f"Unknown provider: {provider}")
13
 
climateqa/engine/llm/ollama.py DELETED
@@ -1,6 +0,0 @@
1
-
2
-
3
- from langchain_community.llms import Ollama
4
-
5
- def get_llm(model="llama3", **kwargs):
6
- return Ollama(model=model, **kwargs)
 
 
 
 
 
 
 
climateqa/engine/llm/openai.py CHANGED
@@ -7,7 +7,7 @@ try:
7
  except Exception:
8
  pass
9
 
10
- def get_llm(model="gpt-4o-mini",max_tokens=1024, temperature=0.0, streaming=True,timeout=30, **kwargs):
11
 
12
  llm = ChatOpenAI(
13
  model=model,
 
7
  except Exception:
8
  pass
9
 
10
+ def get_llm(model="gpt-3.5-turbo-0125",max_tokens=1024, temperature=0.0, streaming=True,timeout=30, **kwargs):
11
 
12
  llm = ChatOpenAI(
13
  model=model,
climateqa/engine/{chains/prompts.py → prompts.py} RENAMED
@@ -36,41 +36,13 @@ You are given a question and extracted passages of the IPCC and/or IPBES reports
36
  """
37
 
38
 
39
- # answer_prompt_template_old = """
40
- # You are ClimateQ&A, an AI Assistant created by Ekimetrics. You are given a question and extracted passages of reports. Provide a clear and structured answer based on the passages provided, the context and the guidelines.
41
-
42
- # Guidelines:
43
- # - If the passages have useful facts or numbers, use them in your answer.
44
- # - When you use information from a passage, mention where it came from by using [Doc i] at the end of the sentence. i stands for the number of the document.
45
- # - Do not use the sentence 'Doc i says ...' to say where information came from.
46
- # - If the same thing is said in more than one document, you can mention all of them like this: [Doc i, Doc j, Doc k]
47
- # - Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
48
- # - If it makes sense, use bullet points and lists to make your answers easier to understand.
49
- # - You do not need to use every passage. Only use the ones that help answer the question.
50
- # - If the documents do not have the information needed to answer the question, just say you do not have enough information.
51
- # - Consider by default that the question is about the past century unless it is specified otherwise.
52
- # - If the passage is the caption of a picture, you can still use it as part of your answer as any other document.
53
-
54
- # -----------------------
55
- # Passages:
56
- # {context}
57
-
58
- # -----------------------
59
- # Question: {query} - Explained to {audience}
60
- # Answer in {language} with the passages citations:
61
- # """
62
-
63
  answer_prompt_template = """
64
- You are ClimateQ&A, an AI Assistant created by Ekimetrics. You are given a question and extracted passages of reports. Provide a clear and structured answer based on the passages provided, the context and the guidelines.
65
 
66
  Guidelines:
67
  - If the passages have useful facts or numbers, use them in your answer.
68
  - When you use information from a passage, mention where it came from by using [Doc i] at the end of the sentence. i stands for the number of the document.
69
- - You will receive passages from different reports, e.g., IPCC and PPCP. Make separate paragraphs and specify the source of the information in your answer, e.g., "According to IPCC, ...".
70
- - The different sources are IPCC, IPBES, PPCP (for Plan Climat Air Energie Territorial de Paris), PBDP (for Plan Biodiversité de Paris), Acclimaterra (Rapport scientifique de la région Nouvelle Aquitaine en France).
71
- - If the reports are local (like PPCP, PBDP, Acclimaterra), consider that the information is specific to the region and not global. If the document is about a nearby region (for example, an extract from Acclimaterra for a question about Britain), explicitly state the concerned region.
72
- - Do not mention that you are using specific extract documents, but mention only the source information. "According to IPCC, ..." rather than "According to the provided document from IPCC ..."
73
- - Make a clear distinction between information from IPCC, IPBES, Acclimaterra that are scientific reports and PPCP, PBDP that are strategic reports. Strategic reports should not be taken as verified facts, but as political or strategic decisions.
74
  - If the same thing is said in more than one document, you can mention all of them like this: [Doc i, Doc j, Doc k]
75
  - Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
76
  - If it makes sense, use bullet points and lists to make your answers easier to understand.
@@ -79,16 +51,16 @@ Guidelines:
79
  - Consider by default that the question is about the past century unless it is specified otherwise.
80
  - If the passage is the caption of a picture, you can still use it as part of your answer as any other document.
81
 
82
-
83
  -----------------------
84
  Passages:
85
  {context}
86
 
87
  -----------------------
88
- Question: {query} - Explained to {audience}
89
  Answer in {language} with the passages citations:
90
  """
91
 
 
92
  papers_prompt_template = """
93
  You are ClimateQ&A, an AI Assistant created by Ekimetrics. You are given a question and extracted abstracts of scientific papers. Provide a clear and structured answer based on the abstracts provided, the context and the guidelines.
94
 
@@ -165,7 +137,7 @@ Guidelines:
165
  - If the question is not related to environmental issues, never never answer it. Say it's not your role.
166
  - Make paragraphs by starting new lines to make your answers more readable.
167
 
168
- Question: {query}
169
  Answer in {language}:
170
  """
171
 
@@ -175,77 +147,4 @@ audience_prompts = {
175
  "children": "6 year old children that don't know anything about science and climate change and need metaphors to learn",
176
  "general": "the general public who know the basics in science and climate change and want to learn more about it without technical terms. Still use references to passages.",
177
  "experts": "expert and climate scientists that are not afraid of technical terms",
178
- }
179
-
180
-
181
- answer_prompt_graph_template = """
182
- Given the user question and a list of graphs which are related to the question, rank the graphs based on relevance to the user question. ALWAYS follow the guidelines given below.
183
-
184
- ### Guidelines ###
185
- - Keep all the graphs that are given to you.
186
- - NEVER modify the graph HTML embedding, the category or the source leave them exactly as they are given.
187
- - Return the ranked graphs as a list of dictionaries with keys 'embedding', 'category', and 'source'.
188
- - Return a valid JSON output.
189
-
190
- -----------------------
191
- User question:
192
- {query}
193
-
194
- Graphs and their HTML embedding:
195
- {recommended_content}
196
-
197
- -----------------------
198
- {format_instructions}
199
-
200
- Output the result as json with a key "graphs" containing a list of dictionaries of the relevant graphs with keys 'embedding', 'category', and 'source'. Do not modify the graph HTML embedding, the category or the source. Do not put any message or text before or after the JSON output.
201
- """
202
-
203
- retrieve_chapter_prompt_template = """Given the user question and a list of documents with their table of contents, retrieve the 5 most relevant level 0 chapters which could help to answer to the question while taking account their sub-chapters.
204
-
205
- The table of contents is structured like that :
206
- {{
207
- "level": 0,
208
- "Chapter 1": {{}},
209
- "Chapter 2" : {{
210
- "level": 1,
211
- "Chapter 2.1": {{
212
- ...
213
- }}
214
- }},
215
- }}
216
-
217
- Here level is the level of the chapter. For example, Chapter 1 and Chapter 2 are at level 0, and Chapter 2.1 is at level 1.
218
-
219
- ### Guidelines ###
220
- - Keep all the list of documents that is given to you
221
- - Each chapter must keep **EXACTLY** its assigned level in the table of contents. **DO NOT MODIFY THE LEVELS. **
222
- - Check systematically the level of a chapter before including it in the answer.
223
- - Return **valid JSON** result.
224
-
225
- --------------------
226
- User question :
227
- {query}
228
-
229
- List of documents with their table of contents :
230
- {doc_list}
231
-
232
- --------------------
233
-
234
- Return a JSON result with a list of relevant chapters with the following keys **WITHOUT** the json markdown indicator ```json at the beginning:
235
- - "document" : the document in which we can find the chapter
236
- - "chapter" : the title of the chapter
237
-
238
- **IMPORTANT : Make sure that the levels of the answer are exactly the same as the ones in the table of contents**
239
-
240
- Example of a JSON response:
241
- [
242
- {{
243
- "document": "Document A",
244
- "chapter": "Chapter 1",
245
- }},
246
- {{
247
- "document": "Document B",
248
- "chapter": "Chapter 5",
249
- }}
250
- ]
251
- """
 
36
  """
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  answer_prompt_template = """
40
+ You are ClimateQ&A, an AI Assistant created by Ekimetrics. You are given a question and extracted passages of the IPCC and/or IPBES reports. Provide a clear and structured answer based on the passages provided, the context and the guidelines.
41
 
42
  Guidelines:
43
  - If the passages have useful facts or numbers, use them in your answer.
44
  - When you use information from a passage, mention where it came from by using [Doc i] at the end of the sentence. i stands for the number of the document.
45
+ - Do not use the sentence 'Doc i says ...' to say where information came from.
 
 
 
 
46
  - If the same thing is said in more than one document, you can mention all of them like this: [Doc i, Doc j, Doc k]
47
  - Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
48
  - If it makes sense, use bullet points and lists to make your answers easier to understand.
 
51
  - Consider by default that the question is about the past century unless it is specified otherwise.
52
  - If the passage is the caption of a picture, you can still use it as part of your answer as any other document.
53
 
 
54
  -----------------------
55
  Passages:
56
  {context}
57
 
58
  -----------------------
59
+ Question: {question} - Explained to {audience}
60
  Answer in {language} with the passages citations:
61
  """
62
 
63
+
64
  papers_prompt_template = """
65
  You are ClimateQ&A, an AI Assistant created by Ekimetrics. You are given a question and extracted abstracts of scientific papers. Provide a clear and structured answer based on the abstracts provided, the context and the guidelines.
66
 
 
137
  - If the question is not related to environmental issues, never never answer it. Say it's not your role.
138
  - Make paragraphs by starting new lines to make your answers more readable.
139
 
140
+ Question: {question}
141
  Answer in {language}:
142
  """
143
 
 
147
  "children": "6 year old children that don't know anything about science and climate change and need metaphors to learn",
148
  "general": "the general public who know the basics in science and climate change and want to learn more about it without technical terms. Still use references to passages.",
149
  "experts": "expert and climate scientists that are not afraid of technical terms",
150
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/{chains/answer_rag.py → rag.py} RENAMED
@@ -2,16 +2,17 @@ from operator import itemgetter
2
 
3
  from langchain_core.prompts import ChatPromptTemplate
4
  from langchain_core.output_parsers import StrOutputParser
 
5
  from langchain_core.prompts.prompt import PromptTemplate
6
  from langchain_core.prompts.base import format_document
7
 
8
- from climateqa.engine.chains.prompts import answer_prompt_template,answer_prompt_without_docs_template,answer_prompt_images_template
9
- from climateqa.engine.chains.prompts import papers_prompt_template
10
- import time
11
- from ..utils import rename_chain, pass_values
 
12
 
13
-
14
- DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="Source : {source} - Content : {page_content}")
15
 
16
  def _combine_documents(
17
  docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, sep="\n\n"
@@ -39,54 +40,72 @@ def get_text_docs(x):
39
  def get_image_docs(x):
40
  return [doc for doc in x if doc.metadata["chunk_type"] == "image"]
41
 
42
- def make_rag_chain(llm):
 
 
 
43
  prompt = ChatPromptTemplate.from_template(answer_prompt_template)
44
- chain = ({
45
- "context":lambda x : _combine_documents(x["documents"]),
46
- "context_length":lambda x : print("CONTEXT LENGTH : " , len(_combine_documents(x["documents"]))),
47
- "query":itemgetter("query"),
48
- "language":itemgetter("language"),
49
- "audience":itemgetter("audience"),
50
- } | prompt | llm | StrOutputParser())
51
- return chain
52
 
53
- def make_rag_chain_without_docs(llm):
54
- prompt = ChatPromptTemplate.from_template(answer_prompt_without_docs_template)
55
- chain = prompt | llm | StrOutputParser()
56
- return chain
57
 
58
- def make_rag_node(llm,with_docs = True):
 
 
 
59
 
60
- if with_docs:
61
- rag_chain = make_rag_chain(llm)
62
- else:
63
- rag_chain = make_rag_chain_without_docs(llm)
64
-
65
- async def answer_rag(state,config):
66
- print("---- Answer RAG ----")
67
- start_time = time.time()
68
- chat_history = state.get("chat_history",[])
69
- print("Sources used : " + "\n".join([x.metadata["short_name"] + " - page " + str(x.metadata["page_number"]) for x in state["documents"]]))
70
 
71
- answer = await rag_chain.ainvoke(state,config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- end_time = time.time()
74
- elapsed_time = end_time - start_time
75
- print("RAG elapsed time: ", elapsed_time)
76
- print("Answer size : ", len(answer))
77
-
78
- chat_history.append({"question":state["query"],"answer":answer})
79
-
80
- return {"answer":answer,"chat_history": chat_history}
81
 
82
- return answer_rag
 
 
 
83
 
84
 
 
 
 
 
 
85
 
86
 
87
  def make_rag_papers_chain(llm):
88
 
89
  prompt = ChatPromptTemplate.from_template(papers_prompt_template)
 
90
  input_documents = {
91
  "context":lambda x : _combine_documents(x["docs"]),
92
  **pass_values(["question","language"])
@@ -112,4 +131,4 @@ def make_illustration_chain(llm):
112
  }
113
 
114
  illustration_chain = input_description_images | prompt_with_images | llm | StrOutputParser()
115
- return illustration_chain
 
2
 
3
  from langchain_core.prompts import ChatPromptTemplate
4
  from langchain_core.output_parsers import StrOutputParser
5
+ from langchain_core.runnables import RunnablePassthrough, RunnableLambda, RunnableBranch
6
  from langchain_core.prompts.prompt import PromptTemplate
7
  from langchain_core.prompts.base import format_document
8
 
9
+ from climateqa.engine.reformulation import make_reformulation_chain
10
+ from climateqa.engine.prompts import answer_prompt_template,answer_prompt_without_docs_template,answer_prompt_images_template
11
+ from climateqa.engine.prompts import papers_prompt_template
12
+ from climateqa.engine.utils import pass_values, flatten_dict,prepare_chain,rename_chain
13
+ from climateqa.engine.keywords import make_keywords_chain
14
 
15
+ DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
 
16
 
17
  def _combine_documents(
18
  docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, sep="\n\n"
 
40
  def get_image_docs(x):
41
  return [doc for doc in x if doc.metadata["chunk_type"] == "image"]
42
 
43
+
44
+ def make_rag_chain(retriever,llm):
45
+
46
+ # Construct the prompt
47
  prompt = ChatPromptTemplate.from_template(answer_prompt_template)
48
+ prompt_without_docs = ChatPromptTemplate.from_template(answer_prompt_without_docs_template)
 
 
 
 
 
 
 
49
 
50
+ # ------- CHAIN 0 - Reformulation
51
+ reformulation = make_reformulation_chain(llm)
52
+ reformulation = prepare_chain(reformulation,"reformulation")
 
53
 
54
+ # ------- Find all keywords from the reformulated query
55
+ keywords = make_keywords_chain(llm)
56
+ keywords = {"keywords":itemgetter("question") | keywords}
57
+ keywords = prepare_chain(keywords,"keywords")
58
 
59
+ # ------- CHAIN 1
60
+ # Retrieved documents
61
+ find_documents = {"docs": itemgetter("question") | retriever} | RunnablePassthrough()
62
+ find_documents = prepare_chain(find_documents,"find_documents")
 
 
 
 
 
 
63
 
64
+ # ------- CHAIN 2
65
+ # Construct inputs for the llm
66
+ input_documents = {
67
+ "context":lambda x : _combine_documents(x["docs"]),
68
+ **pass_values(["question","audience","language","keywords"])
69
+ }
70
+
71
+ # ------- CHAIN 3
72
+ # Bot answer
73
+ llm_final = rename_chain(llm,"answer")
74
+
75
+ answer_with_docs = {
76
+ "answer": input_documents | prompt | llm_final | StrOutputParser(),
77
+ **pass_values(["question","audience","language","query","docs","keywords"]),
78
+ }
79
+
80
+ answer_without_docs = {
81
+ "answer": prompt_without_docs | llm_final | StrOutputParser(),
82
+ **pass_values(["question","audience","language","query","docs","keywords"]),
83
+ }
84
+
85
+ # def has_images(x):
86
+ # image_docs = [doc for doc in x["docs"] if doc.metadata["chunk_type"]=="image"]
87
+ # return len(image_docs) > 0
88
 
89
+ def has_docs(x):
90
+ return len(x["docs"]) > 0
 
 
 
 
 
 
91
 
92
+ answer = RunnableBranch(
93
+ (lambda x: has_docs(x), answer_with_docs),
94
+ answer_without_docs,
95
+ )
96
 
97
 
98
+ # ------- FINAL CHAIN
99
+ # Build the final chain
100
+ rag_chain = reformulation | keywords | find_documents | answer
101
+
102
+ return rag_chain
103
 
104
 
105
  def make_rag_papers_chain(llm):
106
 
107
  prompt = ChatPromptTemplate.from_template(papers_prompt_template)
108
+
109
  input_documents = {
110
  "context":lambda x : _combine_documents(x["docs"]),
111
  **pass_values(["question","language"])
 
131
  }
132
 
133
  illustration_chain = input_description_images | prompt_with_images | llm | StrOutputParser()
134
+ return illustration_chain
climateqa/engine/{chains/reformulation.py → reformulation.py} RENAMED
@@ -3,7 +3,7 @@ from langchain.output_parsers.structured import StructuredOutputParser, Response
3
  from langchain_core.prompts import PromptTemplate
4
  from langchain_core.runnables import RunnablePassthrough, RunnableLambda, RunnableBranch
5
 
6
- from climateqa.engine.chains.prompts import reformulation_prompt_template
7
  from climateqa.engine.utils import pass_values, flatten_dict
8
 
9
 
 
3
  from langchain_core.prompts import PromptTemplate
4
  from langchain_core.runnables import RunnablePassthrough, RunnableLambda, RunnableBranch
5
 
6
+ from climateqa.engine.prompts import reformulation_prompt_template
7
  from climateqa.engine.utils import pass_values, flatten_dict
8
 
9
 
climateqa/engine/reranker.py DELETED
@@ -1,55 +0,0 @@
1
- import os
2
- from dotenv import load_dotenv
3
- from scipy.special import expit, logit
4
- from rerankers import Reranker
5
- from sentence_transformers import CrossEncoder
6
-
7
- load_dotenv()
8
-
9
- def get_reranker(model = "nano", cohere_api_key = None):
10
-
11
- assert model in ["nano","tiny","small","large", "jina"]
12
-
13
- if model == "nano":
14
- reranker = Reranker('ms-marco-TinyBERT-L-2-v2', model_type='flashrank')
15
- elif model == "tiny":
16
- reranker = Reranker('ms-marco-MiniLM-L-12-v2', model_type='flashrank')
17
- elif model == "small":
18
- reranker = Reranker("mixedbread-ai/mxbai-rerank-xsmall-v1", model_type='cross-encoder')
19
- elif model == "large":
20
- if cohere_api_key is None:
21
- cohere_api_key = os.environ["COHERE_API_KEY"]
22
- reranker = Reranker("cohere", lang='en', api_key = cohere_api_key)
23
- elif model == "jina":
24
- # Reached token quota so does not work
25
- reranker = Reranker("jina-reranker-v2-base-multilingual", api_key = os.getenv("JINA_RERANKER_API_KEY"))
26
- # marche pas sans gpu ? et anyways returns with another structure donc faudrait changer le code du retriever node
27
- # reranker = CrossEncoder("jinaai/jina-reranker-v2-base-multilingual", automodel_args={"torch_dtype": "auto"}, trust_remote_code=True,)
28
- return reranker
29
-
30
-
31
-
32
- def rerank_docs(reranker,docs,query):
33
- if docs == []:
34
- return []
35
-
36
- # Get a list of texts from langchain docs
37
- input_docs = [x.page_content for x in docs]
38
-
39
- # Rerank using rerankers library
40
- results = reranker.rank(query=query, docs=input_docs)
41
-
42
- # Prepare langchain list of docs
43
- docs_reranked = []
44
- for result in results.results:
45
- doc_id = result.document.doc_id
46
- doc = docs[doc_id]
47
- doc.metadata["reranking_score"] = result.score
48
- doc.metadata["query_used_for_retrieval"] = query
49
- docs_reranked.append(doc)
50
- return docs_reranked
51
-
52
- def rerank_and_sort_docs(reranker, docs, query):
53
- docs_reranked = rerank_docs(reranker,docs,query)
54
- docs_reranked = sorted(docs_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
55
- return docs_reranked
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/retriever.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/langchain-ai/langchain/issues/8623
2
+
3
+ import pandas as pd
4
+
5
+ from langchain_core.retrievers import BaseRetriever
6
+ from langchain_core.vectorstores import VectorStoreRetriever
7
+ from langchain_core.documents.base import Document
8
+ from langchain_core.vectorstores import VectorStore
9
+ from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
10
+
11
+ from typing import List
12
+ from pydantic import Field
13
+
14
+ class ClimateQARetriever(BaseRetriever):
15
+ vectorstore:VectorStore
16
+ sources:list = ["IPCC","IPBES","IPOS"]
17
+ reports:list = []
18
+ threshold:float = 0.6
19
+ k_summary:int = 3
20
+ k_total:int = 10
21
+ namespace:str = "vectors",
22
+ min_size:int = 200,
23
+
24
+
25
+ def _get_relevant_documents(
26
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
27
+ ) -> List[Document]:
28
+
29
+ # Check if all elements in the list are either IPCC or IPBES
30
+ assert isinstance(self.sources,list)
31
+ assert all([x in ["IPCC","IPBES","IPOS"] for x in self.sources])
32
+ assert self.k_total > self.k_summary, "k_total should be greater than k_summary"
33
+
34
+ # Prepare base search kwargs
35
+ filters = {}
36
+
37
+ if len(self.reports) > 0:
38
+ filters["short_name"] = {"$in":self.reports}
39
+ else:
40
+ filters["source"] = { "$in":self.sources}
41
+
42
+ # Search for k_summary documents in the summaries dataset
43
+ filters_summaries = {
44
+ **filters,
45
+ "report_type": { "$in":["SPM"]},
46
+ }
47
+
48
+ docs_summaries = self.vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = self.k_summary)
49
+ docs_summaries = [x for x in docs_summaries if x[1] > self.threshold]
50
+
51
+ # Search for k_total - k_summary documents in the full reports dataset
52
+ filters_full = {
53
+ **filters,
54
+ "report_type": { "$nin":["SPM"]},
55
+ }
56
+ k_full = self.k_total - len(docs_summaries)
57
+ docs_full = self.vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_full)
58
+
59
+ # Concatenate documents
60
+ docs = docs_summaries + docs_full
61
+
62
+ # Filter if scores are below threshold
63
+ docs = [x for x in docs if len(x[0].page_content) > self.min_size]
64
+ # docs = [x for x in docs if x[1] > self.threshold]
65
+
66
+ # Add score to metadata
67
+ results = []
68
+ for i,(doc,score) in enumerate(docs):
69
+ doc.metadata["similarity_score"] = score
70
+ doc.metadata["content"] = doc.page_content
71
+ doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
72
+ # doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
73
+ results.append(doc)
74
+
75
+ # Sort by score
76
+ # results = sorted(results,key = lambda x : x.metadata["similarity_score"],reverse = True)
77
+
78
+ return results
79
+
80
+
81
+
82
+
83
+ # def filter_summaries(df,k_summary = 3,k_total = 10):
84
+ # # assert source in ["IPCC","IPBES","ALL"], "source arg should be in (IPCC,IPBES,ALL)"
85
+
86
+ # # # Filter by source
87
+ # # if source == "IPCC":
88
+ # # df = df.loc[df["source"]=="IPCC"]
89
+ # # elif source == "IPBES":
90
+ # # df = df.loc[df["source"]=="IPBES"]
91
+ # # else:
92
+ # # pass
93
+
94
+ # # Separate summaries and full reports
95
+ # df_summaries = df.loc[df["report_type"].isin(["SPM","TS"])]
96
+ # df_full = df.loc[~df["report_type"].isin(["SPM","TS"])]
97
+
98
+ # # Find passages from summaries dataset
99
+ # passages_summaries = df_summaries.head(k_summary)
100
+
101
+ # # Find passages from full reports dataset
102
+ # passages_fullreports = df_full.head(k_total - len(passages_summaries))
103
+
104
+ # # Concatenate passages
105
+ # passages = pd.concat([passages_summaries,passages_fullreports],axis = 0,ignore_index = True)
106
+ # return passages
107
+
108
+
109
+
110
+
111
+ # def retrieve_with_summaries(query,retriever,k_summary = 3,k_total = 10,sources = ["IPCC","IPBES"],max_k = 100,threshold = 0.555,as_dict = True,min_length = 300):
112
+ # assert max_k > k_total
113
+
114
+ # validated_sources = ["IPCC","IPBES"]
115
+ # sources = [x for x in sources if x in validated_sources]
116
+ # filters = {
117
+ # "source": { "$in": sources },
118
+ # }
119
+ # print(filters)
120
+
121
+ # # Retrieve documents
122
+ # docs = retriever.retrieve(query,top_k = max_k,filters = filters)
123
+
124
+ # # Filter by score
125
+ # docs = [{**x.meta,"score":x.score,"content":x.content} for x in docs if x.score > threshold]
126
+
127
+ # if len(docs) == 0:
128
+ # return []
129
+ # res = pd.DataFrame(docs)
130
+ # passages_df = filter_summaries(res,k_summary,k_total)
131
+ # if as_dict:
132
+ # contents = passages_df["content"].tolist()
133
+ # meta = passages_df.drop(columns = ["content"]).to_dict(orient = "records")
134
+ # passages = []
135
+ # for i in range(len(contents)):
136
+ # passages.append({"content":contents[i],"meta":meta[i]})
137
+ # return passages
138
+ # else:
139
+ # return passages_df
140
+
141
+
142
+
143
+ # def retrieve(query,sources = ["IPCC"],threshold = 0.555,k = 10):
144
+
145
+
146
+ # print("hellooooo")
147
+
148
+ # # Reformulate queries
149
+ # reformulated_query,language = reformulate(query)
150
+
151
+ # print(reformulated_query)
152
+
153
+ # # Retrieve documents
154
+ # passages = retrieve_with_summaries(reformulated_query,retriever,k_total = k,k_summary = 3,as_dict = True,sources = sources,threshold = threshold)
155
+ # response = {
156
+ # "query":query,
157
+ # "reformulated_query":reformulated_query,
158
+ # "language":language,
159
+ # "sources":passages,
160
+ # "prompts":{"init_prompt":init_prompt,"sources_prompt":sources_prompt},
161
+ # }
162
+ # return response
163
+
climateqa/engine/talk_to_data/config.py DELETED
@@ -1,11 +0,0 @@
1
- # Path configuration for climateqa project
2
-
3
- # IPCC dataset path
4
- IPCC_DATASET_URL = "hf://datasets/ekimetrics/ipcc-atlas"
5
-
6
- # DRIAS dataset paths
7
- DRIAS_DATASET_URL = "hf://datasets/timeki/drias_db"
8
-
9
- # Table paths
10
- DRIAS_MEAN_ANNUAL_TEMPERATURE_PATH = f"{DRIAS_DATASET_URL}/mean_annual_temperature.parquet"
11
- IPCC_COORDINATES_PATH = f"{IPCC_DATASET_URL}/coordinates.parquet"
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/talk_to_data/drias/config.py DELETED
@@ -1,124 +0,0 @@
1
-
2
- from climateqa.engine.talk_to_data.ui_config import PRECIPITATION_COLORSCALE, TEMPERATURE_COLORSCALE
3
-
4
-
5
- DRIAS_TABLES = [
6
- "total_winter_precipitation",
7
- "total_summer_precipitation",
8
- "total_annual_precipitation",
9
- "total_remarkable_daily_precipitation",
10
- "frequency_of_remarkable_daily_precipitation",
11
- "extreme_precipitation_intensity",
12
- "mean_winter_temperature",
13
- "mean_summer_temperature",
14
- "mean_annual_temperature",
15
- "number_of_tropical_nights",
16
- "maximum_summer_temperature",
17
- "number_of_days_with_tx_above_30",
18
- "number_of_days_with_tx_above_35",
19
- "number_of_days_with_a_dry_ground",
20
- ]
21
-
22
- DRIAS_INDICATOR_COLUMNS_PER_TABLE = {
23
- "total_winter_precipitation": "total_winter_precipitation",
24
- "total_summer_precipitation": "total_summer_precipitation",
25
- "total_annual_precipitation": "total_annual_precipitation",
26
- "total_remarkable_daily_precipitation": "total_remarkable_daily_precipitation",
27
- "frequency_of_remarkable_daily_precipitation": "frequency_of_remarkable_daily_precipitation",
28
- "extreme_precipitation_intensity": "extreme_precipitation_intensity",
29
- "mean_winter_temperature": "mean_winter_temperature",
30
- "mean_summer_temperature": "mean_summer_temperature",
31
- "mean_annual_temperature": "mean_annual_temperature",
32
- "number_of_tropical_nights": "number_tropical_nights",
33
- "maximum_summer_temperature": "maximum_summer_temperature",
34
- "number_of_days_with_tx_above_30": "number_of_days_with_tx_above_30",
35
- "number_of_days_with_tx_above_35": "number_of_days_with_tx_above_35",
36
- "number_of_days_with_a_dry_ground": "number_of_days_with_dry_ground"
37
- }
38
-
39
- DRIAS_MODELS = [
40
- 'ALL',
41
- 'RegCM4-6_MPI-ESM-LR',
42
- 'RACMO22E_EC-EARTH',
43
- 'RegCM4-6_HadGEM2-ES',
44
- 'HadREM3-GA7_EC-EARTH',
45
- 'HadREM3-GA7_CNRM-CM5',
46
- 'REMO2015_NorESM1-M',
47
- 'SMHI-RCA4_EC-EARTH',
48
- 'WRF381P_NorESM1-M',
49
- 'ALADIN63_CNRM-CM5',
50
- 'CCLM4-8-17_MPI-ESM-LR',
51
- 'HIRHAM5_IPSL-CM5A-MR',
52
- 'HadREM3-GA7_HadGEM2-ES',
53
- 'SMHI-RCA4_IPSL-CM5A-MR',
54
- 'HIRHAM5_NorESM1-M',
55
- 'REMO2009_MPI-ESM-LR',
56
- 'CCLM4-8-17_HadGEM2-ES'
57
- ]
58
- # Mapping between indicator columns and their units
59
- DRIAS_INDICATOR_TO_UNIT = {
60
- "total_winter_precipitation": "mm",
61
- "total_summer_precipitation": "mm",
62
- "total_annual_precipitation": "mm",
63
- "total_remarkable_daily_precipitation": "mm",
64
- "frequency_of_remarkable_daily_precipitation": "days",
65
- "extreme_precipitation_intensity": "mm",
66
- "mean_winter_temperature": "°C",
67
- "mean_summer_temperature": "°C",
68
- "mean_annual_temperature": "°C",
69
- "number_tropical_nights": "days",
70
- "maximum_summer_temperature": "°C",
71
- "number_of_days_with_tx_above_30": "days",
72
- "number_of_days_with_tx_above_35": "days",
73
- "number_of_days_with_dry_ground": "days"
74
- }
75
-
76
- DRIAS_PLOT_PARAMETERS = [
77
- 'year',
78
- 'location'
79
- ]
80
-
81
- DRIAS_INDICATOR_TO_COLORSCALE = {
82
- "total_winter_precipitation": PRECIPITATION_COLORSCALE,
83
- "total_summer_precipitation": PRECIPITATION_COLORSCALE,
84
- "total_annual_precipitation": PRECIPITATION_COLORSCALE,
85
- "total_remarkable_daily_precipitation": PRECIPITATION_COLORSCALE,
86
- "frequency_of_remarkable_daily_precipitation": PRECIPITATION_COLORSCALE,
87
- "extreme_precipitation_intensity": PRECIPITATION_COLORSCALE,
88
- "mean_winter_temperature":TEMPERATURE_COLORSCALE,
89
- "mean_summer_temperature":TEMPERATURE_COLORSCALE,
90
- "mean_annual_temperature":TEMPERATURE_COLORSCALE,
91
- "number_tropical_nights": TEMPERATURE_COLORSCALE,
92
- "maximum_summer_temperature":TEMPERATURE_COLORSCALE,
93
- "number_of_days_with_tx_above_30": TEMPERATURE_COLORSCALE,
94
- "number_of_days_with_tx_above_35": TEMPERATURE_COLORSCALE,
95
- "number_of_days_with_dry_ground": TEMPERATURE_COLORSCALE
96
- }
97
-
98
- DRIAS_UI_TEXT = """
99
- Hi, I'm **Talk to Drias**, designed to answer your questions using [**DRIAS - TRACC 2023**](https://www.drias-climat.fr/accompagnement/sections/401) data.
100
- I'll answer by displaying a list of SQL queries, graphs and data most relevant to your question.
101
-
102
- You can ask me anything about these climate indicators: **temperature**, **precipitation** or **drought**.
103
- You can specify **location** and/or **year**.
104
- You can choose from a list of climate models. By default, we take the **average of each model**.
105
-
106
- For example, you can ask:
107
- - What will the temperature be like in Paris?
108
- - What will be the total rainfall in France in 2030?
109
- - How frequent will extreme events be in Lyon?
110
-
111
- **Example of indicators in the data**:
112
- - Mean temperature (annual, winter, summer)
113
- - Total precipitation (annual, winter, summer)
114
- - Number of days with remarkable precipitations, with dry ground, with temperature above 30°C
115
-
116
- ⚠️ **Limitations**:
117
- - You can't ask anything that isn't related to **DRIAS - TRACC 2023** data.
118
- - You can only ask about **locations in France**.
119
- - If you specify a year, there may be **no data for that year for some models**.
120
- - You **cannot compare two models**.
121
-
122
- 🛈 **Information**
123
- Please note that we **log your questions for meta-analysis purposes**, so avoid sharing any sensitive or personal information.
124
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/talk_to_data/drias/plot_informations.py DELETED
@@ -1,88 +0,0 @@
1
- from climateqa.engine.talk_to_data.drias.config import DRIAS_INDICATOR_TO_UNIT
2
-
3
- def indicator_evolution_informations(
4
- indicator: str,
5
- params: dict[str, str]
6
- ) -> str:
7
- unit = DRIAS_INDICATOR_TO_UNIT[indicator]
8
- if "location" not in params:
9
- raise ValueError('"location" must be provided in params')
10
- location = params["location"]
11
- return f"""
12
- This plot shows how the climate indicator **{indicator}** evolves over time in **{location}**.
13
-
14
- It combines both historical observations and future projections according to the climate scenario RCP8.5.
15
-
16
- The x-axis represents the years, and the y-axis shows the value of the indicator ({unit}).
17
-
18
- A 10-year rolling average curve is displayed to give a better idea of the overall trend.
19
-
20
- **Data source:**
21
- - The data come from the DRIAS TRACC data. The data were initially extracted from [the DRIAS website](https://www.drias-climat.fr/drias_prod/accueil/okapiWebDrias/index.jsp?iddrias=climat) and then preprocessed to a tabular format and uploaded as parquet in this [Hugging Face dataset](https://huggingface.co/datasets/timeki/drias_db).
22
- - For each year and climate model, the value of {indicator} in {location} is collected, to build the time series.
23
- - The coordinates used for {location} correspond to the closest available point in the DRIAS database, which uses a regular grid with a spatial resolution of 8 km.
24
- - The indicator values shown are those for the selected climate model.
25
- - If ALL climate model is selected, the average value of the indicator between all the climate models is used.
26
- """
27
-
28
- def indicator_number_of_days_per_year_informations(
29
- indicator: str,
30
- params: dict[str, str]
31
- ) -> str:
32
- unit = DRIAS_INDICATOR_TO_UNIT[indicator]
33
- if "location" not in params:
34
- raise ValueError('"location" must be provided in params')
35
- location = params["location"]
36
- return f"""
37
- This plot displays a bar chart showing the yearly frequency of the climate indicator **{indicator}** in **{location}**.
38
-
39
- The x-axis represents the years, and the y-axis shows the frequency of {indicator} ({unit}) per year.
40
-
41
- **Data source:**
42
- - The data come from the DRIAS TRACC data. The data were initially extracted from [the DRIAS website](https://www.drias-climat.fr/drias_prod/accueil/okapiWebDrias/index.jsp?iddrias=climat) and then preprocessed to a tabular format and uploaded as parquet in this [Hugging Face dataset](https://huggingface.co/datasets/timeki/drias_db).
43
- - For each year and climate model, the value of {indicator} in {location} is collected, to build the time series.
44
- - The coordinates used for {location} correspond to the closest available point in the DRIAS database, which uses a regular grid with a spatial resolution of 8 km.
45
- - The indicator values shown are those for the selected climate model.
46
- - If ALL climate model is selected, the average value of the indicator between all the climate models is used.
47
- """
48
-
49
- def distribution_of_indicator_for_given_year_informations(
50
- indicator: str,
51
- params: dict[str, str]
52
- ) -> str:
53
- unit = DRIAS_INDICATOR_TO_UNIT[indicator]
54
- year = params["year"]
55
- if year is None:
56
- year = 2030
57
- return f"""
58
- This plot shows a histogram of the distribution of the climate indicator **{indicator}** across all locations for the year **{year}**.
59
-
60
- It allows you to visualize how the values of {indicator} ({unit}) are spread for a given year.
61
-
62
- **Data source:**
63
- - The data come from the DRIAS TRACC data. The data were initially extracted from [the DRIAS website](https://www.drias-climat.fr/drias_prod/accueil/okapiWebDrias/index.jsp?iddrias=climat) and then preprocessed to a tabular format and uploaded as parquet in this [Hugging Face dataset](https://huggingface.co/datasets/timeki/drias_db).
64
- - For each grid point in the dataset and climate model, the value of {indicator} for the year {year} is extracted.
65
- - The indicator values shown are those for the selected climate model.
66
- - If ALL climate model is selected, the average value of the indicator between all the climate models is used.
67
- """
68
-
69
- def map_of_france_of_indicator_for_given_year_informations(
70
- indicator: str,
71
- params: dict[str, str]
72
- ) -> str:
73
- unit = DRIAS_INDICATOR_TO_UNIT[indicator]
74
- year = params["year"]
75
- if year is None:
76
- year = 2030
77
- return f"""
78
- This plot displays a choropleth map showing the spatial distribution of **{indicator}** across all regions of France for the year **{year}**.
79
-
80
- Each region is colored according to the value of the indicator ({unit}), allowing you to visually compare how {indicator} varies geographically within France for the selected year and climate model.
81
-
82
- **Data source:**
83
- - The data come from the DRIAS TRACC data. The data were initially extracted from [the DRIAS website](https://www.drias-climat.fr/drias_prod/accueil/okapiWebDrias/index.jsp?iddrias=climat) and then preprocessed to a tabular format and uploaded as parquet in this [Hugging Face dataset](https://huggingface.co/datasets/timeki/drias_db).
84
- - For each region of France, the value of {indicator} in {year} and for the selected climate model is extracted and mapped to its geographic coordinates.
85
- - The regions correspond to 8 km squares centered on the grid points of the DRIAS dataset.
86
- - The indicator values shown are those for the selected climate model.
87
- - If ALL climate model is selected, the average value of the indicator between all the climate models is used.
88
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/talk_to_data/drias/plots.py DELETED
@@ -1,434 +0,0 @@
1
- import os
2
- import geojson
3
- from math import cos, radians
4
- from typing import Callable
5
- import pandas as pd
6
- from plotly.graph_objects import Figure
7
- import plotly.graph_objects as go
8
- from climateqa.engine.talk_to_data.drias.plot_informations import distribution_of_indicator_for_given_year_informations, indicator_evolution_informations, indicator_number_of_days_per_year_informations, map_of_france_of_indicator_for_given_year_informations
9
- from climateqa.engine.talk_to_data.objects.plot import Plot
10
- from climateqa.engine.talk_to_data.drias.queries import (
11
- indicator_for_given_year_query,
12
- indicator_per_year_at_location_query,
13
- )
14
- from climateqa.engine.talk_to_data.drias.config import DRIAS_INDICATOR_TO_COLORSCALE, DRIAS_INDICATOR_TO_UNIT
15
-
16
- def generate_geojson_polygons(latitudes: list[float], longitudes: list[float], indicators: list[float]) -> geojson.FeatureCollection:
17
- side_km = 8
18
- delta_lat = side_km / 111
19
- features = []
20
- for idx, (lat, lon, val) in enumerate(zip(latitudes, longitudes, indicators)):
21
- delta_lon = side_km / (111 * cos(radians(lat)))
22
- half_lat = delta_lat / 2
23
- half_lon = delta_lon / 2
24
- features.append(geojson.Feature(
25
- geometry=geojson.Polygon([[
26
- [lon - half_lon, lat - half_lat],
27
- [lon + half_lon, lat - half_lat],
28
- [lon + half_lon, lat + half_lat],
29
- [lon - half_lon, lat + half_lat],
30
- [lon - half_lon, lat - half_lat]
31
- ]]),
32
- properties={"value": val},
33
- id=str(idx)
34
- ))
35
-
36
- return geojson.FeatureCollection(features)
37
-
38
- def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
39
- """Generates a function to plot indicator evolution over time at a location.
40
-
41
- This function creates a line plot showing how a climate indicator changes
42
- over time at a specific location. It handles temperature, precipitation,
43
- and other climate indicators.
44
-
45
- Args:
46
- params (dict): Dictionary containing:
47
- - indicator_column (str): The column name for the indicator
48
- - location (str): The location to plot
49
- - model (str): The climate model to use
50
-
51
- Returns:
52
- Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
53
-
54
- Example:
55
- >>> plot_func = plot_indicator_evolution_at_location({
56
- ... 'indicator_column': 'mean_temperature',
57
- ... 'location': 'Paris',
58
- ... 'model': 'ALL'
59
- ... })
60
- >>> fig = plot_func(df)
61
- """
62
- indicator = params["indicator_column"]
63
- location = params["location"]
64
- indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
65
- unit = DRIAS_INDICATOR_TO_UNIT.get(indicator, "")
66
-
67
- def plot_data(df: pd.DataFrame) -> Figure:
68
- """Generates the actual plot from the data.
69
-
70
- Args:
71
- df (pd.DataFrame): DataFrame containing the data to plot
72
-
73
- Returns:
74
- Figure: A plotly Figure object showing the indicator evolution
75
- """
76
- fig = go.Figure()
77
- if df['model'].nunique() != 1:
78
- df_avg = df.groupby("year", as_index=False)[indicator].mean()
79
-
80
- # Transform to list to avoid pandas encoding
81
- indicators = df_avg[indicator].astype(float).tolist()
82
- years = df_avg["year"].astype(int).tolist()
83
-
84
- # Compute the 10-year rolling average
85
- rolling_window = 10
86
- sliding_averages = (
87
- df_avg[indicator]
88
- .rolling(window=rolling_window, min_periods=rolling_window)
89
- .mean()
90
- .astype(float)
91
- .tolist()
92
- )
93
- model_label = "Model Average"
94
-
95
- # Only add rolling average if we have enough data points
96
- if len([x for x in sliding_averages if pd.notna(x)]) > 0:
97
- # Sliding average dashed line
98
- fig.add_scatter(
99
- x=years,
100
- y=sliding_averages,
101
- mode="lines",
102
- name="10 years rolling average",
103
- line=dict(dash="dash"),
104
- marker=dict(color="#d62728"),
105
- hovertemplate=f"10-year average: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
106
- )
107
-
108
- else:
109
- df_model = df
110
-
111
- # Transform to list to avoid pandas encoding
112
- indicators = df_model[indicator].astype(float).tolist()
113
- years = df_model["year"].astype(int).tolist()
114
-
115
- # Compute the 10-year rolling average
116
- rolling_window = 10
117
- sliding_averages = (
118
- df_model[indicator]
119
- .rolling(window=rolling_window, min_periods=rolling_window)
120
- .mean()
121
- .astype(float)
122
- .tolist()
123
- )
124
- model_label = f"Model : {df['model'].unique()[0]}"
125
-
126
- # Only add rolling average if we have enough data points
127
- if len([x for x in sliding_averages if pd.notna(x)]) > 0:
128
- # Sliding average dashed line
129
- fig.add_scatter(
130
- x=years,
131
- y=sliding_averages,
132
- mode="lines",
133
- name="10 years rolling average",
134
- line=dict(dash="dash"),
135
- marker=dict(color="#d62728"),
136
- hovertemplate=f"10-year average: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
137
- )
138
-
139
- # Indicator per year plot
140
- fig.add_scatter(
141
- x=years,
142
- y=indicators,
143
- name=f"Yearly {indicator_label}",
144
- mode="lines",
145
- marker=dict(color="#1f77b4"),
146
- hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
147
- )
148
- fig.update_layout(
149
- title=f"Evolution of {indicator_label} in {location} ({model_label})",
150
- xaxis_title="Year",
151
- yaxis_title=f"{indicator_label} ({unit})",
152
- template="plotly_white",
153
- height=900,
154
- )
155
- return fig
156
-
157
- return plot_data
158
-
159
-
160
- indicator_evolution_at_location: Plot = {
161
- "name": "Indicator evolution at location",
162
- "description": "Plot an evolution of the indicator at a certain location",
163
- "params": ["indicator_column", "location", "model"],
164
- "plot_function": plot_indicator_evolution_at_location,
165
- "sql_query": indicator_per_year_at_location_query,
166
- "plot_information": indicator_evolution_informations,
167
- 'short_name': 'Evolution'
168
- }
169
-
170
-
171
- def plot_indicator_number_of_days_per_year_at_location(
172
- params: dict,
173
- ) -> Callable[..., Figure]:
174
- """Generates a function to plot the number of days per year for an indicator.
175
-
176
- This function creates a bar chart showing the frequency of certain climate
177
- events (like days above a temperature threshold) per year at a specific location.
178
-
179
- Args:
180
- params (dict): Dictionary containing:
181
- - indicator_column (str): The column name for the indicator
182
- - location (str): The location to plot
183
- - model (str): The climate model to use
184
-
185
- Returns:
186
- Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
187
- """
188
- indicator = params["indicator_column"]
189
- location = params["location"]
190
- indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
191
- unit = DRIAS_INDICATOR_TO_UNIT.get(indicator, "")
192
-
193
- def plot_data(df: pd.DataFrame) -> Figure:
194
- """Generate the figure thanks to the dataframe
195
-
196
- Args:
197
- df (pd.DataFrame): pandas dataframe with the required data
198
-
199
- Returns:
200
- Figure: Plotly figure
201
- """
202
- fig = go.Figure()
203
- if df['model'].nunique() != 1:
204
- df_avg = df.groupby("year", as_index=False)[indicator].mean()
205
-
206
- # Transform to list to avoid pandas encoding
207
- indicators = df_avg[indicator].astype(float).tolist()
208
- years = df_avg["year"].astype(int).tolist()
209
- model_label = "Model Average"
210
-
211
- else:
212
- df_model = df
213
- # Transform to list to avoid pandas encoding
214
- indicators = df_model[indicator].astype(float).tolist()
215
- years = df_model["year"].astype(int).tolist()
216
- model_label = f"Model : {df['model'].unique()[0]}"
217
-
218
-
219
- # Bar plot
220
- fig.add_trace(
221
- go.Bar(
222
- x=years,
223
- y=indicators,
224
- width=0.5,
225
- marker=dict(color="#1f77b4"),
226
- hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
227
- )
228
- )
229
-
230
- fig.update_layout(
231
- title=f"{indicator_label} in {location} ({model_label})",
232
- xaxis_title="Year",
233
- yaxis_title=f"{indicator_label} ({unit})",
234
- yaxis=dict(range=[0, max(indicators)]),
235
- bargap=0.5,
236
- height=900,
237
- template="plotly_white",
238
- )
239
-
240
- return fig
241
-
242
- return plot_data
243
-
244
-
245
- indicator_number_of_days_per_year_at_location: Plot = {
246
- "name": "Indicator number of days per year at location",
247
- "description": "Plot a barchart of the number of days per year of a certain indicator at a certain location. It is appropriate for frequency indicator.",
248
- "params": ["indicator_column", "location", "model"],
249
- "plot_function": plot_indicator_number_of_days_per_year_at_location,
250
- "sql_query": indicator_per_year_at_location_query,
251
- "plot_information": indicator_number_of_days_per_year_informations,
252
- "short_name": "Yearly Frequency",
253
- }
254
-
255
-
256
- def plot_distribution_of_indicator_for_given_year(
257
- params: dict,
258
- ) -> Callable[..., Figure]:
259
- """Generates a function to plot the distribution of an indicator for a year.
260
-
261
- This function creates a histogram showing the distribution of a climate
262
- indicator across different locations for a specific year.
263
-
264
- Args:
265
- params (dict): Dictionary containing:
266
- - indicator_column (str): The column name for the indicator
267
- - year (str): The year to plot
268
- - model (str): The climate model to use
269
-
270
- Returns:
271
- Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
272
- """
273
- indicator = params["indicator_column"]
274
- year = params["year"]
275
- if year is None:
276
- year = 2030
277
- indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
278
- unit = DRIAS_INDICATOR_TO_UNIT.get(indicator, "")
279
-
280
- def plot_data(df: pd.DataFrame) -> Figure:
281
- """Generate the figure thanks to the dataframe
282
-
283
- Args:
284
- df (pd.DataFrame): pandas dataframe with the required data
285
-
286
- Returns:
287
- Figure: Plotly figure
288
- """
289
- fig = go.Figure()
290
- if df['model'].nunique() != 1:
291
- df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
292
- indicator
293
- ].mean()
294
-
295
- # Transform to list to avoid pandas encoding
296
- indicators = df_avg[indicator].astype(float).tolist()
297
- model_label = "Model Average"
298
-
299
- else:
300
- df_model = df
301
-
302
- # Transform to list to avoid pandas encoding
303
- indicators = df_model[indicator].astype(float).tolist()
304
- model_label = f"Model : {df['model'].unique()[0]}"
305
-
306
-
307
- fig.add_trace(
308
- go.Histogram(
309
- x=indicators,
310
- opacity=0.8,
311
- histnorm="percent",
312
- marker=dict(color="#1f77b4"),
313
- hovertemplate=f"{indicator_label}: %{{x:.2f}} {unit}<br>Frequency: %{{y:.2f}}%<extra></extra>"
314
- )
315
- )
316
-
317
- fig.update_layout(
318
- title=f"Distribution of {indicator_label} in {year} ({model_label})",
319
- xaxis_title=f"{indicator_label} ({unit})",
320
- yaxis_title="Frequency (%)",
321
- plot_bgcolor="rgba(0, 0, 0, 0)",
322
- showlegend=False,
323
- height=900,
324
- )
325
-
326
- return fig
327
-
328
- return plot_data
329
-
330
-
331
- distribution_of_indicator_for_given_year: Plot = {
332
- "name": "Distribution of an indicator for a given year",
333
- "description": "Plot an histogram of the distribution for a given year of the values of an indicator",
334
- "params": ["indicator_column", "model", "year"],
335
- "plot_function": plot_distribution_of_indicator_for_given_year,
336
- "sql_query": indicator_for_given_year_query,
337
- "plot_information": distribution_of_indicator_for_given_year_informations,
338
- 'short_name': 'Distribution'
339
- }
340
-
341
-
342
- def plot_map_of_france_of_indicator_for_given_year(
343
- params: dict,
344
- ) -> Callable[..., Figure]:
345
- """Generates a function to plot a map of France for an indicator.
346
-
347
- This function creates a choropleth map of France showing the spatial
348
- distribution of a climate indicator for a specific year.
349
-
350
- Args:
351
- params (dict): Dictionary containing:
352
- - indicator_column (str): The column name for the indicator
353
- - year (str): The year to plot
354
- - model (str): The climate model to use
355
-
356
- Returns:
357
- Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
358
- """
359
- indicator = params["indicator_column"]
360
- year = params["year"]
361
- if year is None:
362
- year = 2030
363
- indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
364
- unit = DRIAS_INDICATOR_TO_UNIT.get(indicator, "")
365
-
366
- def plot_data(df: pd.DataFrame) -> Figure:
367
- fig = go.Figure()
368
- if df['model'].nunique() != 1:
369
- df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
370
- indicator
371
- ].mean()
372
-
373
- indicators = df_avg[indicator].astype(float).tolist()
374
- latitudes = df_avg["latitude"].astype(float).tolist()
375
- longitudes = df_avg["longitude"].astype(float).tolist()
376
- model_label = "Model Average"
377
-
378
- else:
379
- df_model = df
380
-
381
- # Transform to list to avoid pandas encoding
382
- indicators = df_model[indicator].astype(float).tolist()
383
- latitudes = df_model["latitude"].astype(float).tolist()
384
- longitudes = df_model["longitude"].astype(float).tolist()
385
- model_label = f"Model : {df['model'].unique()[0]}"
386
-
387
-
388
-
389
- geojson_data = generate_geojson_polygons(latitudes, longitudes, indicators)
390
-
391
- fig = go.Figure(go.Choroplethmapbox(
392
- geojson=geojson_data,
393
- locations=[str(i) for i in range(len(indicators))],
394
- featureidkey="id",
395
- z=indicators,
396
- colorscale=DRIAS_INDICATOR_TO_COLORSCALE[indicator],
397
- zmin=min(indicators),
398
- zmax=max(indicators),
399
- marker_opacity=0.7,
400
- marker_line_width=0,
401
- colorbar_title=f"{indicator_label} ({unit})",
402
- text=[f"{indicator_label}: {value:.2f} {unit}" for value in indicators], # Add hover text showing the indicator value
403
- hoverinfo="text"
404
- ))
405
-
406
- fig.update_layout(
407
- mapbox_style="open-street-map", # Use OpenStreetMap
408
- mapbox_zoom=5,
409
- height=900,
410
- mapbox_center={"lat": 46.6, "lon": 2.0},
411
- coloraxis_colorbar=dict(title=f"{indicator_label} ({unit})"), # Add legend
412
- title=f"{indicator_label} in {year} in France ({model_label}) " # Title
413
- )
414
- return fig
415
-
416
- return plot_data
417
-
418
-
419
- map_of_france_of_indicator_for_given_year: Plot = {
420
- "name": "Map of France of an indicator for a given year",
421
- "description": "Heatmap on the map of France of the values of an indicator for a given year",
422
- "params": ["indicator_column", "year", "model"],
423
- "plot_function": plot_map_of_france_of_indicator_for_given_year,
424
- "sql_query": indicator_for_given_year_query,
425
- "plot_information": map_of_france_of_indicator_for_given_year_informations,
426
- 'short_name': 'Map of France'
427
- }
428
-
429
- DRIAS_PLOTS = [
430
- indicator_evolution_at_location,
431
- indicator_number_of_days_per_year_at_location,
432
- distribution_of_indicator_for_given_year,
433
- map_of_france_of_indicator_for_given_year,
434
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/talk_to_data/drias/queries.py DELETED
@@ -1,83 +0,0 @@
1
- from typing import TypedDict
2
- from climateqa.engine.talk_to_data.config import DRIAS_DATASET_URL
3
-
4
- class IndicatorPerYearAtLocationQueryParams(TypedDict, total=False):
5
- """Parameters for querying an indicator's values over time at a location.
6
-
7
- This class defines the parameters needed to query climate indicator data
8
- for a specific location over multiple years.
9
-
10
- Attributes:
11
- indicator_column (str): The column name for the climate indicator
12
- latitude (str): The latitude coordinate of the location
13
- longitude (str): The longitude coordinate of the location
14
- model (str): The climate model to use (optional)
15
- """
16
- indicator_column: str
17
- latitude: str
18
- longitude: str
19
- model: str
20
-
21
- def indicator_per_year_at_location_query(
22
- table: str, params: IndicatorPerYearAtLocationQueryParams
23
- ) -> str:
24
- """SQL Query to get the evolution of an indicator per year at a certain location
25
-
26
- Args:
27
- table (str): sql table of the indicator
28
- params (IndicatorPerYearAtLocationQueryParams) : dictionary with the required params for the query
29
-
30
- Returns:
31
- str: the sql query
32
- """
33
- indicator_column = params.get("indicator_column")
34
- latitude = params.get("latitude")
35
- longitude = params.get("longitude")
36
-
37
- if indicator_column is None or latitude is None or longitude is None: # If one parameter is missing, returns an empty query
38
- return ""
39
-
40
- table = f"'{DRIAS_DATASET_URL}/{table.lower()}.parquet'"
41
-
42
- sql_query = f"SELECT year, {indicator_column}, model\nFROM {table}\nWHERE latitude = {latitude} \nAnd longitude = {longitude} \nOrder by Year"
43
-
44
- return sql_query
45
-
46
- class IndicatorForGivenYearQueryParams(TypedDict, total=False):
47
- """Parameters for querying an indicator's values across locations for a year.
48
-
49
- This class defines the parameters needed to query climate indicator data
50
- across different locations for a specific year.
51
-
52
- Attributes:
53
- indicator_column (str): The column name for the climate indicator
54
- year (str): The year to query
55
- model (str): The climate model to use (optional)
56
- """
57
- indicator_column: str
58
- year: str
59
- model: str
60
-
61
- def indicator_for_given_year_query(
62
- table:str, params: IndicatorForGivenYearQueryParams
63
- ) -> str:
64
- """SQL Query to get the values of an indicator with their latitudes, longitudes and models for a given year
65
-
66
- Args:
67
- table (str): sql table of the indicator
68
- params (IndicatorForGivenYearQueryParams): dictionarry with the required params for the query
69
-
70
- Returns:
71
- str: the sql query
72
- """
73
- indicator_column = params.get("indicator_column")
74
- year = params.get('year')
75
- if year is None:
76
- year = 2050
77
- if year is None or indicator_column is None: # If one parameter is missing, returns an empty query
78
- return ""
79
-
80
- table = f"'{DRIAS_DATASET_URL}/{table.lower()}.parquet'"
81
-
82
- sql_query = f"Select {indicator_column}, latitude, longitude, model\nFrom {table}\nWhere year = {year}"
83
- return sql_query
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/talk_to_data/input_processing.py DELETED
@@ -1,257 +0,0 @@
1
- from typing import Any, Literal, Optional, cast
2
- import ast
3
- from langchain_core.prompts import ChatPromptTemplate
4
- from geopy.geocoders import Nominatim
5
- from climateqa.engine.llm import get_llm
6
- import duckdb
7
- import os
8
- from climateqa.engine.talk_to_data.config import DRIAS_MEAN_ANNUAL_TEMPERATURE_PATH, IPCC_COORDINATES_PATH
9
- from climateqa.engine.talk_to_data.objects.llm_outputs import ArrayOutput
10
- from climateqa.engine.talk_to_data.objects.location import Location
11
- from climateqa.engine.talk_to_data.objects.plot import Plot
12
- from climateqa.engine.talk_to_data.objects.states import State
13
-
14
- async def detect_location_with_openai(sentence: str) -> str:
15
- """
16
- Detects locations in a sentence using OpenAI's API via LangChain.
17
- """
18
- llm = get_llm()
19
-
20
- prompt = f"""
21
- Extract all locations (cities, countries, states, or geographical areas) mentioned in the following sentence.
22
- Return the result as a Python list. If no locations are mentioned, return an empty list.
23
-
24
- Sentence: "{sentence}"
25
- """
26
-
27
- response = await llm.ainvoke(prompt)
28
- location_list = ast.literal_eval(response.content.strip("```python\n").strip())
29
- if location_list:
30
- return location_list[0]
31
- else:
32
- return ""
33
-
34
- def loc_to_coords(location: str) -> tuple[float, float]:
35
- """Converts a location name to geographic coordinates.
36
-
37
- This function uses the Nominatim geocoding service to convert
38
- a location name (e.g., city name) to its latitude and longitude.
39
-
40
- Args:
41
- location (str): The name of the location to geocode
42
-
43
- Returns:
44
- tuple[float, float]: A tuple containing (latitude, longitude)
45
-
46
- Raises:
47
- AttributeError: If the location cannot be found
48
- """
49
- geolocator = Nominatim(user_agent="city_to_latlong", timeout=5)
50
- coords = geolocator.geocode(location)
51
- return (coords.latitude, coords.longitude)
52
-
53
- def coords_to_country(coords: tuple[float, float]) -> tuple[str,str]:
54
- """Converts geographic coordinates to a country name.
55
-
56
- This function uses the Nominatim reverse geocoding service to convert
57
- latitude and longitude coordinates to a country name.
58
-
59
- Args:
60
- coords (tuple[float, float]): A tuple containing (latitude, longitude)
61
-
62
- Returns:
63
- tuple[str,str]: A tuple containg (country_code, country_name, admin1)
64
-
65
- Raises:
66
- AttributeError: If the coordinates cannot be found
67
- """
68
- geolocator = Nominatim(user_agent="latlong_to_country")
69
- location = geolocator.reverse(coords)
70
- address = location.raw['address']
71
- return address['country_code'].upper(), address['country']
72
-
73
- def nearest_neighbour_sql(location: tuple, mode: Literal['DRIAS', 'IPCC']) -> tuple[str, str, Optional[str]]:
74
- long = round(location[1], 3)
75
- lat = round(location[0], 3)
76
- conn = duckdb.connect()
77
-
78
- if mode == 'DRIAS':
79
- table_path = f"'{DRIAS_MEAN_ANNUAL_TEMPERATURE_PATH}'"
80
- results = conn.sql(
81
- f"SELECT latitude, longitude FROM {table_path} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}"
82
- ).fetchdf()
83
- else:
84
- table_path = f"'{IPCC_COORDINATES_PATH}'"
85
- results = conn.sql(
86
- f"SELECT latitude, longitude, admin1 FROM {table_path} WHERE latitude BETWEEN {lat - 0.5} AND {lat + 0.5} AND longitude BETWEEN {long - 0.5} AND {long + 0.5}"
87
- ).fetchdf()
88
-
89
-
90
- if len(results) == 0:
91
- return "", "", ""
92
-
93
- if 'admin1' in results.columns:
94
- admin1 = results['admin1'].iloc[0]
95
- else:
96
- admin1 = None
97
- return results['latitude'].iloc[0], results['longitude'].iloc[0], admin1
98
-
99
- async def detect_year_with_openai(sentence: str) -> str:
100
- """
101
- Detects years in a sentence using OpenAI's API via LangChain.
102
- """
103
- llm = get_llm()
104
-
105
- prompt = """
106
- Extract all years mentioned in the following sentence.
107
- Return the result as a Python list. If no year are mentioned, return an empty list.
108
-
109
- Sentence: "{sentence}"
110
- """
111
-
112
- prompt = ChatPromptTemplate.from_template(prompt)
113
- structured_llm = llm.with_structured_output(ArrayOutput)
114
- chain = prompt | structured_llm
115
- response: ArrayOutput = await chain.ainvoke({"sentence": sentence})
116
- years_list = ast.literal_eval(response['array'])
117
- if len(years_list) > 0:
118
- return years_list[0]
119
- else:
120
- return ""
121
-
122
-
123
- async def detect_relevant_tables(user_question: str, plot: Plot, llm, table_names_list: list[str]) -> list[str]:
124
- """Identifies relevant tables for a plot based on user input.
125
-
126
- This function uses an LLM to analyze the user's question and the plot
127
- description to determine which tables in the DRIAS database would be
128
- most relevant for generating the requested visualization.
129
-
130
- Args:
131
- user_question (str): The user's question about climate data
132
- plot (Plot): The plot configuration object
133
- llm: The language model instance to use for analysis
134
-
135
- Returns:
136
- list[str]: A list of table names that are relevant for the plot
137
-
138
- Example:
139
- >>> detect_relevant_tables(
140
- ... "What will the temperature be like in Paris?",
141
- ... indicator_evolution_at_location,
142
- ... llm
143
- ... )
144
- ['mean_annual_temperature', 'mean_summer_temperature']
145
- """
146
- # Get all table names
147
-
148
- prompt = (
149
- f"You are helping to build a plot following this description : {plot['description']}."
150
- f"You are given a list of tables and a user question."
151
- f"Based on the description of the plot, which table are appropriate for that kind of plot."
152
- f"Write the 3 most relevant tables to use. Answer only a python list of table name."
153
- f"### List of tables : {table_names_list}"
154
- f"### User question : {user_question}"
155
- f"### List of table name : "
156
- )
157
-
158
- table_names = ast.literal_eval(
159
- (await llm.ainvoke(prompt)).content.strip("```python\n").strip()
160
- )
161
- return table_names
162
-
163
- async def detect_relevant_plots(user_question: str, llm, plot_list: list[Plot]) -> list[str]:
164
- plots_description = ""
165
- for plot in plot_list:
166
- plots_description += "Name: " + plot["name"]
167
- plots_description += " - Description: " + plot["description"] + "\n"
168
-
169
- prompt = (
170
- "You are helping to answer a question with insightful visualizations.\n"
171
- "You are given a user question and a list of plots with their name and description.\n"
172
- "Based on the descriptions of the plots, select ALL plots that could provide a useful answer to this question. "
173
- "Include any plot that could show relevant information, even if their perspectives (such as time series or spatial distribution) are different.\n"
174
- "For example, for a question like 'What will be the total rainfall in China in 2050?', both a time series plot and a spatial map plot could be relevant.\n"
175
- "Return only a Python list of plot names sorted from the most relevant one to the less relevant one.\n"
176
- f"### Descriptions of the plots : {plots_description}"
177
- f"### User question : {user_question}\n"
178
- f"### Names of the plots : "
179
- )
180
-
181
- plot_names = ast.literal_eval(
182
- (await llm.ainvoke(prompt)).content.strip("```python\n").strip()
183
- )
184
- return plot_names
185
-
186
- async def find_location(user_input: str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> Location:
187
- print(f"---- Find location in user input ----")
188
- location = await detect_location_with_openai(user_input)
189
- output: Location = {
190
- 'location' : location,
191
- 'longitude' : None,
192
- 'latitude' : None,
193
- 'country_code' : None,
194
- 'country_name' : None,
195
- 'admin1' : None
196
- }
197
-
198
- if location:
199
- coords = loc_to_coords(location)
200
- country_code, country_name = coords_to_country(coords)
201
- neighbour = nearest_neighbour_sql(coords, mode)
202
- output.update({
203
- "latitude": neighbour[0],
204
- "longitude": neighbour[1],
205
- "country_code": country_code,
206
- "country_name": country_name,
207
- "admin1": neighbour[2]
208
- })
209
- output = cast(Location, output)
210
- return output
211
-
212
- async def find_year(user_input: str) -> str| None:
213
- """Extracts year information from user input using LLM.
214
-
215
- This function uses an LLM to identify and extract year information from the
216
- user's query, which is used to filter data in subsequent queries.
217
-
218
- Args:
219
- user_input (str): The user's query text
220
-
221
- Returns:
222
- str: The extracted year, or empty string if no year found
223
- """
224
- print(f"---- Find year ---")
225
- year = await detect_year_with_openai(user_input)
226
- if year == "":
227
- return None
228
- return year
229
-
230
- async def find_relevant_plots(state: State, llm, plots: list[Plot]) -> list[str]:
231
- print("---- Find relevant plots ----")
232
- relevant_plots = await detect_relevant_plots(state['user_input'], llm, plots)
233
- return relevant_plots
234
-
235
- async def find_relevant_tables_per_plot(state: State, plot: Plot, llm, tables: list[str]) -> list[str]:
236
- print(f"---- Find relevant tables for {plot['name']} ----")
237
- relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm, tables)
238
- return relevant_tables
239
-
240
- async def find_param(state: State, param_name:str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> dict[str, Optional[str]] | Location | None:
241
- """Perform the good method to retrieve the desired parameter
242
-
243
- Args:
244
- state (State): state of the workflow
245
- param_name (str): name of the desired parameter
246
- table (str): name of the table
247
-
248
- Returns:
249
- dict[str, Any] | None:
250
- """
251
- if param_name == 'location':
252
- location = await find_location(state['user_input'], mode)
253
- return location
254
- if param_name == 'year':
255
- year = await find_year(state['user_input'])
256
- return {'year': year}
257
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/talk_to_data/ipcc/config.py DELETED
@@ -1,98 +0,0 @@
1
- from climateqa.engine.talk_to_data.ui_config import PRECIPITATION_COLORSCALE, TEMPERATURE_COLORSCALE
2
- from climateqa.engine.talk_to_data.config import IPCC_DATASET_URL
3
-
4
-
5
- # IPCC_DATASET_URL = "hf://datasets/ekimetrics/ipcc-atlas"
6
- IPCC_TABLES = [
7
- "mean_temperature",
8
- "total_precipitation",
9
- ]
10
-
11
- IPCC_INDICATOR_COLUMNS_PER_TABLE = {
12
- "mean_temperature": "mean_temperature",
13
- "total_precipitation": "total_precipitation"
14
- }
15
-
16
- IPCC_INDICATOR_TO_UNIT = {
17
- "mean_temperature": "°C",
18
- "total_precipitation": "mm/day"
19
- }
20
-
21
- IPCC_SCENARIO = [
22
- "historical",
23
- "ssp126",
24
- "ssp245",
25
- "ssp370",
26
- "ssp585",
27
- ]
28
-
29
- IPCC_MODELS = []
30
-
31
- IPCC_PLOT_PARAMETERS = [
32
- 'year',
33
- 'location'
34
- ]
35
-
36
- MACRO_COUNTRIES = ['JP',
37
- 'IN',
38
- 'MH',
39
- 'PT',
40
- 'ID',
41
- 'SJ',
42
- 'MX',
43
- 'CN',
44
- 'GL',
45
- 'PN',
46
- 'AR',
47
- 'AQ',
48
- 'PF',
49
- 'BR',
50
- 'SH',
51
- 'GS',
52
- 'ZA',
53
- 'NZ',
54
- 'TF',
55
- ]
56
-
57
- HUGE_MACRO_COUNTRIES = ['CL',
58
- 'CA',
59
- 'AU',
60
- 'US',
61
- 'RU'
62
- ]
63
-
64
- IPCC_INDICATOR_TO_COLORSCALE = {
65
- "mean_temperature": TEMPERATURE_COLORSCALE,
66
- "total_precipitation": PRECIPITATION_COLORSCALE
67
- }
68
-
69
- IPCC_UI_TEXT = """
70
- Hi, I'm **Talk to IPCC**, designed to answer your questions using [**IPCC - ATLAS**](https://interactive-atlas.ipcc.ch/regional-information#eyJ0eXBlIjoiQVRMQVMiLCJjb21tb25zIjp7ImxhdCI6OTc3MiwibG5nIjo0MDA2OTIsInpvb20iOjQsInByb2oiOiJFUFNHOjU0MDMwIiwibW9kZSI6ImNvbXBsZXRlX2F0bGFzIn0sInByaW1hcnkiOnsic2NlbmFyaW8iOiJzc3A1ODUiLCJwZXJpb2QiOiIyIiwic2Vhc29uIjoieWVhciIsImRhdGFzZXQiOiJDTUlQNiIsInZhcmlhYmxlIjoidGFzIiwidmFsdWVUeXBlIjoiQU5PTUFMWSIsImhhdGNoaW5nIjoiU0lNUExFIiwicmVnaW9uU2V0IjoiYXI2IiwiYmFzZWxpbmUiOiJwcmVJbmR1c3RyaWFsIiwicmVnaW9uc1NlbGVjdGVkIjpbXX0sInBsb3QiOnsiYWN0aXZlVGFiIjoicGx1bWUiLCJtYXNrIjoibm9uZSIsInNjYXR0ZXJZTWFnIjpudWxsLCJzY2F0dGVyWVZhciI6bnVsbCwic2hvd2luZyI6ZmFsc2V9fQ==) data.
71
- I'll answer by displaying a list of SQL queries, graphs and data most relevant to your question.
72
-
73
- You can ask me anything about these climate indicators: **temperature** or **precipitation**.
74
- You can specify **location** and/or **year**.
75
- By default, we take the **mediane of each climate model**.
76
-
77
- Current available charts :
78
- - Yearly evolution of an indicator at a specific location (historical + SSP Projections)
79
- - Yearly spatial distribution of an indicator in a specific country
80
-
81
- Current available indicators :
82
- - Mean temperature
83
- - Total precipitation
84
-
85
- For example, you can ask:
86
- - What will the temperature be like in Paris?
87
- - What will be the total rainfall in the USA in 2030?
88
- - How will the average temperature evolve in China ?
89
-
90
- ⚠️ **Limitations**:
91
- - You can't ask anything that isn't related to **IPCC - ATLAS** data.
92
- - You can not ask about **several locations at the same time**.
93
- - If you specify a year **before 1850 or over 2100**, there will be **no data**.
94
- - You **cannot compare two models**.
95
-
96
- 🛈 **Information**
97
- Please note that we **log your questions for meta-analysis purposes**, so avoid sharing any sensitive or personal information.
98
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/talk_to_data/ipcc/plot_informations.py DELETED
@@ -1,50 +0,0 @@
1
- from climateqa.engine.talk_to_data.ipcc.config import IPCC_INDICATOR_TO_UNIT
2
-
3
- def indicator_evolution_informations(
4
- indicator: str,
5
- params: dict[str,str],
6
- ) -> str:
7
- if "location" not in params:
8
- raise ValueError('"location" must be provided in params')
9
- location = params["location"]
10
-
11
- unit = IPCC_INDICATOR_TO_UNIT[indicator]
12
- return f"""
13
- This plot shows how the climate indicator **{indicator}** evolves over time in **{location}**.
14
-
15
- It combines both historical (from 1950 to 2015) observations and future (from 2016 to 2100) projections for the different SSP climate scenarios (SSP126, SSP245, SSP370 and SSP585).
16
-
17
- The x-axis represents the years (from 1950 to 2100), and the y-axis shows the value of the {indicator} ({unit}).
18
-
19
- Each line corresponds to a different scenario, allowing you to compare how {indicator} might change under various future conditions.
20
-
21
- **Data source:**
22
- - The data comes from the CMIP6 IPCC ATLAS data. The data were initially extracted from [this referenced website](https://digital.csic.es/handle/10261/332744) and then preprocessed to a tabular format and uploaded as parquet in this [Hugging Face dataset](https://huggingface.co/datasets/Ekimetrics/ipcc-atlas).
23
- - The underlying data is retrieved by aggregating yearly values of {indicator} for the selected location, across all available scenarios. This means the system collects, for each year, the value of {indicator} in {location}, both for the historical period and for each scenario, to build the time series.
24
- - The coordinates used for {location} correspond to the closest available point in the IPCC database, which uses a regular grid with a spatial resolution of 1 degree.
25
- """
26
-
27
- def choropleth_map_informations(
28
- indicator: str,
29
- params: dict[str, str],
30
- ) -> str:
31
- unit = IPCC_INDICATOR_TO_UNIT[indicator]
32
- if "location" not in params:
33
- raise ValueError('"location" must be provided in params')
34
- location = params["location"]
35
- country_name = params["country_name"]
36
- year = params["year"]
37
- if year is None:
38
- year = 2050
39
-
40
- return f"""
41
- This plot displays a choropleth map showing the spatial distribution of **{indicator}** across all regions of **{location}** country ({country_name}) for the year **{year}** and the chosen scenario.
42
-
43
- Each grid point is colored according to the value of the indicator ({unit}), allowing you to visually compare how {indicator} varies geographically within the country for the selected year and scenario.
44
-
45
- **Data source:**
46
- - The data come from the CMIP6 IPCC ATLAS data. The data were initially extracted from [this referenced website](https://digital.csic.es/handle/10261/332744) and then preprocessed to a tabular format and uploaded as parquet in this [Hugging Face dataset](https://huggingface.co/datasets/Ekimetrics/ipcc-atlas).
47
- - For each grid point of {location} country ({country_name}), the value of {indicator} in {year} and for the selected scenario is extracted and mapped to its geographic coordinates.
48
- - The grid points correspond to 1-degree squares centered on the grid points of the IPCC dataset. Each grid point has been mapped to a country using [**reverse_geocoder**](https://github.com/thampiman/reverse-geocoder).
49
- - The coordinates used for each region are those of the closest available grid point in the IPCC database, which uses a regular grid with a spatial resolution of 1 degree.
50
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/talk_to_data/ipcc/plots.py DELETED
@@ -1,189 +0,0 @@
1
- from typing import Callable
2
- from plotly.graph_objects import Figure
3
- import plotly.graph_objects as go
4
- import pandas as pd
5
- import geojson
6
-
7
- from climateqa.engine.talk_to_data.ipcc.config import IPCC_INDICATOR_TO_COLORSCALE, IPCC_INDICATOR_TO_UNIT, IPCC_SCENARIO
8
- from climateqa.engine.talk_to_data.ipcc.plot_informations import choropleth_map_informations, indicator_evolution_informations
9
- from climateqa.engine.talk_to_data.ipcc.queries import indicator_for_given_year_query, indicator_per_year_at_location_query
10
- from climateqa.engine.talk_to_data.objects.plot import Plot
11
-
12
- def generate_geojson_polygons(latitudes: list[float], longitudes: list[float], indicators: list[float]) -> geojson.FeatureCollection:
13
- features = [
14
- geojson.Feature(
15
- geometry=geojson.Polygon([[
16
- [lon - 0.5, lat - 0.5],
17
- [lon + 0.5, lat - 0.5],
18
- [lon + 0.5, lat + 0.5],
19
- [lon - 0.5, lat + 0.5],
20
- [lon - 0.5, lat - 0.5]
21
- ]]),
22
- properties={"value": val},
23
- id=str(idx)
24
- )
25
- for idx, (lat, lon, val) in enumerate(zip(latitudes, longitudes, indicators))
26
- ]
27
-
28
- geojson_data = geojson.FeatureCollection(features)
29
- return geojson_data
30
-
31
- def plot_indicator_evolution_at_location_historical_and_projections(
32
- params: dict,
33
- ) -> Callable[[pd.DataFrame], Figure]:
34
- """
35
- Returns a function that generates a line plot showing the evolution of a climate indicator
36
- (e.g., temperature, rainfall) over time at a specific location, including both historical data
37
- and future projections for different climate scenarios.
38
-
39
- Args:
40
- params (dict): Dictionary with:
41
- - indicator_column (str): Name of the climate indicator column to plot.
42
- - location (str): Location (e.g., country, city) for which to plot the indicator.
43
-
44
- Returns:
45
- Callable[[pd.DataFrame], Figure]: Function that takes a DataFrame and returns a Plotly Figure
46
- showing the indicator's evolution over time, with scenario lines and historical data.
47
- """
48
- indicator = params["indicator_column"]
49
- location = params["location"]
50
- indicator_label = " ".join(word.capitalize() for word in indicator.split("_"))
51
- unit = IPCC_INDICATOR_TO_UNIT.get(indicator, "")
52
-
53
- def plot_data(df: pd.DataFrame) -> Figure:
54
- df = df.sort_values(by='year')
55
- years = df['year'].astype(int).tolist()
56
- indicators = df[indicator].astype(float).tolist()
57
- scenarios = df['scenario'].astype(str).tolist()
58
-
59
- # Find last historical value for continuity
60
- last_historical = [(y, v) for y, v, s in zip(years, indicators, scenarios) if s == 'historical']
61
- last_historical_year, last_historical_indicator = last_historical[-1] if last_historical else (None, None)
62
-
63
- fig = go.Figure()
64
- for scenario in IPCC_SCENARIO:
65
- x = [y for y, s in zip(years, scenarios) if s == scenario]
66
- y = [v for v, s in zip(indicators, scenarios) if s == scenario]
67
- # Connect historical to scenario
68
- if scenario != 'historical' and last_historical_indicator is not None:
69
- x = [last_historical_year] + x
70
- y = [last_historical_indicator] + y
71
- fig.add_trace(go.Scatter(
72
- x=x,
73
- y=y,
74
- mode='lines',
75
- name=scenario
76
- ))
77
-
78
- fig.update_layout(
79
- title=f'Yearly Evolution of {indicator_label} in {location} (Historical + SSP Scenarios)',
80
- xaxis_title='Year',
81
- yaxis_title=f'{indicator_label} ({unit})',
82
- legend_title='Scenario',
83
- height=800,
84
- )
85
- return fig
86
-
87
- return plot_data
88
-
89
- indicator_evolution_at_location_historical_and_projections: Plot = {
90
- "name": "Indicator Evolution at Location (Historical + Projections)",
91
- "description": (
92
- "Shows how a climate indicator (e.g., rainfall, temperature) changes over time at a specific location, "
93
- "including historical data and future projections. "
94
- "Useful for questions about the value or trend of an indicator at a location for any year, "
95
- "such as 'What will be the total rainfall in China in 2050?' or 'How does rainfall evolve in China over time?'. "
96
- "Parameters: indicator_column (the climate variable), location (e.g., country, city)."
97
- ),
98
- "params": ["indicator_column", "location"],
99
- "plot_function": plot_indicator_evolution_at_location_historical_and_projections,
100
- "sql_query": indicator_per_year_at_location_query,
101
- "plot_information": indicator_evolution_informations,
102
- "short_name": "Evolution"
103
- }
104
-
105
- def plot_choropleth_map_of_country_indicator_for_specific_year(
106
- params: dict,
107
- ) -> Callable[[pd.DataFrame], Figure]:
108
- """
109
- Returns a function that generates a choropleth map (heatmap) showing the spatial distribution
110
- of a climate indicator (e.g., temperature, rainfall) across all regions of a country for a specific year.
111
-
112
- Args:
113
- params (dict): Dictionary with:
114
- - indicator_column (str): Name of the climate indicator column to plot.
115
- - year (str or int, optional): Year for which to plot the indicator (default: 2050).
116
- - country_name (str): Name of the country.
117
- - location (str): Location (country or region) for the map.
118
-
119
- Returns:
120
- Callable[[pd.DataFrame], Figure]: Function that takes a DataFrame and returns a Plotly Figure
121
- showing the indicator's spatial distribution as a choropleth map for the specified year.
122
- """
123
- indicator = params["indicator_column"]
124
- year = params.get('year')
125
- if year is None:
126
- year = 2050
127
- country_name = params['country_name']
128
- location = params['location']
129
- indicator_label = " ".join(word.capitalize() for word in indicator.split("_"))
130
- unit = IPCC_INDICATOR_TO_UNIT.get(indicator, "")
131
-
132
- def plot_data(df: pd.DataFrame) -> Figure:
133
-
134
- indicators = df[indicator].astype(float).tolist()
135
- latitudes = df["latitude"].astype(float).tolist()
136
- longitudes = df["longitude"].astype(float).tolist()
137
-
138
- geojson_data = generate_geojson_polygons(latitudes, longitudes, indicators)
139
-
140
- fig = go.Figure(go.Choroplethmapbox(
141
- geojson=geojson_data,
142
- locations=[str(i) for i in range(len(indicators))],
143
- featureidkey="id",
144
- z=indicators,
145
- colorscale=IPCC_INDICATOR_TO_COLORSCALE[indicator],
146
- zmin=min(indicators),
147
- zmax=max(indicators),
148
- marker_opacity=0.7,
149
- marker_line_width=0,
150
- colorbar_title=f"{indicator_label} ({unit})",
151
- text=[f"{indicator_label}: {value:.2f} {unit}" for value in indicators], # Add hover text showing the indicator value
152
- hoverinfo="text"
153
- ))
154
-
155
- fig.update_layout(
156
- mapbox_style="open-street-map",
157
- mapbox_zoom=2,
158
- height=800,
159
- mapbox_center={
160
- "lat": latitudes[len(latitudes)//2] if latitudes else 0,
161
- "lon": longitudes[len(longitudes)//2] if longitudes else 0
162
- },
163
- coloraxis_colorbar=dict(title=f"{indicator_label} ({unit})"),
164
- title=f"{indicator_label} in {year} in {location} ({country_name})"
165
- )
166
- return fig
167
-
168
- return plot_data
169
-
170
- choropleth_map_of_country_indicator_for_specific_year: Plot = {
171
- "name": "Choropleth Map of a Country's Indicator Distribution for a Specific Year",
172
- "description": (
173
- "Displays a map showing the spatial distribution of a climate indicator (e.g., rainfall, temperature) "
174
- "across all regions of a country for a specific year. "
175
- "Can answer questions about the value of an indicator in a country or region for a given year, "
176
- "such as 'What will be the total rainfall in China in 2050?' or 'How is rainfall distributed across China in 2050?'. "
177
- "Parameters: indicator_column (the climate variable), year, location (country name)."
178
- ),
179
- "params": ["indicator_column", "year", "location"],
180
- "plot_function": plot_choropleth_map_of_country_indicator_for_specific_year,
181
- "sql_query": indicator_for_given_year_query,
182
- "plot_information": choropleth_map_informations,
183
- "short_name": "Map",
184
- }
185
-
186
- IPCC_PLOTS = [
187
- indicator_evolution_at_location_historical_and_projections,
188
- choropleth_map_of_country_indicator_for_specific_year
189
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/talk_to_data/ipcc/queries.py DELETED
@@ -1,144 +0,0 @@
1
- from typing import TypedDict, Optional
2
-
3
- from climateqa.engine.talk_to_data.ipcc.config import HUGE_MACRO_COUNTRIES, MACRO_COUNTRIES
4
- from climateqa.engine.talk_to_data.config import IPCC_DATASET_URL
5
-
6
- class IndicatorPerYearAtLocationQueryParams(TypedDict, total=False):
7
- """
8
- Parameters for querying the evolution of an indicator per year at a specific location.
9
-
10
- Attributes:
11
- indicator_column (str): Name of the climate indicator column.
12
- latitude (str): Latitude of the location.
13
- longitude (str): Longitude of the location.
14
- country_code (str): Country code.
15
- admin1 (str): Administrative region (optional).
16
- """
17
- indicator_column: str
18
- latitude: str
19
- longitude: str
20
- country_code: str
21
- admin1: Optional[str]
22
-
23
- def indicator_per_year_at_location_query(
24
- table: str, params: IndicatorPerYearAtLocationQueryParams
25
- ) -> str:
26
- """
27
- Builds an SQL query to get the evolution of an indicator per year at a specific location.
28
-
29
- Args:
30
- table (str): SQL table of the indicator.
31
- params (IndicatorPerYearAtLocationQueryParams): Dictionary with the required params for the query.
32
-
33
- Returns:
34
- str: The SQL query string, or an empty string if required parameters are missing.
35
- """
36
- indicator_column = params.get("indicator_column")
37
- latitude = params.get("latitude")
38
- longitude = params.get("longitude")
39
- country_code = params.get("country_code")
40
- admin1 = params.get("admin1")
41
-
42
- if not all([indicator_column, latitude, longitude, country_code]):
43
- return ""
44
-
45
- if country_code in MACRO_COUNTRIES:
46
- table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_macro.parquet'"
47
- sql_query = f"""
48
- SELECT year, scenario, AVG({indicator_column}) as {indicator_column}
49
- FROM {table_path}
50
- WHERE latitude = {latitude} AND longitude = {longitude} AND year >= 1950
51
- GROUP BY scenario, year
52
- ORDER BY year, scenario
53
- """
54
- elif country_code in HUGE_MACRO_COUNTRIES:
55
- table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_macro.parquet'"
56
- sql_query = f"""
57
- SELECT year, scenario, {indicator_column}
58
- FROM {table_path}
59
- WHERE latitude = {latitude} AND longitude = {longitude} AND year >= 1950
60
- ORDER BY year, scenario
61
- """
62
- else:
63
- table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}.parquet'"
64
- sql_query = f"""
65
- WITH medians_per_month AS (
66
- SELECT year, scenario, month, MEDIAN({indicator_column}) AS median_value
67
- FROM {table_path}
68
- WHERE latitude = {latitude} AND longitude = {longitude} AND year >= 1950
69
- GROUP BY scenario, year, month
70
- )
71
- SELECT year, scenario, AVG(median_value) AS {indicator_column}
72
- FROM medians_per_month
73
- GROUP BY scenario, year
74
- ORDER BY year, scenario
75
- """
76
- return sql_query.strip()
77
-
78
- class IndicatorForGivenYearQueryParams(TypedDict, total=False):
79
- """
80
- Parameters for querying an indicator's values across locations for a specific year.
81
-
82
- Attributes:
83
- indicator_column (str): The column name for the climate indicator.
84
- year (str): The year to query.
85
- country_code (str): The country code.
86
- """
87
- indicator_column: str
88
- year: str
89
- country_code: str
90
-
91
- def indicator_for_given_year_query(
92
- table: str, params: IndicatorForGivenYearQueryParams
93
- ) -> str:
94
- """
95
- Builds an SQL query to get the values of an indicator with their latitudes, longitudes,
96
- and scenarios for a given year.
97
-
98
- Args:
99
- table (str): SQL table of the indicator.
100
- params (IndicatorForGivenYearQueryParams): Dictionary with the required params for the query.
101
-
102
- Returns:
103
- str: The SQL query string, or an empty string if required parameters are missing.
104
- """
105
- indicator_column = params.get("indicator_column")
106
- year = params.get("year") or 2050
107
- country_code = params.get("country_code")
108
-
109
- if not all([indicator_column, year, country_code]):
110
- return ""
111
-
112
- if country_code in MACRO_COUNTRIES:
113
- table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_macro.parquet'"
114
- sql_query = f"""
115
- SELECT latitude, longitude, scenario, AVG({indicator_column}) as {indicator_column}
116
- FROM {table_path}
117
- WHERE year = {year}
118
- GROUP BY latitude, longitude, scenario
119
- ORDER BY latitude, longitude, scenario
120
- """
121
- elif country_code in HUGE_MACRO_COUNTRIES:
122
- table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_macro.parquet'"
123
- sql_query = f"""
124
- SELECT latitude, longitude, scenario, {indicator_column}
125
- FROM {table_path}
126
- WHERE year = {year}
127
- ORDER BY latitude, longitude, scenario
128
- """
129
- else:
130
- table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}.parquet'"
131
- sql_query = f"""
132
- WITH medians_per_month AS (
133
- SELECT latitude, longitude, scenario, month, MEDIAN({indicator_column}) AS median_value
134
- FROM {table_path}
135
- WHERE year = {year}
136
- GROUP BY latitude, longitude, scenario, month
137
- )
138
- SELECT latitude, longitude, scenario, AVG(median_value) AS {indicator_column}
139
- FROM medians_per_month
140
- GROUP BY latitude, longitude, scenario
141
- ORDER BY latitude, longitude, scenario
142
- """
143
-
144
- return sql_query.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/talk_to_data/main.py DELETED
@@ -1,124 +0,0 @@
1
- from climateqa.engine.talk_to_data.workflow.drias import drias_workflow
2
- from climateqa.engine.talk_to_data.workflow.ipcc import ipcc_workflow
3
- from climateqa.logging import log_drias_interaction_to_huggingface
4
-
5
- async def ask_drias(query: str, index_state: int = 0, user_id: str | None = None) -> tuple:
6
- """Main function to process a DRIAS query and return results.
7
-
8
- This function orchestrates the DRIAS workflow, processing a user query to generate
9
- SQL queries, dataframes, and visualizations. It handles multiple results and allows
10
- pagination through them.
11
-
12
- Args:
13
- query (str): The user's question about climate data
14
- index_state (int, optional): The index of the result to return. Defaults to 0.
15
-
16
- Returns:
17
- tuple: A tuple containing:
18
- - sql_query (str): The SQL query used
19
- - dataframe (pd.DataFrame): The resulting data
20
- - figure (Callable): Function to generate the visualization
21
- - sql_queries (list): All generated SQL queries
22
- - result_dataframes (list): All resulting dataframes
23
- - figures (list): All figure generation functions
24
- - index_state (int): Current result index
25
- - table_list (list): List of table names used
26
- - error (str): Error message if any
27
- """
28
- final_state = await drias_workflow(query)
29
- sql_queries = []
30
- result_dataframes = []
31
- figures = []
32
- plot_title_list = []
33
- plot_informations = []
34
-
35
- for output_title, output in final_state['outputs'].items():
36
- if output['status'] == 'OK':
37
- if output['table'] is not None:
38
- plot_title_list.append(output_title)
39
-
40
- if output['plot_information'] is not None:
41
- plot_informations.append(output['plot_information'])
42
-
43
- if output['sql_query'] is not None:
44
- sql_queries.append(output['sql_query'])
45
-
46
- if output['dataframe'] is not None:
47
- result_dataframes.append(output['dataframe'])
48
- if output['figure'] is not None:
49
- figures.append(output['figure'])
50
-
51
- if "error" in final_state and final_state["error"] != "":
52
- # No Sql query, no dataframe, no figure, no plot information, empty sql queries list, empty result dataframes list, empty figures list, empty plot information list, index state = 0, empty table list, error message
53
- return None, None, None, None, [], [], [], 0, [], final_state["error"]
54
-
55
- sql_query = sql_queries[index_state]
56
- dataframe = result_dataframes[index_state]
57
- figure = figures[index_state](dataframe)
58
- plot_information = plot_informations[index_state]
59
-
60
-
61
- log_drias_interaction_to_huggingface(query, sql_query, user_id)
62
-
63
- return sql_query, dataframe, figure, plot_information, sql_queries, result_dataframes, figures, plot_informations, index_state, plot_title_list, ""
64
-
65
-
66
-
67
- async def ask_ipcc(query: str, index_state: int = 0, user_id: str | None = None) -> tuple:
68
- """Main function to process a DRIAS query and return results.
69
-
70
- This function orchestrates the DRIAS workflow, processing a user query to generate
71
- SQL queries, dataframes, and visualizations. It handles multiple results and allows
72
- pagination through them.
73
-
74
- Args:
75
- query (str): The user's question about climate data
76
- index_state (int, optional): The index of the result to return. Defaults to 0.
77
-
78
- Returns:
79
- tuple: A tuple containing:
80
- - sql_query (str): The SQL query used
81
- - dataframe (pd.DataFrame): The resulting data
82
- - figure (Callable): Function to generate the visualization
83
- - sql_queries (list): All generated SQL queries
84
- - result_dataframes (list): All resulting dataframes
85
- - figures (list): All figure generation functions
86
- - index_state (int): Current result index
87
- - table_list (list): List of table names used
88
- - error (str): Error message if any
89
- """
90
- final_state = await ipcc_workflow(query)
91
- sql_queries = []
92
- result_dataframes = []
93
- figures = []
94
- plot_title_list = []
95
- plot_informations = []
96
-
97
- for output_title, output in final_state['outputs'].items():
98
- if output['status'] == 'OK':
99
- if output['table'] is not None:
100
- plot_title_list.append(output_title)
101
-
102
- if output['plot_information'] is not None:
103
- plot_informations.append(output['plot_information'])
104
-
105
- if output['sql_query'] is not None:
106
- sql_queries.append(output['sql_query'])
107
-
108
- if output['dataframe'] is not None:
109
- result_dataframes.append(output['dataframe'])
110
- if output['figure'] is not None:
111
- figures.append(output['figure'])
112
-
113
- if "error" in final_state and final_state["error"] != "":
114
- # No Sql query, no dataframe, no figure, no plot information, empty sql queries list, empty result dataframes list, empty figures list, empty plot information list, index state = 0, empty table list, error message
115
- return None, None, None, None, [], [], [], 0, [], final_state["error"]
116
-
117
- sql_query = sql_queries[index_state]
118
- dataframe = result_dataframes[index_state]
119
- figure = figures[index_state](dataframe)
120
- plot_information = plot_informations[index_state]
121
-
122
- log_drias_interaction_to_huggingface(query, sql_query, user_id)
123
-
124
- return sql_query, dataframe, figure, plot_information, sql_queries, result_dataframes, figures, plot_informations, index_state, plot_title_list, ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/talk_to_data/myVanna.py DELETED
@@ -1,13 +0,0 @@
1
- from dotenv import load_dotenv
2
- from climateqa.engine.talk_to_data.vanna_class import MyCustomVectorDB
3
- from vanna.openai import OpenAI_Chat
4
- import os
5
-
6
- load_dotenv()
7
-
8
- OPENAI_API_KEY = os.getenv('THEO_API_KEY')
9
-
10
- class MyVanna(MyCustomVectorDB, OpenAI_Chat):
11
- def __init__(self, config=None):
12
- MyCustomVectorDB.__init__(self, config=config)
13
- OpenAI_Chat.__init__(self, config=config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/talk_to_data/objects/llm_outputs.py DELETED
@@ -1,13 +0,0 @@
1
- from typing import Annotated, TypedDict
2
-
3
-
4
- class ArrayOutput(TypedDict):
5
- """Represents the output of a function that returns an array.
6
-
7
- This class is used to type-hint functions that return arrays,
8
- ensuring consistent return types across the codebase.
9
-
10
- Attributes:
11
- array (str): A syntactically valid Python array string
12
- """
13
- array: Annotated[str, "Syntactically valid python array."]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/talk_to_data/objects/location.py DELETED
@@ -1,12 +0,0 @@
1
- from token import OP
2
- from typing import Optional, TypedDict
3
-
4
-
5
-
6
- class Location(TypedDict):
7
- location: str
8
- latitude: Optional[str]
9
- longitude: Optional[str]
10
- country_code: Optional[str]
11
- country_name: Optional[str]
12
- admin1: Optional[str]
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/talk_to_data/objects/plot.py DELETED
@@ -1,23 +0,0 @@
1
- from typing import Callable, TypedDict, Optional
2
- from plotly.graph_objects import Figure
3
-
4
- class Plot(TypedDict):
5
- """Represents a plot configuration in the DRIAS system.
6
-
7
- This class defines the structure for configuring different types of plots
8
- that can be generated from climate data.
9
-
10
- Attributes:
11
- name (str): The name of the plot type
12
- description (str): A description of what the plot shows
13
- params (list[str]): List of required parameters for the plot
14
- plot_function (Callable[..., Callable[..., Figure]]): Function to generate the plot
15
- sql_query (Callable[..., str]): Function to generate the SQL query for the plot
16
- """
17
- name: str
18
- description: str
19
- params: list[str]
20
- plot_function: Callable[..., Callable[..., Figure]]
21
- sql_query: Callable[..., str]
22
- plot_information: Callable[..., str]
23
- short_name: str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/talk_to_data/objects/states.py DELETED
@@ -1,19 +0,0 @@
1
- from typing import Any, Callable, Optional, TypedDict
2
- from plotly.graph_objects import Figure
3
- import pandas as pd
4
- from climateqa.engine.talk_to_data.objects.plot import Plot
5
-
6
- class TTDOutput(TypedDict):
7
- status: str
8
- plot: Plot
9
- table: str
10
- sql_query: Optional[str]
11
- dataframe: Optional[pd.DataFrame]
12
- figure: Optional[Callable[..., Figure]]
13
- plot_information: Optional[str]
14
- class State(TypedDict):
15
- user_input: str
16
- plots: list[str]
17
- outputs: dict[str, TTDOutput]
18
- error: Optional[str]
19
-