timeki commited on
Commit
b7b39b4
·
1 Parent(s): fc994dc

Multiple sources OK

Browse files
climateqa/engine/chains/answer_rag.py CHANGED
@@ -11,7 +11,7 @@ import time
11
  from ..utils import rename_chain, pass_values
12
 
13
 
14
- DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
15
 
16
  def _combine_documents(
17
  docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, sep="\n\n"
 
11
  from ..utils import rename_chain, pass_values
12
 
13
 
14
+ DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="Source : {source} - {page_content}")
15
 
16
  def _combine_documents(
17
  docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, sep="\n\n"
climateqa/engine/chains/prompts.py CHANGED
@@ -36,6 +36,30 @@ You are given a question and extracted passages of the IPCC and/or IPBES reports
36
  """
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  answer_prompt_template = """
40
  You are ClimateQ&A, an AI Assistant created by Ekimetrics. You are given a question and extracted passages of reports. Provide a clear and structured answer based on the passages provided, the context and the guidelines.
41
 
@@ -50,6 +74,8 @@ Guidelines:
50
  - If the documents do not have the information needed to answer the question, just say you do not have enough information.
51
  - Consider by default that the question is about the past century unless it is specified otherwise.
52
  - If the passage is the caption of a picture, you can still use it as part of your answer as any other document.
 
 
53
 
54
  -----------------------
55
  Passages:
@@ -60,7 +86,6 @@ Question: {query} - Explained to {audience}
60
  Answer in {language} with the passages citations:
61
  """
62
 
63
-
64
  papers_prompt_template = """
65
  You are ClimateQ&A, an AI Assistant created by Ekimetrics. You are given a question and extracted abstracts of scientific papers. Provide a clear and structured answer based on the abstracts provided, the context and the guidelines.
66
 
 
36
  """
37
 
38
 
39
+ # answer_prompt_template_old = """
40
+ # You are ClimateQ&A, an AI Assistant created by Ekimetrics. You are given a question and extracted passages of reports. Provide a clear and structured answer based on the passages provided, the context and the guidelines.
41
+
42
+ # Guidelines:
43
+ # - If the passages have useful facts or numbers, use them in your answer.
44
+ # - When you use information from a passage, mention where it came from by using [Doc i] at the end of the sentence. i stands for the number of the document.
45
+ # - Do not use the sentence 'Doc i says ...' to say where information came from.
46
+ # - If the same thing is said in more than one document, you can mention all of them like this: [Doc i, Doc j, Doc k]
47
+ # - Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
48
+ # - If it makes sense, use bullet points and lists to make your answers easier to understand.
49
+ # - You do not need to use every passage. Only use the ones that help answer the question.
50
+ # - If the documents do not have the information needed to answer the question, just say you do not have enough information.
51
+ # - Consider by default that the question is about the past century unless it is specified otherwise.
52
+ # - If the passage is the caption of a picture, you can still use it as part of your answer as any other document.
53
+
54
+ # -----------------------
55
+ # Passages:
56
+ # {context}
57
+
58
+ # -----------------------
59
+ # Question: {query} - Explained to {audience}
60
+ # Answer in {language} with the passages citations:
61
+ # """
62
+
63
  answer_prompt_template = """
64
  You are ClimateQ&A, an AI Assistant created by Ekimetrics. You are given a question and extracted passages of reports. Provide a clear and structured answer based on the passages provided, the context and the guidelines.
65
 
 
74
  - If the documents do not have the information needed to answer the question, just say you do not have enough information.
75
  - Consider by default that the question is about the past century unless it is specified otherwise.
76
  - If the passage is the caption of a picture, you can still use it as part of your answer as any other document.
77
+ - If you receive passages from different reports, eg IPCC and PPCP, make separate paragraphs and specify the source of the information in your answer, eg "According to IPCC, ...".
78
+ - The different sources are IPCC, IPBES, PPCP (for Plan Climat Air Energie Territorial de Paris), PBDP (for Plan Biodiversité de Paris), Acclimaterra.
79
 
80
  -----------------------
81
  Passages:
 
86
  Answer in {language} with the passages citations:
87
  """
88
 
 
89
  papers_prompt_template = """
90
  You are ClimateQ&A, an AI Assistant created by Ekimetrics. You are given a question and extracted abstracts of scientific papers. Provide a clear and structured answer based on the abstracts provided, the context and the guidelines.
91
 
climateqa/engine/chains/query_transformation.py CHANGED
@@ -60,7 +60,8 @@ from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
60
 
61
 
62
  ROUTING_INDEX = {
63
- "Vector":["IPCC","IPBES","IPOS", "AcclimaTerra"],
 
64
  "OpenAlex":["OpenAlex"],
65
  }
66
 
@@ -88,6 +89,17 @@ class Location(BaseModel):
88
  country:str = Field(...,description="The country if directly mentioned or inferred from the location (cities, regions, adresses), ex: France, USA, ...")
89
  location:str = Field(...,description="The specific place if mentioned (cities, regions, addresses), ex: Marseille, New York, Wisconsin, ...")
90
 
 
 
 
 
 
 
 
 
 
 
 
91
  class QueryAnalysis(BaseModel):
92
  """
93
  Analyze the user query to extract the relevant sources
@@ -98,14 +110,16 @@ class QueryAnalysis(BaseModel):
98
  Also provide simple keywords to feed a search engine
99
  """
100
 
101
- sources: List[Literal["IPCC", "IPBES", "IPOS", "AcclimaTerra"]] = Field( #,"OpenAlex"]] = Field(
102
  ...,
103
  description="""
104
  Given a user question choose which documents would be most relevant for answering their question,
105
  - IPCC is for questions about climate change, energy, impacts, and everything we can find the IPCC reports
106
  - IPBES is for questions about biodiversity and nature
107
  - IPOS is for questions about the ocean and deep sea mining
108
- - AcclimaTerra is for questions about any specific place in, or close to, the french region "Nouvelle-Aquitaine"
 
 
109
  """,
110
  )
111
 
@@ -142,7 +156,25 @@ def make_query_analysis_chain(llm):
142
 
143
 
144
  prompt = ChatPromptTemplate.from_messages([
145
- ("system", "You are a helpful assistant, you will analyze, translate and reformulate the user input message using the function provided"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  ("user", "input: {input}")
147
  ])
148
 
@@ -150,6 +182,16 @@ def make_query_analysis_chain(llm):
150
  chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
151
  return chain
152
 
 
 
 
 
 
 
 
 
 
 
153
 
154
  def make_query_transform_node(llm,k_final=15):
155
  """
@@ -172,12 +214,13 @@ def make_query_transform_node(llm,k_final=15):
172
 
173
  decomposition_chain = make_query_decomposition_chain(llm)
174
  query_analysis_chain = make_query_analysis_chain(llm)
 
175
 
176
  def transform_query(state):
177
  print("---- Transform query ----")
178
 
179
- auto_mode = state.get("sources_auto", False)
180
- sources_input = state.get("sources_input", ROUTING_INDEX["Vector"])
181
 
182
 
183
  new_state = {}
@@ -186,6 +229,7 @@ def make_query_transform_node(llm,k_final=15):
186
  decomposition_output = decomposition_chain.invoke({"input":state["query"]})
187
  new_state.update(decomposition_output)
188
 
 
189
  # Query Analysis
190
  questions = []
191
  for question in new_state["questions"]:
@@ -194,16 +238,32 @@ def make_query_transform_node(llm,k_final=15):
194
 
195
  # TODO WARNING llm should always return smthg
196
  # The case when the llm does not return any sources or wrong ouput
197
- if not query_analysis_output["sources"] or not all(source in ["IPCC", "IPBS", "IPOS"] for source in query_analysis_output["sources"]):
198
  query_analysis_output["sources"] = ["IPCC", "IPBES", "IPOS"]
199
 
200
- question_state.update(query_analysis_output)
201
- questions.append(question_state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
  # Explode the questions into multiple questions with different sources
204
  new_questions = []
205
  for q in questions:
206
- question,sources = q["question"],q["sources"]
207
 
208
  # If not auto mode we take the configuration
209
  if not auto_mode:
@@ -212,7 +272,7 @@ def make_query_transform_node(llm,k_final=15):
212
  for index,index_sources in ROUTING_INDEX.items():
213
  selected_sources = list(set(sources).intersection(index_sources))
214
  if len(selected_sources) > 0:
215
- new_questions.append({"question":question,"sources":selected_sources,"index":index})
216
 
217
  # # Add the number of questions to search
218
  # k_by_question = k_final // len(new_questions)
@@ -222,11 +282,16 @@ def make_query_transform_node(llm,k_final=15):
222
  # new_state["questions"] = new_questions
223
  # new_state["remaining_questions"] = new_questions
224
 
 
 
 
 
 
225
 
226
  new_state = {
227
  "questions_list":new_questions,
228
- "n_questions":len(new_questions),
229
- "handled_questions_index":[],
230
  }
231
  return new_state
232
 
 
60
 
61
 
62
  ROUTING_INDEX = {
63
+ "IPx":["IPCC", "IPBS", "IPOS"],
64
+ "POC": ["AcclimaTerra", "PCAET","Biodiv"],
65
  "OpenAlex":["OpenAlex"],
66
  }
67
 
 
89
  country:str = Field(...,description="The country if directly mentioned or inferred from the location (cities, regions, adresses), ex: France, USA, ...")
90
  location:str = Field(...,description="The specific place if mentioned (cities, regions, addresses), ex: Marseille, New York, Wisconsin, ...")
91
 
92
+ class QueryTranslation(BaseModel):
93
+ """Translate the query into a given language"""
94
+
95
+ question : str = Field(
96
+ description="""
97
+ Translate the questions into the given language
98
+ If the question is alrealdy in the given language, just return the same question
99
+ """,
100
+ )
101
+
102
+
103
  class QueryAnalysis(BaseModel):
104
  """
105
  Analyze the user query to extract the relevant sources
 
110
  Also provide simple keywords to feed a search engine
111
  """
112
 
113
+ sources: List[Literal["IPCC", "IPBES", "IPOS", "AcclimaTerra", "PCAET","Biodiv"]] = Field( #,"OpenAlex"]] = Field(
114
  ...,
115
  description="""
116
  Given a user question choose which documents would be most relevant for answering their question,
117
  - IPCC is for questions about climate change, energy, impacts, and everything we can find the IPCC reports
118
  - IPBES is for questions about biodiversity and nature
119
  - IPOS is for questions about the ocean and deep sea mining
120
+ - AcclimaTerra is for questions about any specific place in, or close to, the french region "Nouvelle-Aquitaine"
121
+ - PCAET is the Plan Climat Eneregie Territorial for the city of Paris
122
+ - Biodiv is the Biodiversity plan for the city of Paris
123
  """,
124
  )
125
 
 
156
 
157
 
158
  prompt = ChatPromptTemplate.from_messages([
159
+ ("system", "You are a helpful assistant, you will analyze the user input message using the function provided"),
160
+ ("user", "input: {input}")
161
+ ])
162
+
163
+
164
+ chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
165
+ return chain
166
+
167
+
168
+ def make_query_translation_chain(llm):
169
+ """Analyze the user query to extract the relevant sources"""
170
+
171
+ openai_functions = [convert_to_openai_function(QueryTranslation)]
172
+ llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"QueryTranslation"})
173
+
174
+
175
+
176
+ prompt = ChatPromptTemplate.from_messages([
177
+ ("system", "You are a helpful assistant, translate the question into {language}"),
178
  ("user", "input: {input}")
179
  ])
