timeki commited on
Commit
be494ba
·
1 Parent(s): d09f2e9

WIP add regionnal sources

Browse files
app.py CHANGED
@@ -103,11 +103,12 @@ CITATION_TEXT = r"""@misc{climateqa,
103
  embeddings_function = get_embeddings_function()
104
  vectorstore = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX"))
105
  vectorstore_graphs = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_OWID"), text_key="description")
 
106
 
107
  llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
108
  reranker = get_reranker("nano")
109
 
110
- agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, reranker=reranker)
111
 
112
  # Function to update modal visibility
113
  def update_config_modal_visibility(config_open):
@@ -149,7 +150,7 @@ async def chat(
149
  print(f">> NEW QUESTION ({date_now}) : {query}")
150
 
151
  audience_prompt = init_audience(audience)
152
- sources = sources or ["IPCC", "IPBES", "IPOS"]
153
  reports = reports or []
154
 
155
  # Prepare inputs for agent
@@ -606,9 +607,7 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
606
  outputs=[graphs_container]
607
  )
608
 
609
-
610
-
611
- # Other tabs
612
  with gr.Tab("About", elem_classes="max-height other-tabs"):
613
  with gr.Row():
614
  with gr.Column(scale=1):
@@ -629,10 +628,10 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
629
  show_copy_button=True,
630
  lines=len(CITATION_TEXT.split('\n')),
631
  )
632
-
633
- # Event handlers
634
  config_modal, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only, dropdown_audience, after, output_query, output_language = create_config_modal(config_open)
635
 
 
636
  config_button.click(
637
  fn=update_config_modal_visibility,
638
  inputs=[config_open],
 
103
  embeddings_function = get_embeddings_function()
104
  vectorstore = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX"))
105
  vectorstore_graphs = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_OWID"), text_key="description")
106
+ vectorstore_region = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_REGION"))
107
 
108
  llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
109
  reranker = get_reranker("nano")
110
 
111
+ agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker)
112
 
113
  # Function to update modal visibility
114
  def update_config_modal_visibility(config_open):
 
150
  print(f">> NEW QUESTION ({date_now}) : {query}")
151
 
152
  audience_prompt = init_audience(audience)
153
+ sources = sources or ["IPCC", "IPBES"]
154
  reports = reports or []
155
 
156
  # Prepare inputs for agent
 
607
  outputs=[graphs_container]
608
  )
609
 
610
+
 
 
611
  with gr.Tab("About", elem_classes="max-height other-tabs"):
612
  with gr.Row():
613
  with gr.Column(scale=1):
 
628
  show_copy_button=True,
629
  lines=len(CITATION_TEXT.split('\n')),
630
  )
631
+ # Configuration pannel
 
632
  config_modal, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only, dropdown_audience, after, output_query, output_language = create_config_modal(config_open)
633
 
634
+ # Event handlers
635
  config_button.click(
636
  fn=update_config_modal_visibility,
637
  inputs=[config_open],
climateqa/engine/chains/query_transformation.py CHANGED
@@ -9,7 +9,7 @@ from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
9
 
10
 
11
  ROUTING_INDEX = {
12
- "Vector":["IPCC","IPBES","IPOS"],
13
  "OpenAlex":["OpenAlex"],
14
  }
15
 
@@ -69,13 +69,14 @@ class QueryAnalysis(BaseModel):
69
  # """
70
  # )
71
 
72
- sources: List[Literal["IPCC", "IPBES", "IPOS"]] = Field( #,"OpenAlex"]] = Field(
73
  ...,
74
  description="""
75
  Given a user question choose which documents would be most relevant for answering their question,
76
  - IPCC is for questions about climate change, energy, impacts, and everything we can find the IPCC reports
77
  - IPBES is for questions about biodiversity and nature
78
  - IPOS is for questions about the ocean and deep sea mining
 
79
  """,
80
  # - OpenAlex is for any other questions that are not in the previous categories but could be found in the scientific litterature
81
  )
@@ -133,6 +134,23 @@ def make_query_rewriter_chain(llm):
133
 
134
 
135
  def make_query_transform_node(llm,k_final=15):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
  decomposition_chain = make_query_decomposition_chain(llm)
