timeki commited on
Commit
175604a
·
1 Parent(s): bf59b4c

Dupliactes workflow to separate POC from Prod and simpify retrieval

Browse files
app.py CHANGED
@@ -9,7 +9,7 @@ from climateqa.engine.embeddings import get_embeddings_function
9
  from climateqa.engine.llm import get_llm
10
  from climateqa.engine.vectorstore import get_pinecone_vectorstore
11
  from climateqa.engine.reranker import get_reranker
12
- from climateqa.engine.graph import make_graph_agent
13
  from climateqa.engine.chains.retrieve_papers import find_papers
14
  from climateqa.chat import start_chat, chat_stream, finish_chat
15
 
@@ -69,12 +69,19 @@ vectorstore_region = get_pinecone_vectorstore(embeddings_function, index_name=os
69
  llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
70
  reranker = get_reranker("nano")
71
 
72
- agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0)#TODO put back default 0.2
 
73
 
74
 
75
  async def chat(query, history, audience, sources, reports, relevant_content_sources_selection, search_only):
 
76
  async for event in chat_stream(agent, query, history, audience, sources, reports, relevant_content_sources_selection, search_only, share_client, user_id):
77
  yield event
 
 
 
 
 
78
 
79
 
80
  # --------------------------------------------------------------------