180
 
 
182
  chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
183
  return chain
184
 
185
+ def group_by_sources_types(sources):
186
+ sources_types = {}
187
+ IPx_sources = ["IPCC", "IPBES", "IPOS"]
188
+ local_sources = ["AcclimaTerra", "PCAET","Biodiv"]
189
+ if any(source in IPx_sources for source in sources):
190
+ sources_types["IPx"] = list(set(sources).intersection(IPx_sources))
191
+ if any(source in local_sources for source in sources):
192
+ sources_types["POC"] = list(set(sources).intersection(local_sources))
193
+ return sources_types
194
+
195
 
196
  def make_query_transform_node(llm,k_final=15):
197
  """
 
214
 
215
  decomposition_chain = make_query_decomposition_chain(llm)
216
  query_analysis_chain = make_query_analysis_chain(llm)
217
+ query_translation_chain = make_query_translation_chain(llm)
218
 
219
  def transform_query(state):
220
  print("---- Transform query ----")
221
 
222
+ auto_mode = state.get("sources_auto", True)
223
+ sources_input = state.get("sources_input", ROUTING_INDEX["IPx"])
224
 
225
 
226
  new_state = {}
 
229
  decomposition_output = decomposition_chain.invoke({"input":state["query"]})
230
  new_state.update(decomposition_output)
231
 
232
+
233
  # Query Analysis
234
  questions = []
235
  for question in new_state["questions"]:
 
238
 
239
  # TODO WARNING llm should always return smthg
240
  # The case when the llm does not return any sources or wrong ouput
241
+ if not query_analysis_output["sources"] or not all(source in ["IPCC", "IPBS", "IPOS","AcclimaTerra", "PCAET","Biodiv"] for source in query_analysis_output["sources"]):
242
  query_analysis_output["sources"] = ["IPCC", "IPBES", "IPOS"]
243
 
244
+ sources_types = group_by_sources_types(query_analysis_output["sources"])
245
+ for source_type,sources in sources_types.items():
246
+ question_state = {
247
+ "question":question,
248
+ "sources":sources,
249
+ "source_type":source_type
250
+ }
251
+
252
+ questions.append(question_state)
253
+
254
+ # Translate question into the document language
255
+ for q in questions:
256
+ if q["source_type"]=="IPx":
257
+ translation_output = query_translation_chain.invoke({"input":q["question"],"language":"English"})
258
+ q["question"] = translation_output["question"]
259
+ elif q["source_type"]=="POC":
260
+ translation_output = query_translation_chain.invoke({"input":q["question"],"language":"French"})
261
+ q["question"] = translation_output["question"]
262
 
263
  # Explode the questions into multiple questions with different sources
264
  new_questions = []
265
  for q in questions:
266
+ question,sources,source_type = q["question"],q["sources"], q["source_type"]
267
 
268
  # If not auto mode we take the configuration
269
  if not auto_mode:
 
272
  for index,index_sources in ROUTING_INDEX.items():
273
  selected_sources = list(set(sources).intersection(index_sources))
274
  if len(selected_sources) > 0:
275
+ new_questions.append({"question":question,"sources":selected_sources,"index":index, "source_type":source_type})
276
 
277
  # # Add the number of questions to search
278
  # k_by_question = k_final // len(new_questions)
 
282
  # new_state["questions"] = new_questions
283
  # new_state["remaining_questions"] = new_questions
284
 
285
+ n_questions = {
286
+ "total":len(new_questions),
287
+ "IPx":len([q for q in new_questions if q["index"] == "IPx"]),
288
+ "POC":len([q for q in new_questions if q["index"] == "POC"]),
289
+ }
290
 
291
  new_state = {
292
  "questions_list":new_questions,
293
+ "n_questions":n_questions,
294
+ "handled_questions_index":[],
295
  }
296
  return new_state
297
 
climateqa/engine/chains/retrieve_documents.py CHANGED
@@ -290,7 +290,7 @@ async def retrieve_documents(state,config, source_type, vectorstore,reranker,llm
290
  Returns:
291
  dict: The updated state containing the retrieved and reranked documents, related content, and remaining questions.
292
  """