138
  rewriter_chain = make_query_rewriter_chain(llm)
@@ -140,14 +158,9 @@ def make_query_transform_node(llm,k_final=15):
140
  def transform_query(state):
141
  print("---- Transform query ----")
142
 
143
-
144
- if "sources_auto" not in state or state["sources_auto"] is None or state["sources_auto"] is False:
145
- auto_mode = False
146
- else:
147
- auto_mode = True
148
-
149
- sources_input = state.get("sources_input")
150
- if sources_input is None: sources_input = ROUTING_INDEX["Vector"]
151
 
152
  new_state = {}
153
 
@@ -159,7 +172,7 @@ def make_query_transform_node(llm,k_final=15):
159
  questions = []
160
  for question in new_state["questions"]:
161
  question_state = {"question":question}
162
- analysis_output = rewriter_chain.invoke({"input":question})
163
 
164
  # TODO WARNING llm should always return smthg
165
  # The case when the llm does not return any sources
 
9
 
10
 
11
  ROUTING_INDEX = {
12
+ "Vector":["IPCC","IPBES","IPOS", "AcclimaTerra"],
13
  "OpenAlex":["OpenAlex"],
14
  }
15
 
 
69
  # """
70
  # )
71
 
72
+ sources: List[Literal["IPCC", "IPBES", "IPOS", "AcclimaTerra"]] = Field( #,"OpenAlex"]] = Field(
73
  ...,
74
  description="""
75
  Given a user question choose which documents would be most relevant for answering their question,
76
  - IPCC is for questions about climate change, energy, impacts, and everything we can find the IPCC reports
77
  - IPBES is for questions about biodiversity and nature
78
  - IPOS is for questions about the ocean and deep sea mining
79
+ - AcclimaTerra is for questions about any specific place in, or close to, the french region "Nouvelle-Aquitaine"
80
  """,
81
  # - OpenAlex is for any other questions that are not in the previous categories but could be found in the scientific litterature
82
  )
 
134
 
135
 
136
  def make_query_transform_node(llm,k_final=15):
137
+ """
138
+ Creates a query transformation node that processes and transforms a given query state.
139
+ Args:
140
+ llm: The language model to be used for query decomposition and rewriting.
141
+ k_final (int, optional): The final number of questions to be generated. Defaults to 15.
142
+ Returns:
143
+ function: A function that takes a query state and returns a transformed state.
144
+ The returned function performs the following steps:
145
+ 1. Checks if the query should be processed in auto mode based on the state.
146
+ 2. Retrieves the input sources from the state or defaults to a predefined routing index.
147
+ 3. Decomposes the query using the decomposition chain.
148
+ 4. Analyzes each decomposed question using the rewriter chain.
149
+ 5. Ensures that the sources returned by the language model are valid.
150
+ 6. Explodes the questions into multiple questions with different sources based on the mode.
151
+ 7. Constructs a new state with the transformed questions and their respective sources.
152
+ """
153
+
154
 
155
  decomposition_chain = make_query_decomposition_chain(llm)
156
  rewriter_chain = make_query_rewriter_chain(llm)
 
158
  def transform_query(state):
159
  print("---- Transform query ----")
160
 
161
+ auto_mode = state.get("sources_auto", False)
162
+ sources_input = state.get("sources_input", ROUTING_INDEX["Vector"])
163
+
 
 
 
 
 
164
 
165
  new_state = {}
166
 
 
172
  questions = []
173
  for question in new_state["questions"]:
174
  question_state = {"question":question}
175
+ analysis_output = rewriter_chain.invoke({"input":question})
176
 
177
  # TODO WARNING llm should always return smthg
178
  # The case when the llm does not return any sources
climateqa/engine/chains/retrieve_documents.py CHANGED
@@ -7,7 +7,7 @@ from langchain_core.runnables import chain
7
  from langchain_core.runnables import RunnableParallel, RunnablePassthrough
8
  from langchain_core.runnables import RunnableLambda
9
 
10
- from ..reranker import rerank_docs
11
  # from ...knowledge.retriever import ClimateQARetriever
12
  from ...knowledge.openalex import OpenAlexRetriever