@@ -205,7 +212,7 @@ def event_handling(
205
 
206
  new_sources_hmtl = gr.State([])
207
 
208
-
209
 
210
  for button in [config_button, close_config_modal]:
211
  button.click(
@@ -213,18 +220,38 @@ def event_handling(
213
  inputs=[config_open],
214
  outputs=[config_modal, config_open]
215
  )
216
- # Event for textbox
217
- (textbox
218
- .submit(start_chat, [textbox, chatbot, search_only], [textbox, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{textbox.elem_id}")
219
- .then(chat, [textbox, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name=f"chat_{textbox.elem_id}")
220
- .then(finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}")
221
- )
222
- # Event for examples_hidden
223
- (examples_hidden
224
- .change(start_chat, [examples_hidden, chatbot, search_only], [examples_hidden, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{examples_hidden.elem_id}")
225
- .then(chat, [examples_hidden, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name=f"chat_{examples_hidden.elem_id}")
226
- .then(finish_chat, None, [examples_hidden], api_name=f"finish_chat_{examples_hidden.elem_id}")
227
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
  new_sources_hmtl.change(lambda x : x, inputs = [new_sources_hmtl], outputs = [sources_textbox])
230
  current_graphs.change(lambda x: x, inputs=[current_graphs], outputs=[graphs_container])
@@ -234,10 +261,12 @@ def event_handling(
234
  for component in [sources_textbox, figures_cards, current_graphs, papers_html]:
235
  component.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs, papers_html], [tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers])
236
 
237
-
238
  # Search for papers
239
  for component in [textbox, examples_hidden]:
240
  component.submit(find_papers, [component, after, dropdown_external_sources], [papers_html, citations_network, papers_summary])
 
 
 
241
 
242
  def main_ui():
243
  # config_open = gr.State(True)
@@ -246,12 +275,12 @@ def main_ui():
246
 
247
  with gr.Tabs():
248
  cqa_components = cqa_tab(tab_name = "ClimateQ&A")
249
- # local_cqa_components = cqa_tab(tab_name = "Beta - POC Adapt'Action")
250
 
251
  create_about_tab()
252
 
253
  event_handling(cqa_components, config_components, tab_name = 'ClimateQ&A')
254
- # event_handling(local_cqa_components, config_components, tab_name = 'Beta - POC Adapt\'Action')
255
 
256
  demo.queue()
257
 
 
9
  from climateqa.engine.llm import get_llm
10
  from climateqa.engine.vectorstore import get_pinecone_vectorstore
11
  from climateqa.engine.reranker import get_reranker
12
+ from climateqa.engine.graph import make_graph_agent,make_graph_agent_poc
13
  from climateqa.engine.chains.retrieve_papers import find_papers
14
  from climateqa.chat import start_chat, chat_stream, finish_chat
15
 
 
69
  llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
70
  reranker = get_reranker("nano")
71
 
72
+ agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0.2)
73
+ agent_poc = make_graph_agent_poc(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0)#TODO put back default 0.2
74
 
75
 
76
  async def chat(query, history, audience, sources, reports, relevant_content_sources_selection, search_only):
77
+ print("chat cqa - message received")
78
  async for event in chat_stream(agent, query, history, audience, sources, reports, relevant_content_sources_selection, search_only, share_client, user_id):
79
  yield event
80
+
81
+ async def chat_poc(query, history, audience, sources, reports, relevant_content_sources_selection, search_only):
82
+ print("chat poc - message received")
83
+ async for event in chat_stream(agent_poc, query, history, audience, sources, reports, relevant_content_sources_selection, search_only, share_client, user_id):
84
+ yield event
85
 
86
 
87
  # --------------------------------------------------------------------
 
212
 
213
  new_sources_hmtl = gr.State([])
214
 
215
+ print("textbox id : ", textbox.elem_id)
216
 
217
  for button in [config_button, close_config_modal]:
218
  button.click(
 
220
  inputs=[config_open],
221
  outputs=[config_modal, config_open]
222
  )
223
+
224
+ if tab_name == "ClimateQ&A":
225
+ print("chat cqa - message sent")
226
+
227
+ # Event for textbox
228
+ (textbox
229
+ .submit(start_chat, [textbox, chatbot, search_only], [textbox, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{textbox.elem_id}")
230
+ .then(chat, [textbox, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name=f"chat_{textbox.elem_id}")
231
+ .then(finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}")
232
+ )
233
+ # Event for examples_hidden
234
+ (examples_hidden
235
+ .change(start_chat, [examples_hidden, chatbot, search_only], [examples_hidden, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{examples_hidden.elem_id}")
236
+ .then(chat, [examples_hidden, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name=f"chat_{examples_hidden.elem_id}")
237
+ .then(finish_chat, None, [textbox], api_name=f"finish_chat_{examples_hidden.elem_id}")
238
+ )
239
+
240
+ elif tab_name == "Beta - POC Adapt'Action":
241
+ print("chat poc - message sent")
242
+ # Event for textbox
243
+ (textbox
244
+ .submit(start_chat, [textbox, chatbot, search_only], [textbox, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{textbox.elem_id}")
245
+ .then(chat_poc, [textbox, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name=f"chat_{textbox.elem_id}")
246
+ .then(finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}")
247
+ )
248
+ # Event for examples_hidden
249
+ (examples_hidden
250
+ .change(start_chat, [examples_hidden, chatbot, search_only], [examples_hidden, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{examples_hidden.elem_id}")
251
+ .then(chat_poc, [examples_hidden, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name=f"chat_{examples_hidden.elem_id}")
252
+ .then(finish_chat, None, [textbox], api_name=f"finish_chat_{examples_hidden.elem_id}")
253
+ )
254
+
255
 
256
  new_sources_hmtl.change(lambda x : x, inputs = [new_sources_hmtl], outputs = [sources_textbox])
257
  current_graphs.change(lambda x: x, inputs=[current_graphs], outputs=[graphs_container])
 
261
  for component in [sources_textbox, figures_cards, current_graphs, papers_html]:
262
  component.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs, papers_html], [tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers])
263
 
 
264
  # Search for papers
265
  for component in [textbox, examples_hidden]:
266
  component.submit(find_papers, [component, after, dropdown_external_sources], [papers_html, citations_network, papers_summary])
267
+
268
+
269
+
270
 
271
  def main_ui():
272
  # config_open = gr.State(True)
 
275
 
276
  with gr.Tabs():
277
  cqa_components = cqa_tab(tab_name = "ClimateQ&A")
278
+ local_cqa_components = cqa_tab(tab_name = "Beta - POC Adapt'Action")
279
 
280
  create_about_tab()
281
 
282
  event_handling(cqa_components, config_components, tab_name = 'ClimateQ&A')
283
+ event_handling(local_cqa_components, config_components, tab_name = 'Beta - POC Adapt\'Action')
284
 
285
  demo.queue()
286
 
climateqa/chat.py CHANGED
@@ -119,6 +119,7 @@ async def chat_stream(
119
  start_streaming = False
120
  graphs_html = ""
121
  used_documents = []
 
122
  answer_message_content = ""
123
 
124
  # Define processing steps
@@ -138,8 +139,8 @@ async def chat_stream(
138
 
139
  # Handle document retrieval
140
  if event["event"] == "on_chain_end" and event["name"] in ["retrieve_documents","retrieve_local_data"] and event["data"]["output"] != None:
141
- history, used_documents = handle_retrieved_documents(
142
- event, history, used_documents
143
  )
144
  if event["event"] == "on_chain_end" and event["name"] == "answer_search" :
145
  docs = event["data"]["input"]["documents"]
@@ -180,7 +181,7 @@ async def chat_stream(
180
  # Handle query transformation
181
  if event["name"] == "transform_query" and event["event"] == "on_chain_end":
182
  if hasattr(history[-1], "content"):
183
- sub_questions = [q["question"] for q in event["data"]["output"]["questions_list"]]
184
  history[-1].content += "Decompose question into sub-questions:\n\n - " + "\n - ".join(sub_questions)
185
 
186
  yield history, docs_html, output_query, output_language, related_contents, graphs_html
 
119
  start_streaming = False
120
  graphs_html = ""
121
  used_documents = []
122
+ retrieved_contents = []
123
  answer_message_content = ""
124
 
125
  # Define processing steps
 
139
 
140
  # Handle document retrieval
141
  if event["event"] == "on_chain_end" and event["name"] in ["retrieve_documents","retrieve_local_data"] and event["data"]["output"] != None:
142
+ history, used_documents, retrieved_contents = handle_retrieved_documents(
143
+ event, history, used_documents, retrieved_contents
144
  )
145
  if event["event"] == "on_chain_end" and event["name"] == "answer_search" :
146
  docs = event["data"]["input"]["documents"]
 
181
  # Handle query transformation
182
  if event["name"] == "transform_query" and event["event"] == "on_chain_end":
183
  if hasattr(history[-1], "content"):
184
+ sub_questions = [q["question"] + "-> relevant sources : " + str(q["sources"]) for q in event["data"]["output"]["questions_list"]]
185
  history[-1].content += "Decompose question into sub-questions:\n\n - " + "\n - ".join(sub_questions)
186
 
187
  yield history, docs_html, output_query, output_language, related_contents, graphs_html
climateqa/engine/chains/answer_rag.py CHANGED
@@ -61,7 +61,7 @@ def make_rag_node(llm,with_docs = True):
61
  rag_chain = make_rag_chain(llm)
62
  else:
63
  rag_chain = make_rag_chain_without_docs(llm)
64
-
65
  async def answer_rag(state,config):
66
  print("---- Answer RAG ----")
67
  start_time = time.time()
 
61
  rag_chain = make_rag_chain(llm)
62
  else:
63
  rag_chain = make_rag_chain_without_docs(llm)
64
+
65
  async def answer_rag(state,config):
66
  print("---- Answer RAG ----")
67
  start_time = time.time()
climateqa/engine/chains/query_transformation.py CHANGED
@@ -60,7 +60,7 @@ from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
60
 
61
 
62
  ROUTING_INDEX = {
63
- "IPx":["IPCC", "IPBS", "IPOS"],
64
  "POC": ["AcclimaTerra", "PCAET","Biodiv"],
65
  "OpenAlex":["OpenAlex"],
66
  }
 
60
 
61
 
62
  ROUTING_INDEX = {
63
+ "IPx":["IPCC", "IPBES", "IPOS"],
64
  "POC": ["AcclimaTerra", "PCAET","Biodiv"],
65
  "OpenAlex":["OpenAlex"],
66
  }
climateqa/engine/chains/retrieve_documents.py CHANGED
@@ -15,7 +15,9 @@ from ..utils import log_event
15
  from langchain_core.vectorstores import VectorStore
16
  from typing import List
17
  from langchain_core.documents.base import Document
 
18
 
 
19
 
20
 
21
  def divide_into_parts(target, parts):
@@ -272,12 +274,27 @@ def concatenate_documents(index, source_type, docs_question_dict, k_by_question,
272
 
273
  # The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results
274
  # @chain
275
- async def retrieve_documents(state,config, source_type, vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5, k_images=5):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  """
277
  Unpack the first question of the remaining questions, and retrieve and rerank corresponding documents, based on the question and selected_sources
278
 
279
  Args:
280
  state (dict): The current state containing documents, related content, relevant content sources, remaining questions and n_questions.
 
281
  config (dict): Configuration settings for logging and other purposes.
282
  vectorstore (object): The vector store used to retrieve relevant documents.
283
  reranker (object): The reranker used to rerank the retrieved documents.
@@ -290,35 +307,6 @@ async def retrieve_documents(state,config, source_type, vectorstore,reranker,llm
290
  Returns:
291
  dict: The updated state containing the retrieved and reranked documents, related content, and remaining questions.
292
  """
293
- # TODO split les questions selon le type de sources dans le state question + conditions sur le nombre de questions traités par type de source
294
- docs = state.get("documents", [])
295
- related_content = state.get("related_content", [])
296
-
297
- search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"]
298
- search_only = state["search_only"]
299
-
300
- reports = state["reports"]
301
-
302
- # Get the current question
303
- # current_question = state["questions_list"][0]
304
- # remaining_questions = state["remaining_questions"][1:]
305
-
306
- current_question_id = None
307
- print("Questions Indexs", list(range(len(state["questions_list"]))), "- Handled questions : " ,state["handled_questions_index"])
308
-
309
- for i in range(len(state["questions_list"])):
310
- current_question = state["questions_list"][i]
311
-
312
- if i not in state["handled_questions_index"] and current_question["source_type"] == source_type:
313
- current_question_id = i
314
- break
315
-
316
- # TODO filter on source_type
317
-
318
- k_by_question = k_final // state["n_questions"]["total"]
319
- k_summary_by_question = _get_k_summary_by_question(state["n_questions"]["total"])
320
- k_images_by_question = _get_k_images_by_question(state["n_questions"]["total"])
321
-
322
  sources = current_question["sources"]
323
  question = current_question["question"]
324
  index = current_question["index"]
@@ -329,8 +317,7 @@ async def retrieve_documents(state,config, source_type, vectorstore,reranker,llm
329
 
330
  print(f"""---- Retrieve documents from {current_question["source_type"]}----""")
331
 
332
- # if index == "Vector": # always true for now #TODO rename to IPx
333
- if source_type == "IPx": # always true for now #TODO rename to IPx
334
  docs_question_dict = await get_IPCC_relevant_documents(
335
  query = question,
336
  vectorstore=vectorstore,
@@ -359,7 +346,6 @@ async def retrieve_documents(state,config, source_type, vectorstore,reranker,llm
359
  k_images= k_by_question
360
  )
361
 
362
-
363
  # Rerank
364
  if reranker is not None and rerank_by_question:
365
  with suppress_output():
@@ -381,35 +367,72 @@ async def retrieve_documents(state,config, source_type, vectorstore,reranker,llm
381
  docs_question = _add_sources_used_in_metadata(docs_question,sources,question,index)
382
  images_question = _add_sources_used_in_metadata(images_question,sources,question,index)
383
 
384
- # Add to the list of docs
385
- # docs.extend(docs_question)
386
- # related_content.extend(images_question)
387
- docs = docs_question
388
- related_content = images_question
389
- new_state = {"documents":docs, "related_contents": related_content, "handled_questions_index": [current_question_id]}
390
- print("Updated state with question ", current_question_id, " added ", len(docs), " documents")
391
- return new_state
392
 
393
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
 
395
  def make_IPx_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
396
 
397
  @chain
398
  async def retrieve_IPx_docs(state, config):
399
  source_type = "IPx"
 
 
400
  # return {"documents":[], "related_contents": [], "handled_questions_index": list(range(len(state["questions_list"])))} # TODO Remove
401
 
402
- state = await retrieve_documents(
403
- state = state,
404
- config= config,
405
  source_type=source_type,
 
406
  vectorstore=vectorstore,
407
- reranker= reranker,
408
- llm=llm,
409
  rerank_by_question=rerank_by_question,
410
- k_final=k_final,
411
- k_before_reranking=k_before_reranking,
412
- k_summary=k_summary
413
  )
414
  return state
415
 
@@ -420,19 +443,23 @@ def make_POC_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_
420
 
421
  @chain
422
  async def retrieve_POC_docs_node(state, config):
 
 
 
423
  source_type = "POC"
424
- state = await retrieve_documents(
425
- state = state,
426
- config= config,
 
 
427
  source_type=source_type,
 
428
  vectorstore=vectorstore,
429
- reranker= reranker,
430
- llm=llm,
431
  rerank_by_question=rerank_by_question,
432
- k_final=k_final,
433
- k_before_reranking=k_before_reranking,
434
- k_summary=k_summary
435
- )
436
  return state
437
 
438
  return retrieve_POC_docs_node
 
15
  from langchain_core.vectorstores import VectorStore
16
  from typing import List
17
  from langchain_core.documents.base import Document
18
+ import asyncio
19
 
20
+ from typing import Any, Dict, List, Tuple
21
 
22
 
23
  def divide_into_parts(target, parts):
 
274
 
275
  # The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results
276
  # @chain
277
+ async def retrieve_documents(
278
+ current_question: Dict[str, Any],
279
+ config: Dict[str, Any],
280
+ source_type: str,
281
+ vectorstore: VectorStore,
282
+ reranker: Any,
283
+ search_figures: bool = False,
284
+ search_only: bool = False,
285
+ reports: list = [],
286
+ rerank_by_question: bool = True,
287
+ k_images_by_question: int = 5,
288
+ k_before_reranking: int = 100,
289
+ k_by_question: int = 5,
290
+ k_summary_by_question: int = 3
291
+ ) -> Tuple[List[Document], List[Document]]:
292
  """
293
  Unpack the first question of the remaining questions, and retrieve and rerank corresponding documents, based on the question and selected_sources
294
 
295
  Args:
296
  state (dict): The current state containing documents, related content, relevant content sources, remaining questions and n_questions.
297
+ current_question (dict): The current question being processed.
298
  config (dict): Configuration settings for logging and other purposes.
299
  vectorstore (object): The vector store used to retrieve relevant documents.
300
  reranker (object): The reranker used to rerank the retrieved documents.
 
307
  Returns:
308
  dict: The updated state containing the retrieved and reranked documents, related content, and remaining questions.
309
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  sources = current_question["sources"]
311
  question = current_question["question"]
312
  index = current_question["index"]
 
317
 
318
  print(f"""---- Retrieve documents from {current_question["source_type"]}----""")
319
 
320
+ if source_type == "IPx":
 
321
  docs_question_dict = await get_IPCC_relevant_documents(
322
  query = question,
323
  vectorstore=vectorstore,
 
346
  k_images= k_by_question
347
  )
348
 
 
349
  # Rerank
350
  if reranker is not None and rerank_by_question:
351
  with suppress_output():
 
367
  docs_question = _add_sources_used_in_metadata(docs_question,sources,question,index)
368
  images_question = _add_sources_used_in_metadata(images_question,sources,question,index)
369
 
370
+ return docs_question, images_question
 
 
 
 
 
 
 
371
 
372
 
373
+ async def retrieve_documents_for_all_questions(state, config, source_type, to_handle_questions_index, vectorstore, reranker, rerank_by_question=True, k_final=15, k_before_reranking=100):
374
+ """
375
+ Retrieve documents in parallel for all questions.
376
+ """
377
+ # to_handle_questions_index = [x for x in state["questions_list"] if x["source_type"] == "IPx"]
378
+
379
+ # TODO split les questions selon le type de sources dans le state question + conditions sur le nombre de questions traités par type de source
380
+ docs = state.get("documents", [])
381
+ related_content = state.get("related_content", [])
382
+ search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"]
383
+ search_only = state["search_only"]
384
+ reports = state["reports"]
385
+
386
+ k_by_question = k_final // state["n_questions"]["total"]
387
+ k_summary_by_question = _get_k_summary_by_question(state["n_questions"]["total"])
388
+ k_images_by_question = _get_k_images_by_question(state["n_questions"]["total"])
389
+ k_before_reranking=100
390
+
391
+ tasks = [
392
+ retrieve_documents(
393
+ current_question=question,
394
+ config=config,
395
+ source_type=source_type,
396
+ vectorstore=vectorstore,
397
+ reranker=reranker,
398
+ search_figures=search_figures,
399
+ search_only=search_only,
400
+ reports=reports,
401
+ rerank_by_question=rerank_by_question,
402
+ k_images_by_question=k_images_by_question,
403
+ k_before_reranking=k_before_reranking,
404
+ k_by_question=k_by_question,
405
+ k_summary_by_question=k_summary_by_question
406
+ )
407
+ for i, question in enumerate(state["questions_list"]) if i in to_handle_questions_index
408
+ ]
409
+ results = await asyncio.gather(*tasks)
410
+ # Combine results
411
+ new_state = {"documents": [], "related_contents": [], "handled_questions_index": to_handle_questions_index}
412
+ for docs_question, images_question in results:
413
+ new_state["documents"].extend(docs_question)
414
+ new_state["related_contents"].extend(images_question)
415
+ return new_state
416
 
417
  def make_IPx_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
418
 
419
  @chain
420
  async def retrieve_IPx_docs(state, config):
421
  source_type = "IPx"
422
+ IPx_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "IPx"]
423
+
424
  # return {"documents":[], "related_contents": [], "handled_questions_index": list(range(len(state["questions_list"])))} # TODO Remove
425
 
426
+ state = await retrieve_documents_for_all_questions(
427
+ state=state,
428
+ config=config,
429
  source_type=source_type,
430
+ to_handle_questions_index=IPx_questions_index,
431
  vectorstore=vectorstore,
432
+ reranker=reranker,
 
433
  rerank_by_question=rerank_by_question,
434
+ k_final=k_final,
435
+ k_before_reranking=k_before_reranking,
 
436
  )
437
  return state
438
 
 
443
 
444
  @chain
445
  async def retrieve_POC_docs_node(state, config):
446
+ if "POC region" not in state["relevant_content_sources_selection"] :
447
+ return {}
448
+
449
  source_type = "POC"
450
+ POC_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"]
451
+
452
+ state = await retrieve_documents_for_all_questions(
453
+ state=state,
454
+ config=config,
455
  source_type=source_type,
456
+ to_handle_questions_index=POC_questions_index,
457
  vectorstore=vectorstore,
458
+ reranker=reranker,
 
459
  rerank_by_question=rerank_by_question,
460
+ k_final=k_final,
461
+ k_before_reranking=k_before_reranking,
462
+ )
 
463
  return state
464
 
465
  return retrieve_POC_docs_node
climateqa/engine/graph.py CHANGED
@@ -95,10 +95,10 @@ def route_based_on_relevant_docs(state,threshold_docs=0.2):
95
  def route_continue_retrieve_documents(state):
96
  index_question_ipx = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "IPx"]
97
  questions_ipx_finished = all(elem in state["handled_questions_index"] for elem in index_question_ipx)
98
- if questions_ipx_finished and state["search_only"]:
99
- return END
100
- elif questions_ipx_finished:
101
- return "answer_search"
102
  else:
103
  return "retrieve_documents"
104
 
@@ -113,10 +113,10 @@ def route_continue_retrieve_documents(state):
113
  def route_continue_retrieve_local_documents(state):
114
  index_question_poc = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"]
115
  questions_poc_finished = all(elem in state["handled_questions_index"] for elem in index_question_poc)
116
- if questions_poc_finished and state["search_only"]:
117
- return END
118
- elif questions_poc_finished:
119
- return "answer_search"
120
  else:
121
  return "retrieve_local_data"
122
 
@@ -139,8 +139,7 @@ def route_retrieve_documents(state):
139
 
140
  if "Graphs (OurWorldInData)" in state["relevant_content_sources_selection"] :
141
  sources_to_retrieve.append("retrieve_graphs")
142
- if "POC region" in state["relevant_content_sources_selection"] :
143
- sources_to_retrieve.append("retrieve_local_data")
144
  if sources_to_retrieve == []:
145
  return END
146
  return sources_to_retrieve
@@ -160,7 +159,7 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_regi
160
  answer_ai_impact = make_ai_impact_node(llm)
161
  retrieve_documents = make_IPx_retriever_node(vectorstore_ipcc, reranker, llm)
162
  retrieve_graphs = make_graph_retriever_node(vectorstore_graphs, reranker)
163
- retrieve_local_data = make_POC_retriever_node(vectorstore_region, reranker, llm)
164
  answer_rag = make_rag_node(llm, with_docs=True)
165
  answer_rag_no_docs = make_rag_node(llm, with_docs=False)
166
  chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
@@ -175,7 +174,7 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_regi
175
  workflow.add_node("answer_chitchat", answer_chitchat)
176
  workflow.add_node("chitchat_categorize_intent", chitchat_categorize_intent)
177
  workflow.add_node("retrieve_graphs", retrieve_graphs)
178
- workflow.add_node("retrieve_local_data", retrieve_local_data)
179
  workflow.add_node("retrieve_graphs_chitchat", retrieve_graphs)
180
  workflow.add_node("retrieve_documents", retrieve_documents)
181
  workflow.add_node("answer_rag", answer_rag)
@@ -202,17 +201,92 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_regi
202
  route_translation,
203
  make_id_dict(["translate_query","transform_query"])
204
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  workflow.add_conditional_edges(
206
- "retrieve_documents",
207
- # lambda state : "retrieve_documents" if len(state["remaining_questions"]) > 0 else "answer_search",
208
- route_continue_retrieve_documents,
209
- make_id_dict([END,"retrieve_documents","answer_search"])
210
  )
 
211
  workflow.add_conditional_edges(
212
- "retrieve_local_data",
213
- # lambda state : "retrieve_documents" if len(state["remaining_questions"]) > 0 else "answer_search",
214
- route_continue_retrieve_local_documents,
215
- make_id_dict([END,"retrieve_local_data","answer_search"])
216
  )
217
 
218
  workflow.add_conditional_edges(
@@ -223,19 +297,13 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_regi
223
  workflow.add_conditional_edges(
224
  "transform_query",
225
  route_retrieve_documents,
226
- make_id_dict(["retrieve_graphs","retrieve_local_data", END])
227
  )
228
-
229
-
230
- # workflow.add_conditional_edges(
231
- # "transform_query",
232
- # lambda state : "retrieve_graphs" if "POC region" in state["relevant_content_sources_selection"] else END,
233
- # make_id_dict(["retrieve_local_data", END])
234
- # )
235
 
236
  # Define the edges
237
  workflow.add_edge("translate_query", "transform_query")
238
  workflow.add_edge("transform_query", "retrieve_documents") #TODO put back
 
239
  # workflow.add_edge("transform_query", END) # TODO remove
240
 
241
  workflow.add_edge("retrieve_graphs", END)
@@ -243,7 +311,9 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_regi
243
  workflow.add_edge("answer_rag_no_docs", END)
244
  workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
245
  workflow.add_edge("retrieve_graphs_chitchat", END)
246
- # workflow.add_edge("retrieve_local_data", "answer_search")
 
 
247
 
248
  # Compile
249
  app = workflow.compile()
 
95
  def route_continue_retrieve_documents(state):
96
  index_question_ipx = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "IPx"]
97
  questions_ipx_finished = all(elem in state["handled_questions_index"] for elem in index_question_ipx)
98
+ # if questions_ipx_finished and state["search_only"]:
99
+ # return END
100
+ if questions_ipx_finished:
101
+ return "end_retrieve_IPx_documents"
102
  else:
103
  return "retrieve_documents"
104
 
 
113
  def route_continue_retrieve_local_documents(state):
114
  index_question_poc = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"]
115
  questions_poc_finished = all(elem in state["handled_questions_index"] for elem in index_question_poc)
116
+ # if questions_poc_finished and state["search_only"]:
117
+ # return END
118
+ if questions_poc_finished or ("POC region" not in state["relevant_content_sources_selection"]):
119
+ return "end_retrieve_local_documents"
120
  else:
121
  return "retrieve_local_data"
122
 
 
139
 
140
  if "Graphs (OurWorldInData)" in state["relevant_content_sources_selection"] :
141
  sources_to_retrieve.append("retrieve_graphs")
142
+
 
143
  if sources_to_retrieve == []:
144
  return END
145
  return sources_to_retrieve
 
159
  answer_ai_impact = make_ai_impact_node(llm)
160
  retrieve_documents = make_IPx_retriever_node(vectorstore_ipcc, reranker, llm)
161
  retrieve_graphs = make_graph_retriever_node(vectorstore_graphs, reranker)
162
+ # retrieve_local_data = make_POC_retriever_node(vectorstore_region, reranker, llm)
163
  answer_rag = make_rag_node(llm, with_docs=True)
164
  answer_rag_no_docs = make_rag_node(llm, with_docs=False)
165
  chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
 
174
  workflow.add_node("answer_chitchat", answer_chitchat)
175
  workflow.add_node("chitchat_categorize_intent", chitchat_categorize_intent)
176
  workflow.add_node("retrieve_graphs", retrieve_graphs)
177
+ # workflow.add_node("retrieve_local_data", retrieve_local_data)
178
  workflow.add_node("retrieve_graphs_chitchat", retrieve_graphs)
179
  workflow.add_node("retrieve_documents", retrieve_documents)
180
  workflow.add_node("answer_rag", answer_rag)
 
201
  route_translation,
202
  make_id_dict(["translate_query","transform_query"])
203
  )
204
+
205
+ workflow.add_conditional_edges(
206
+ "answer_search",
207
+ lambda x : route_based_on_relevant_docs(x,threshold_docs=threshold_docs),
208
+ make_id_dict(["answer_rag","answer_rag_no_docs"])
209
+ )
210
+ workflow.add_conditional_edges(
211
+ "transform_query",
212
+ route_retrieve_documents,
213
+ make_id_dict(["retrieve_graphs", END])
214
+ )
215
+
216
+ # Define the edges
217
+ workflow.add_edge("translate_query", "transform_query")
218
+ workflow.add_edge("transform_query", "retrieve_documents") #TODO put back
219
+ # workflow.add_edge("transform_query", "retrieve_local_data")
220
+ # workflow.add_edge("transform_query", END) # TODO remove
221
+
222
+ workflow.add_edge("retrieve_graphs", END)
223
+ workflow.add_edge("answer_rag", END)
224
+ workflow.add_edge("answer_rag_no_docs", END)
225
+ workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
226
+ workflow.add_edge("retrieve_graphs_chitchat", END)
227
+
228
+ # workflow.add_edge("retrieve_local_data", "answer_search")
229
+ workflow.add_edge("retrieve_documents", "answer_search")
230
+
231
+ # Compile
232
+ app = workflow.compile()
233
+ return app
234
+
235
+ def make_graph_agent_poc(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_region, reranker, threshold_docs=0.2):
236
+
237
+ workflow = StateGraph(GraphState)
238
+
239
+ # Define the node functions
240
+ categorize_intent = make_intent_categorization_node(llm)
241
+ transform_query = make_query_transform_node(llm)
242
+ translate_query = make_translation_node(llm)
243
+ answer_chitchat = make_chitchat_node(llm)
244
+ answer_ai_impact = make_ai_impact_node(llm)
245
+ retrieve_documents = make_IPx_retriever_node(vectorstore_ipcc, reranker, llm)
246
+ retrieve_graphs = make_graph_retriever_node(vectorstore_graphs, reranker)
247
+ retrieve_local_data = make_POC_retriever_node(vectorstore_region, reranker, llm)
248
+ answer_rag = make_rag_node(llm, with_docs=True)
249
+ answer_rag_no_docs = make_rag_node(llm, with_docs=False)
250
+ chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
251
+
252
+ # Define the nodes
253
+ # workflow.add_node("set_defaults", set_defaults)
254
+ workflow.add_node("categorize_intent", categorize_intent)
255
+ workflow.add_node("answer_climate", dummy)
256
+ workflow.add_node("answer_search", answer_search)
257
+ # workflow.add_node("end_retrieve_local_documents", dummy)
258
+ # workflow.add_node("end_retrieve_IPx_documents", dummy)
259
+ workflow.add_node("transform_query", transform_query)
260
+ workflow.add_node("translate_query", translate_query)
261
+ workflow.add_node("answer_chitchat", answer_chitchat)
262
+ workflow.add_node("chitchat_categorize_intent", chitchat_categorize_intent)
263
+ workflow.add_node("retrieve_graphs", retrieve_graphs)
264
+ workflow.add_node("retrieve_local_data", retrieve_local_data)
265
+ workflow.add_node("retrieve_graphs_chitchat", retrieve_graphs)
266
+ workflow.add_node("retrieve_documents", retrieve_documents)
267
+ workflow.add_node("answer_rag", answer_rag)
268
+ workflow.add_node("answer_rag_no_docs", answer_rag_no_docs)
269
+
270
+ # Entry point
271
+ workflow.set_entry_point("categorize_intent")
272
+
273
+ # CONDITIONAL EDGES
274
+ workflow.add_conditional_edges(
275
+ "categorize_intent",
276
+ route_intent,
277
+ make_id_dict(["answer_chitchat","answer_climate"])
278
+ )
279
+
280
  workflow.add_conditional_edges(
281
+ "chitchat_categorize_intent",
282
+ chitchat_route_intent,
283
+ make_id_dict(["retrieve_graphs_chitchat", END])
 
284
  )
285
+
286
  workflow.add_conditional_edges(
287
+ "answer_climate",
288
+ route_translation,
289
+ make_id_dict(["translate_query","transform_query"])
 
290
  )
291
 
292
  workflow.add_conditional_edges(
 
297
  workflow.add_conditional_edges(
298
  "transform_query",
299
  route_retrieve_documents,
300
+ make_id_dict(["retrieve_graphs", END])
301
  )
 
 
 
 
 
 
 
302
 
303
  # Define the edges
304
  workflow.add_edge("translate_query", "transform_query")
305
  workflow.add_edge("transform_query", "retrieve_documents") #TODO put back
306
+ workflow.add_edge("transform_query", "retrieve_local_data")
307
  # workflow.add_edge("transform_query", END) # TODO remove
308
 
309
  workflow.add_edge("retrieve_graphs", END)
 
311
  workflow.add_edge("answer_rag_no_docs", END)
312
  workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
313
  workflow.add_edge("retrieve_graphs_chitchat", END)
314
+
315
+ workflow.add_edge("retrieve_local_data", "answer_search")
316
+ workflow.add_edge("retrieve_documents", "answer_search")
317
 
318
  # Compile
319
  app = workflow.compile()
climateqa/handle_stream_events.py CHANGED
@@ -22,7 +22,7 @@ def convert_to_docs_to_html(docs: list[dict]) -> str:
22
  docs_html.append(make_html_source(d, i))
23
  return "".join(docs_html)
24
 
25
- def handle_retrieved_documents(event: StreamEvent, history : list[ChatMessage], used_documents : list[str]) -> tuple[str, list[ChatMessage], list[str]]:
26
  """
27
  Handles the retrieved documents and returns the HTML representation of the documents
28
 
@@ -35,7 +35,7 @@ def handle_retrieved_documents(event: StreamEvent, history : list[ChatMessage],
35
  tuple[str, list[ChatMessage], list[str]]: The updated HTML representation of the documents, the updated message history and the updated list of used documents
36
  """
37
  if "documents" not in event["data"]["output"] or event["data"]["output"]["documents"] == []:
38
- return history, used_documents
39
 
40
  try:
41
  docs = event["data"]["output"]["documents"]
@@ -49,7 +49,7 @@ def handle_retrieved_documents(event: StreamEvent, history : list[ChatMessage],
49
  except Exception as e:
50
  print(f"Error getting documents: {e}")
51
  print(event)
52
- return history, used_documents
53
 
54
  def stream_answer(history: list[ChatMessage], event : StreamEvent, start_streaming : bool, answer_message_content : str)-> tuple[list[ChatMessage], bool, str]:
55
  """
 
22
  docs_html.append(make_html_source(d, i))
23
  return "".join(docs_html)
24
 
25
+ def handle_retrieved_documents(event: StreamEvent, history : list[ChatMessage], used_documents : list[str],related_content:list[str]) -> tuple[str, list[ChatMessage], list[str]]:
26
  """
27
  Handles the retrieved documents and returns the HTML representation of the documents
28
 
 
35
  tuple[str, list[ChatMessage], list[str]]: The updated HTML representation of the documents, the updated message history and the updated list of used documents
36
  """
37
  if "documents" not in event["data"]["output"] or event["data"]["output"]["documents"] == []:
38
+ return history, used_documents, related_content
39
 
40
  try:
41
  docs = event["data"]["output"]["documents"]
 
49
  except Exception as e:
50
  print(f"Error getting documents: {e}")
51
  print(event)
52
+ return history, used_documents, related_content
53
 
54
  def stream_answer(history: list[ChatMessage], event : StreamEvent, start_streaming : bool, answer_message_content : str)-> tuple[list[ChatMessage], bool, str]:
55
  """
front/tabs/chat_interface.py CHANGED
@@ -44,7 +44,7 @@ def create_chat_interface():
44
  scale=12,
45
  lines=1,
46
  interactive=True,
47
- elem_id="input-textbox"
48
  )
49
 
50
  config_button = gr.Button("", elem_id="config-button")
 
44
  scale=12,
45
  lines=1,
46
  interactive=True,
47
+ elem_id=f"input-textbox"
48
  )
49
 
50
  config_button = gr.Button("", elem_id="config-button")
front/tabs/tab_examples.py CHANGED
@@ -3,7 +3,7 @@ from climateqa.sample_questions import QUESTIONS
3
 
4
 
5
  def create_examples_tab():
6
- examples_hidden = gr.Textbox(visible=False)
7
  first_key = list(QUESTIONS.keys())[0]
8
  dropdown_samples = gr.Dropdown(
9
  choices=QUESTIONS.keys(),
 
3
 
4
 
5
  def create_examples_tab():
6
+ examples_hidden = gr.Textbox(visible=False, elem_id=f"examples-hidden")
7
  first_key = list(QUESTIONS.keys())[0]
8
  dropdown_samples = gr.Dropdown(
9
  choices=QUESTIONS.keys(),
requirements.txt CHANGED
@@ -4,7 +4,7 @@ azure-storage-blob
4
  python-dotenv==1.0.0
5
  langchain==0.2.1
6
  langchain_openai==0.1.7
7
- langgraph==0.0.55
8
  pinecone-client==4.1.0
9
  sentence-transformers==2.6.0
10
  huggingface-hub
 
4
  python-dotenv==1.0.0
5
  langchain==0.2.1
6
  langchain_openai==0.1.7
7
+ langgraph==0.2.70
8
  pinecone-client==4.1.0
9
  sentence-transformers==2.6.0
10
  huggingface-hub