293
- print(f"---- Retrieve documents from {source_type}----")
294
  docs = state.get("documents", [])
295
  related_content = state.get("related_content", [])
296
 
@@ -304,26 +304,30 @@ async def retrieve_documents(state,config, source_type, vectorstore,reranker,llm
304
  # remaining_questions = state["remaining_questions"][1:]
305
 
306
  current_question_id = None
307
- print("Here", range(len(state["questions_list"])),state["handled_questions_index"])
308
 
309
  for i in range(len(state["questions_list"])):
310
- if i not in state["handled_questions_index"]:
 
 
311
  current_question_id = i
312
  break
313
- current_question = state["questions_list"][current_question_id]
314
  # TODO filter on source_type
315
 
316
- k_by_question = k_final // state["n_questions"]
317
- k_summary_by_question = _get_k_summary_by_question(state["n_questions"])
318
- k_images_by_question = _get_k_images_by_question(state["n_questions"])
319
 
320
  sources = current_question["sources"]
321
  question = current_question["question"]
322
  index = current_question["index"]
 
323
 
324
  print(f"Retrieve documents for question: {question}")
325
  await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config)
326
 
 
327
 
328
  # if index == "Vector": # always true for now #TODO rename to IPx
329
  if source_type == "IPx": # always true for now #TODO rename to IPx