13
  from .keywords_extraction import make_keywords_extraction_chain
@@ -106,6 +106,48 @@ def _add_metadata_and_score(docs: List) -> Document:
106
  docs_with_metadata.append(doc)
107
  return docs_with_metadata
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  async def get_IPCC_relevant_documents(
110
  query: str,
111
  vectorstore:VectorStore,
@@ -191,12 +233,26 @@ async def get_IPCC_relevant_documents(
191
  }
192
 
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
  # The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results
196
  # @chain
197
- async def retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5, k_images=5):
198
  """
199
- Retrieve and rerank documents based on the current question in the state.
200
 
201
  Args:
202
  state (dict): The current state containing documents, related content, relevant content sources, remaining questions and n_questions.
@@ -212,7 +268,7 @@ async def retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_qu
212
  Returns:
213
  dict: The updated state containing the retrieved and reranked documents, related content, and remaining questions.
214
  """
215
- print("---- Retrieve documents ----")
216
  docs = state.get("documents", [])
217
  related_content = state.get("related_content", [])
218
 
@@ -237,45 +293,51 @@ async def retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_qu
237
  await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config)
238
 
239
 
240
- if index == "Vector": # always true for now
 
241
  docs_question_dict = await get_IPCC_relevant_documents(
242
  query = question,
243
  vectorstore=vectorstore,
244
  search_figures = search_figures,
245
  sources = sources,
246
  min_size = 200,
247
- k_summary = k_summary_by_question,
248
  k_total = k_before_reranking,
249
  k_images = k_images_by_question,
250
  threshold = 0.5,
251
  search_only = search_only,
252
  reports = reports,
253
  )
 
 
 
 
 
 
 
 
 
 
 
254
 
255
 
256
  # Rerank
257
- if reranker is not None:
258
  with suppress_output():
259
- docs_question_summary_reranked = rerank_docs(reranker,docs_question_dict["docs_summaries"],question)
260
- docs_question_fulltext_reranked = rerank_docs(reranker,docs_question_dict["docs_full"],question)
261
- docs_question_images_reranked = rerank_docs(reranker,docs_question_dict["docs_images"],question)
262
- if rerank_by_question:
263
- docs_question_summary_reranked = sorted(docs_question_summary_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
264
- docs_question_fulltext_reranked = sorted(docs_question_fulltext_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
265
- docs_question_images_reranked = sorted(docs_question_images_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
266
  else:
267
- docs_question = docs_question_dict["docs_summaries"] + docs_question_dict["docs_full"]
268
  # Add a default reranking score
269
  for doc in docs_question:
270
  doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
271
 
272
- docs_question = docs_question_summary_reranked + docs_question_fulltext_reranked
273
- docs_question = docs_question[:k_by_question]
274
- images_question = docs_question_images_reranked[:k_images]
275
-
276
  if reranker is not None and rerank_by_question:
277
- docs_question = sorted(docs_question, key=lambda x: x.metadata["reranking_score"], reverse=True)
278
-
279
  # Add sources used in the metadata
280
  docs_question = _add_sources_used_in_metadata(docs_question,sources,question,index)
281
  images_question = _add_sources_used_in_metadata(images_question,sources,question,index)
@@ -288,13 +350,26 @@ async def retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_qu
288
 
289
 
290
 
291
- def make_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
 
 
 
 
 
 
 
 
 
 
 
 
292
  @chain
293
- async def retrieve_docs(state, config):
294
- state = await retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_question, k_final, k_before_reranking, k_summary)
 
295
  return state
296
 
297
- return retrieve_docs
298
 
299
 
300
 
 
7
  from langchain_core.runnables import RunnableParallel, RunnablePassthrough
8
  from langchain_core.runnables import RunnableLambda
9
 
10
+ from ..reranker import rerank_docs, rerank_and_sort_docs
11
  # from ...knowledge.retriever import ClimateQARetriever
12
  from ...knowledge.openalex import OpenAlexRetriever
13
  from .keywords_extraction import make_keywords_extraction_chain
 
106
  docs_with_metadata.append(doc)
107
  return docs_with_metadata
108
 
109
+ async def get_POC_relevant_documents(
110
+ query: str,
111
+ vectorstore:VectorStore,
112
+ sources:list = ["Acclimaterra","PCAET","Plan Biodiversite"],
113
+ search_figures:bool = False,
114
+ search_only:bool = False,
115
+ k_documents:int = 10,
116
+ threshold:float = 0.6,
117
+ k_images: int = 5,
118
+ reports:list = [],
119
+ ) :
120
+ # Prepare base search kwargs
121
+ filters = {}
122
+
123
+ if len(reports) > 0:
124
+ filters["short_name"] = {"$in":reports}
125
+ else:
126
+ filters["source"] = { "$in": sources}
127
+
128
+ filters_text = {
129
+ **filters,
130
+ "chunk_type":"text",
131
+ # "report_type": {}, # TODO to be completed to choose the right documents / chapters according to the analysis of the question
132
+ }
133
+
134
+ docs_question = vectorstore.similarity_search_with_score(query=query,filter = filters_text,k = k_documents)
135
+ docs_question = [x for x in docs_question if x[1] > threshold]
136
+
137
+ if search_figures:
138
+ # Images
139
+ filters_image = {
140
+ **filters,
141
+ "chunk_type":"image"
142
+ }
143
+ docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images)
144
+
145
+ return {
146
+ "docs_question" : docs_question,
147
+ "docs_images" : docs_images
148
+ }
149
+
150
+
151
  async def get_IPCC_relevant_documents(
152
  query: str,
153
  vectorstore:VectorStore,
 
233
  }
234
 
235
 
236
+
237
+ def concatenate_documents(index, source_type, docs_question_dict, k_by_question, k_summary_by_question, k_images_by_question):
238
+ # Keep the right number of documents - The k_summary documents from SPM are placed in front
239
+ if source_type == "Vector" :
240
+ docs_question = docs_question_dict["docs_summaries"][:k_summary_by_question] + docs_question_dict["docs_full"][:k_by_question - k_summary_by_question]
241
+ elif source_type == "POC" :
242
+ docs_question = docs_question_dict["docs_question"][:k_by_question]
243
+ else :
244
+ docs_question = [doc for key in docs_question_dict.keys() for doc in docs_question_dict[key]]
245
+
246
+ images_question = docs_question_dict["docs_images"][:k_images_by_question]
247
+
248
+ return docs_question, images_question
249
+
250
 
251
  # The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results
252
  # @chain
253
+ async def retrieve_documents(state,config, source_type, vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5, k_images=5):
254
  """
255
+ Unpack the first question of the remaining questions, and retrieve and rerank corresponding documents, based on the question and selected_sources
256
 
257
  Args:
258
  state (dict): The current state containing documents, related content, relevant content sources, remaining questions and n_questions.
 
268
  Returns:
269
  dict: The updated state containing the retrieved and reranked documents, related content, and remaining questions.
270
  """
271
+ print(f"---- Retrieve documents from {source_type}----")
272
  docs = state.get("documents", [])
273
  related_content = state.get("related_content", [])
274
 
 
293
  await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config)
294
 
295
 
296
+ # if index == "Vector": # always true for now #TODO rename to IPx
297
+ if source_type == "IPx": # always true for now #TODO rename to IPx
298
  docs_question_dict = await get_IPCC_relevant_documents(
299
  query = question,
300
  vectorstore=vectorstore,
301
  search_figures = search_figures,
302
  sources = sources,
303
  min_size = 200,
304
+ k_summary = k_before_reranking-1,
305
  k_total = k_before_reranking,
306
  k_images = k_images_by_question,
307
  threshold = 0.5,
308
  search_only = search_only,
309
  reports = reports,
310
  )
311
+
312
+ if source_type == "POC":
313
+ docs_question_dict = await get_POC_relevant_documents(
314
+ query = question,
315
+ vectorstore=vectorstore,
316
+ search_figures = search_figures,
317
+ sources = sources,
318
+ threshold = 0.5,
319
+ search_only = search_only,
320
+ reports = reports,
321
+ )
322
 
323
 
324
  # Rerank
325
+ if reranker is not None and rerank_by_question:
326
  with suppress_output():