@@ -393,7 +397,7 @@ def make_IPx_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_
393
  @chain
394
  async def retrieve_IPx_docs(state, config):
395
  source_type = "IPx"
396
- return {"documents":[], "related_contents": [], "handled_questions_index": list(range(len(state["questions_list"])))} # TODO Remove
397
 
398
  state = await retrieve_documents(
399
  state = state,
 
290
  Returns:
291
  dict: The updated state containing the retrieved and reranked documents, related content, and remaining questions.
292
  """
293
+ # TODO split les questions selon le type de sources dans le state question + conditions sur le nombre de questions traités par type de source
294
  docs = state.get("documents", [])
295
  related_content = state.get("related_content", [])
296
 
 
304
  # remaining_questions = state["remaining_questions"][1:]
305
 
306
  current_question_id = None
307
+ print("Questions Indexs", list(range(len(state["questions_list"]))), "- Handled questions : " ,state["handled_questions_index"])
308
 
309
  for i in range(len(state["questions_list"])):
310
+ current_question = state["questions_list"][i]
311
+
312
+ if i not in state["handled_questions_index"] and current_question["source_type"] == source_type:
313
  current_question_id = i
314
  break
315
+
316
  # TODO filter on source_type
317
 
318
+ k_by_question = k_final // state["n_questions"]["total"]
319
+ k_summary_by_question = _get_k_summary_by_question(state["n_questions"]["total"])
320
+ k_images_by_question = _get_k_images_by_question(state["n_questions"]["total"])
321
 
322
  sources = current_question["sources"]
323
  question = current_question["question"]
324
  index = current_question["index"]
325
+ source_type = current_question["source_type"]
326
 
327
  print(f"Retrieve documents for question: {question}")
328
  await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config)
329
 
330
+ print(f"""---- Retrieve documents from {current_question["source_type"]}----""")
331
 
332
  # if index == "Vector": # always true for now #TODO rename to IPx
333
  if source_type == "IPx": # always true for now #TODO rename to IPx
 
397
  @chain
398
  async def retrieve_IPx_docs(state, config):
399
  source_type = "IPx"
400
+ # return {"documents":[], "related_contents": [], "handled_questions_index": list(range(len(state["questions_list"])))} # TODO Remove
401
 
402
  state = await retrieve_documents(
403
  state = state,
climateqa/engine/graph.py CHANGED
@@ -93,21 +93,40 @@ def route_based_on_relevant_docs(state,threshold_docs=0.2):
93
  return "answer_rag_no_docs"
94
 
95
  def route_continue_retrieve_documents(state):
96
- if len(state["questions_list"]) == len(state["handled_questions_index"]) and state["search_only"] :
 
 
97
  return END
98
- elif len(state["questions_list"]) == len(state["handled_questions_index"]):
99
- return "answer_search"
100
- else :
101
  return "retrieve_documents"
 
 
 
 
 
 
 
 
102
 
103
  def route_continue_retrieve_local_documents(state):
104
- if len(state["questions_list"]) == len(state["handled_questions_index"]) and state["search_only"] :
 
 
105
  return END
106
- elif len(state["questions_list"]) == len(state["handled_questions_index"]):
107
  return "answer_search"
108
- else :
109
  return "retrieve_local_data"
110
 
 
 
 
 
 
 
 
111
  # if len(state["remaining_questions"]) == 0 and state["search_only"] :
112
  # return END
113
  # elif len(state["remaining_questions"]) > 0:
@@ -216,8 +235,8 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_regi
216
 
217
  # Define the edges
218
  workflow.add_edge("translate_query", "transform_query")
219
- # workflow.add_edge("transform_query", "retrieve_documents") #TODO put back
220
- workflow.add_edge("transform_query", END) # TODO remove
221
 
222
  workflow.add_edge("retrieve_graphs", END)
223
  workflow.add_edge("answer_rag", END)
 
93
  return "answer_rag_no_docs"
94
 
95
  def route_continue_retrieve_documents(state):
96
+ index_question_ipx = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "IPx"]
97
+ questions_ipx_finished = all(elem in state["handled_questions_index"] for elem in index_question_ipx)
98
+ if questions_ipx_finished and state["search_only"]:
99
  return END
100
+ elif questions_ipx_finished:
101
+ return "answer_search"
102
+ else:
103
  return "retrieve_documents"
104
+
105
+
106
+ # if state["n_questions"]["IPx"] == len(state["handled_questions_index"]) and state["search_only"] :
107
+ # return END
108
+ # elif state["n_questions"]["IPx"] == len(state["handled_questions_index"]):
109
+ # return "answer_search"
110
+ # else :
111
+ # return "retrieve_documents"
112
 
113
  def route_continue_retrieve_local_documents(state):
114
+ index_question_poc = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"]
115
+ questions_poc_finished = all(elem in state["handled_questions_index"] for elem in index_question_poc)
116
+ if questions_poc_finished and state["search_only"]:
117
  return END
118
+ elif questions_poc_finished:
119
  return "answer_search"
120
+ else:
121
  return "retrieve_local_data"
122
 
123
+ # if state["n_questions"]["POC"] == len(state["handled_questions_index"]) and state["search_only"] :
124
+ # return END
125
+ # elif state["n_questions"]["POC"] == len(state["handled_questions_index"]):
126
+ # return "answer_search"
127
+ # else :
128
+ # return "retrieve_local_data"
129
+
130
  # if len(state["remaining_questions"]) == 0 and state["search_only"] :
131
  # return END
132
  # elif len(state["remaining_questions"]) > 0:
 
235
 
236
  # Define the edges
237
  workflow.add_edge("translate_query", "transform_query")
238
+ workflow.add_edge("transform_query", "retrieve_documents") #TODO put back
239
+ # workflow.add_edge("transform_query", END) # TODO remove
240
 
241
  workflow.add_edge("retrieve_graphs", END)
242
  workflow.add_edge("answer_rag", END)
climateqa/event_handler.py CHANGED
@@ -35,7 +35,7 @@ def handle_retrieved_documents(event: StreamEvent, history : list[ChatMessage],
35
  tuple[str, list[ChatMessage], list[str]]: The updated HTML representation of the documents, the updated message history and the updated list of used documents
36
  """
37
  if "documents" not in event["data"]["output"] or event["data"]["output"]["documents"] == []:
38
- return history, used_documents, []
39
 
40
  try:
41
  docs = event["data"]["output"]["documents"]
 
35
  tuple[str, list[ChatMessage], list[str]]: The updated HTML representation of the documents, the updated message history and the updated list of used documents
36
  """
37
  if "documents" not in event["data"]["output"] or event["data"]["output"]["documents"] == []:
38
+ return history, used_documents
39
 
40
  try:
41
  docs = event["data"]["output"]["documents"]