327
+ for key in docs_question_dict.keys():
328
+ docs_question_dict[key] = rerank_and_sort_docs(reranker,docs_question_dict[key],question)
 
 
 
 
 
329
  else:
 
330
  # Add a default reranking score
331
  for doc in docs_question:
332
  doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
333
 
334
+ # Keep the right number of documents
335
+ docs_question, images_question = concatenate_documents(index, source_type, docs_question_dict, k_by_question, k_summary_by_question, k_images_by_question)
336
+
337
+ # Rerank the documents to put the most relevant in front
338
  if reranker is not None and rerank_by_question:
339
+ docs_question = rerank_and_sort_docs(reranker, docs_question, question)
340
+
341
  # Add sources used in the metadata
342
  docs_question = _add_sources_used_in_metadata(docs_question,sources,question,index)
343
  images_question = _add_sources_used_in_metadata(images_question,sources,question,index)
 
350
 
351
 
352
 
353
+ def make_IPx_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
354
+
355
+ @chain
356
+ async def retrieve_IPx_docs(state, config):
357
+ source_type = "IPx"
358
+ state = await retrieve_documents(state,config, source_type, vectorstore,reranker,llm,rerank_by_question, k_final, k_before_reranking, k_summary)
359
+ return state
360
+
361
+ return retrieve_IPx_docs
362
+
363
+
364
+ def make_POC_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
365
+
366
  @chain
367
+ async def retrieve_IPx_docs(state, config):
368
+ source_type = "POC"
369
+ state = await retrieve_documents(state,config, source_type, vectorstore,reranker,llm,rerank_by_question, k_final, k_before_reranking, k_summary)
370
  return state
371
 
372
+ return retrieve_IPx_docs
373
 
374
 
375
 
climateqa/engine/graph.py CHANGED
@@ -16,7 +16,7 @@ from .chains.answer_ai_impact import make_ai_impact_node
16
  from .chains.query_transformation import make_query_transform_node
17
  from .chains.translation import make_translation_node
18
  from .chains.intent_categorization import make_intent_categorization_node
19
- from .chains.retrieve_documents import make_retriever_node
20
  from .chains.answer_rag import make_rag_node
21
  from .chains.graph_retriever import make_graph_retriever_node
22
  from .chains.chitchat_categorization import make_chitchat_intent_categorization_node
@@ -46,6 +46,9 @@ class GraphState(TypedDict):
46
  search_only : bool = False
47
  reports : List[str] = []
48
 
 
 
 
49
  def search(state): #TODO
50
  return state
51
 
@@ -60,7 +63,7 @@ def route_intent(state):
60
  # return "answer_ai_impact"
61
  else:
62
  # Search route
63
- return "search"
64
 
65
  def chitchat_route_intent(state):
66
  intent = state["search_graphs_chitchat"]
@@ -82,18 +85,29 @@ def route_based_on_relevant_docs(state,threshold_docs=0.2):
82
  else:
83
  return "answer_rag_no_docs"
84
 
85
- def route_retrieve_documents(state):
86
  if len(state["remaining_questions"]) == 0 and state["search_only"] :
87
  return END
88
  elif len(state["remaining_questions"]) > 0:
89
  return "retrieve_documents"
90
  else:
91
  return "answer_search"
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  def make_id_dict(values):
94
  return {k:k for k in values}
95
 
96
- def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, reranker, threshold_docs=0.2):
97
 
98
  workflow = StateGraph(GraphState)
99
 
@@ -103,8 +117,9 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, reranker, thresh
103
  translate_query = make_translation_node(llm)
104
  answer_chitchat = make_chitchat_node(llm)
105
  answer_ai_impact = make_ai_impact_node(llm)
106
- retrieve_documents = make_retriever_node(vectorstore_ipcc, reranker, llm)
107
  retrieve_graphs = make_graph_retriever_node(vectorstore_graphs, reranker)
 
108
  answer_rag = make_rag_node(llm, with_docs=True)
109
  answer_rag_no_docs = make_rag_node(llm, with_docs=False)
110
  chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
@@ -112,13 +127,14 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, reranker, thresh
112
  # Define the nodes
113
  # workflow.add_node("set_defaults", set_defaults)
114
  workflow.add_node("categorize_intent", categorize_intent)
115
- workflow.add_node("search", search)
116
  workflow.add_node("answer_search", answer_search)
117
  workflow.add_node("transform_query", transform_query)
118
  workflow.add_node("translate_query", translate_query)
119
  workflow.add_node("answer_chitchat", answer_chitchat)
120
  workflow.add_node("chitchat_categorize_intent", chitchat_categorize_intent)
121
  workflow.add_node("retrieve_graphs", retrieve_graphs)
 
122
  workflow.add_node("retrieve_graphs_chitchat", retrieve_graphs)
123
  workflow.add_node("retrieve_documents", retrieve_documents)
124
  workflow.add_node("answer_rag", answer_rag)
@@ -131,7 +147,7 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, reranker, thresh
131
  workflow.add_conditional_edges(
132
  "categorize_intent",
133
  route_intent,
134
- make_id_dict(["answer_chitchat","search"])
135
  )
136
 
137
  workflow.add_conditional_edges(
@@ -141,14 +157,14 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, reranker, thresh
141
  )
142
 
143
  workflow.add_conditional_edges(
144
- "search",
145
  route_translation,
146
  make_id_dict(["translate_query","transform_query"])
147
  )
148
  workflow.add_conditional_edges(
149
  "retrieve_documents",
150
  # lambda state : "retrieve_documents" if len(state["remaining_questions"]) > 0 else "answer_search",
151
- route_retrieve_documents,
152
  make_id_dict([END,"retrieve_documents","answer_search"])
153
  )
154
 
@@ -159,9 +175,16 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, reranker, thresh
159
  )
160
  workflow.add_conditional_edges(
161
  "transform_query",
162
- lambda state : "retrieve_graphs" if "Graphs (OurWorldInData)" in state["relevant_content_sources_selection"] else END,
163
- make_id_dict(["retrieve_graphs", END])
164
  )
 
 
 
 
 
 
 
165
 
166
  # Define the edges
167
  workflow.add_edge("translate_query", "transform_query")
@@ -172,7 +195,7 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, reranker, thresh
172
  workflow.add_edge("answer_rag_no_docs", END)
173
  workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
174
  workflow.add_edge("retrieve_graphs_chitchat", END)
175
-
176
 
177
  # Compile
178
  app = workflow.compile()
 
16
  from .chains.query_transformation import make_query_transform_node
17
  from .chains.translation import make_translation_node
18
  from .chains.intent_categorization import make_intent_categorization_node
19
+ from .chains.retrieve_documents import make_IPx_retriever_node, make_POC_retriever_node
20
  from .chains.answer_rag import make_rag_node
21
  from .chains.graph_retriever import make_graph_retriever_node
22
  from .chains.chitchat_categorization import make_chitchat_intent_categorization_node
 
46
  search_only : bool = False
47
  reports : List[str] = []
48
 
49
+ def dummy(state):
50
+ return state
51
+
52
  def search(state): #TODO
53
  return state
54
 
 
63
  # return "answer_ai_impact"
64
  else:
65
  # Search route
66
+ return "answer_climate"
67
 
68
  def chitchat_route_intent(state):
69
  intent = state["search_graphs_chitchat"]
 
85
  else:
86
  return "answer_rag_no_docs"
87
 
88
+ def route_continue_retrieve_documents(state):
89
  if len(state["remaining_questions"]) == 0 and state["search_only"] :
90
  return END
91
  elif len(state["remaining_questions"]) > 0:
92
  return "retrieve_documents"
93
  else:
94
  return "answer_search"
95
+
96
+ def route_retrieve_documents(state):
97
+ sources_to_retrieve = []
98
+
99
+ if "Graphs (OurWorldInData)" in state["relevant_content_sources_selection"] :
100
+ sources_to_retrieve.append("retrieve_graphs")
101
+ if "POC region" in state["relevant_content_sources_selection"] :
102
+ sources_to_retrieve.append("retrieve_local_data")
103
+ if sources_to_retrieve == []:
104
+ return END
105
+ return sources_to_retrieve
106
 
107
  def make_id_dict(values):
108
  return {k:k for k in values}
109
 
110
+ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_region, reranker, threshold_docs=0.2):
111
 
112
  workflow = StateGraph(GraphState)
113
 
 
117
  translate_query = make_translation_node(llm)
118
  answer_chitchat = make_chitchat_node(llm)
119
  answer_ai_impact = make_ai_impact_node(llm)
120
+ retrieve_documents = make_IPx_retriever_node(vectorstore_ipcc, reranker, llm)
121
  retrieve_graphs = make_graph_retriever_node(vectorstore_graphs, reranker)
122
+ retrieve_local_data = make_POC_retriever_node(vectorstore_region, reranker, llm)
123
  answer_rag = make_rag_node(llm, with_docs=True)
124
  answer_rag_no_docs = make_rag_node(llm, with_docs=False)
125
  chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
 
127
  # Define the nodes
128
  # workflow.add_node("set_defaults", set_defaults)
129
  workflow.add_node("categorize_intent", categorize_intent)
130
+ workflow.add_node("answer_climate", dummy)
131
  workflow.add_node("answer_search", answer_search)
132
  workflow.add_node("transform_query", transform_query)
133
  workflow.add_node("translate_query", translate_query)
134
  workflow.add_node("answer_chitchat", answer_chitchat)
135
  workflow.add_node("chitchat_categorize_intent", chitchat_categorize_intent)
136
  workflow.add_node("retrieve_graphs", retrieve_graphs)
137
+ workflow.add_node("retrieve_local_data", retrieve_local_data)
138
  workflow.add_node("retrieve_graphs_chitchat", retrieve_graphs)
139
  workflow.add_node("retrieve_documents", retrieve_documents)
140
  workflow.add_node("answer_rag", answer_rag)
 
147
  workflow.add_conditional_edges(
148
  "categorize_intent",
149
  route_intent,
150
+ make_id_dict(["answer_chitchat","answer_climate"])
151
  )
152
 
153
  workflow.add_conditional_edges(
 
157
  )
158
 
159
  workflow.add_conditional_edges(
160
+ "answer_climate",
161
  route_translation,
162
  make_id_dict(["translate_query","transform_query"])
163
  )
164
  workflow.add_conditional_edges(
165
  "retrieve_documents",
166
  # lambda state : "retrieve_documents" if len(state["remaining_questions"]) > 0 else "answer_search",
167
+ route_continue_retrieve_documents,
168
  make_id_dict([END,"retrieve_documents","answer_search"])
169
  )
170
 
 
175
  )
176
  workflow.add_conditional_edges(
177
  "transform_query",
178
+ route_retrieve_documents,
179
+ make_id_dict(["retrieve_graphs","retrieve_local_data", END])
180
  )
181
+
182
+
183
+ # workflow.add_conditional_edges(
184
+ # "transform_query",
185
+ # lambda state : "retrieve_graphs" if "POC region" in state["relevant_content_sources_selection"] else END,
186
+ # make_id_dict(["retrieve_local_data", END])
187
+ # )
188
 
189
  # Define the edges
190
  workflow.add_edge("translate_query", "transform_query")
 
195
  workflow.add_edge("answer_rag_no_docs", END)
196
  workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
197
  workflow.add_edge("retrieve_graphs_chitchat", END)
198
+ workflow.add_edge("retrieve_local_data", "answer_search")
199
 
200
  # Compile
201
  app = workflow.compile()
climateqa/engine/reranker.py CHANGED
@@ -47,4 +47,9 @@ def rerank_docs(reranker,docs,query):
47
  doc.metadata["reranking_score"] = result.score
48
  doc.metadata["query_used_for_retrieval"] = query
49
  docs_reranked.append(doc)
 
 
 
 
 
50
  return docs_reranked
 
47
  doc.metadata["reranking_score"] = result.score
48
  doc.metadata["query_used_for_retrieval"] = query
49
  docs_reranked.append(doc)
50
+ return docs_reranked
51
+
52
+ def rerank_and_sort_docs(reranker, docs, query):
53
+ docs_reranked = rerank_docs(reranker,docs,query)
54
+ docs_reranked = sorted(docs_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
55
  return docs_reranked