timeki commited on
Commit
fc994dc
·
1 Parent(s): be494ba
app.py CHANGED
@@ -30,9 +30,17 @@ from climateqa.event_handler import (
30
  init_audience,
31
  handle_retrieved_documents,
32
  stream_answer,
33
- handle_retrieved_owid_graphs
 
34
  )
35
  from utils import create_user_id
 
 
 
 
 
 
 
36
 
37
  # Load environment variables in local mode
38
  try:
@@ -41,6 +49,7 @@ try:
41
  except Exception as e:
42
  pass
43
 
 
44
  # Set up Gradio Theme
45
  theme = gr.themes.Base(
46
  primary_hue="blue",
@@ -108,7 +117,7 @@ vectorstore_region = get_pinecone_vectorstore(embeddings_function, index_name=os
108
  llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
109
  reranker = get_reranker("nano")
110
 
111
- agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker)
112
 
113
  # Function to update modal visibility
114
  def update_config_modal_visibility(config_open):
@@ -170,6 +179,7 @@ async def chat(
170
  docs = []
171
  related_contents = []
172
  docs_html = ""
 
173
  output_query = ""
174
  output_language = ""
175
  output_keywords = ""
@@ -183,20 +193,26 @@ async def chat(
183
  "categorize_intent": ("🔄️ Analyzing user message", True),
184
  "transform_query": ("🔄️ Thinking step by step to answer the question", True),
185
  "retrieve_documents": ("🔄️ Searching in the knowledge base", False),
 
186
  }
187
 
188
  try:
189
  # Process streaming events
190
  async for event in result:
 
191
  if "langgraph_node" in event["metadata"]:
192
  node = event["metadata"]["langgraph_node"]
193
 
194
  # Handle document retrieval
195
- if event["event"] == "on_chain_end" and event["name"] == "retrieve_documents" and event["data"]["output"] != None:
196
- docs, docs_html, history, used_documents, related_contents = handle_retrieved_documents(
197
  event, history, used_documents
198
  )
199
-
 
 
 
 
200
  # Handle intent categorization
201
  elif (event["event"] == "on_chain_end" and
202
  node == "categorize_intent" and
@@ -231,7 +247,7 @@ async def chat(
231
  # Handle query transformation
232
  if event["name"] == "transform_query" and event["event"] == "on_chain_end":
233
  if hasattr(history[-1], "content"):
234
- sub_questions = [q["question"] for q in event["data"]["output"]["remaining_questions"]]
235
  history[-1].content += "Decompose question into sub-questions:\n\n - " + "\n - ".join(sub_questions)
236
 
237
  yield history, docs_html, output_query, output_language, related_contents, graphs_html
@@ -493,9 +509,9 @@ def create_config_modal(config_open):
493
  )
494
 
495
  dropdown_external_sources = gr.CheckboxGroup(
496
- choices=["Figures (IPCC/IPBES)", "Papers (OpenAlex)", "Graphs (OurWorldInData)"],
497
  label="Select database to search for relevant content",
498
- value=["Figures (IPCC/IPBES)"],
499
  interactive=True
500
  )
501
 
@@ -565,6 +581,8 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
565
  chat_completed_state = gr.State(0)
566
  current_graphs = gr.State([])
567
  saved_graphs = gr.State({})
 
 
568
  config_open = gr.State(False)
569
 
570
  with gr.Tab("ClimateQ&A"):
@@ -584,6 +602,7 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
584
  with gr.Tab("Sources", elem_id="tab-sources", id=1) as tab_sources:
585
  sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox")
586
 
 
587
  # Recommended content tab
588
  with gr.Tab("Recommended content", elem_id="tab-recommended_content", id=2) as tab_recommended_content:
589
  with gr.Tabs(elem_id="group-subtabs") as tabs_recommended_content:
@@ -641,7 +660,7 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
641
 
642
  (textbox
643
  .submit(start_chat, [textbox, chatbot, search_only], [textbox, tabs, chatbot, sources_raw], queue=False, api_name="start_chat_textbox")
644
- .then(chat, [textbox, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, sources_textbox, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name="chat_textbox")
645
  .then(finish_chat, None, [textbox], api_name="finish_chat_textbox")
646
  )
647
 
@@ -649,10 +668,16 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
649
 
650
  (examples_hidden
651
  .change(start_chat, [examples_hidden, chatbot, search_only], [textbox, tabs, chatbot, sources_raw], queue=False, api_name="start_chat_examples")
652
- .then(chat, [examples_hidden, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, sources_textbox, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name="chat_textbox")
653
  .then(finish_chat, None, [textbox], api_name="finish_chat_examples")
654
  )
 
 
 
 
 
655
 
 
656
  new_figures.change(process_figures, inputs=[sources_raw, new_figures], outputs=[sources_raw, figures_cards, gallery_component])
657
 
658
  # Update sources numbers
 
30
  init_audience,
31
  handle_retrieved_documents,
32
  stream_answer,
33
+ handle_retrieved_owid_graphs,
34
+ convert_to_docs_to_html
35
  )
36
  from utils import create_user_id
37
+ from front.utils import make_html_source
38
+ import logging
39
+
40
+ logging.basicConfig(level=logging.WARNING)
41
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Suppresses INFO and WARNING logs
42
+ logging.getLogger().setLevel(logging.WARNING)
43
+
44
 
45
  # Load environment variables in local mode
46
  try:
 
49
  except Exception as e:
50
  pass
51
 
52
+
53
  # Set up Gradio Theme
54
  theme = gr.themes.Base(
55
  primary_hue="blue",
 
117
  llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
118
  reranker = get_reranker("nano")
119
 
120
+ agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0)#TODO put back default 0.2
121
 
122
  # Function to update modal visibility
123
  def update_config_modal_visibility(config_open):
 
179
  docs = []
180
  related_contents = []
181
  docs_html = ""
182
+ new_docs_html = ""
183
  output_query = ""
184
  output_language = ""
185
  output_keywords = ""
 
193
  "categorize_intent": ("🔄️ Analyzing user message", True),
194
  "transform_query": ("🔄️ Thinking step by step to answer the question", True),
195
  "retrieve_documents": ("🔄️ Searching in the knowledge base", False),
196
+ "retrieve_local_data": ("🔄️ Searching in the knowledge base", False),
197
  }
198
 
199
  try:
200
  # Process streaming events
201
  async for event in result:
202
+
203
  if "langgraph_node" in event["metadata"]:
204
  node = event["metadata"]["langgraph_node"]
205
 
206
  # Handle document retrieval
207
+ if event["event"] == "on_chain_end" and event["name"] in ["retrieve_documents","retrieve_local_data"] and event["data"]["output"] != None:
208
+ history, used_documents = handle_retrieved_documents(
209
  event, history, used_documents
210
  )
211
+ if event["event"] == "on_chain_end" and event["name"] == "answer_search" :
212
+ docs = event["data"]["input"]["documents"]
213
+ docs_html = convert_to_docs_to_html(docs)
214
+ related_contents = event["data"]["input"]["related_contents"]
215
+
216
  # Handle intent categorization
217
  elif (event["event"] == "on_chain_end" and
218
  node == "categorize_intent" and
 
247
  # Handle query transformation
248
  if event["name"] == "transform_query" and event["event"] == "on_chain_end":
249
  if hasattr(history[-1], "content"):
250
+ sub_questions = [q["question"] for q in event["data"]["output"]["questions_list"]]
251
  history[-1].content += "Decompose question into sub-questions:\n\n - " + "\n - ".join(sub_questions)
252
 
253
  yield history, docs_html, output_query, output_language, related_contents, graphs_html
 
509
  )
510
 
511
  dropdown_external_sources = gr.CheckboxGroup(
512
+ choices=["Figures (IPCC/IPBES)", "Papers (OpenAlex)", "Graphs (OurWorldInData)","POC region"],
513
  label="Select database to search for relevant content",
514
+ value=["Figures (IPCC/IPBES)","POC region"],
515
  interactive=True
516
  )
517
 
 
581
  chat_completed_state = gr.State(0)
582
  current_graphs = gr.State([])
583
  saved_graphs = gr.State({})
584
+ new_sources_hmtl = gr.State([])
585
+
586
  config_open = gr.State(False)
587
 
588
  with gr.Tab("ClimateQ&A"):
 
602
  with gr.Tab("Sources", elem_id="tab-sources", id=1) as tab_sources:
603
  sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox")
604
 
605
+
606
  # Recommended content tab
607
  with gr.Tab("Recommended content", elem_id="tab-recommended_content", id=2) as tab_recommended_content:
608
  with gr.Tabs(elem_id="group-subtabs") as tabs_recommended_content:
 
660
 
661
  (textbox
662
  .submit(start_chat, [textbox, chatbot, search_only], [textbox, tabs, chatbot, sources_raw], queue=False, api_name="start_chat_textbox")
663
+ .then(chat, [textbox, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name="chat_textbox")
664
  .then(finish_chat, None, [textbox], api_name="finish_chat_textbox")
665
  )
666
 
 
668
 
669
  (examples_hidden
670
  .change(start_chat, [examples_hidden, chatbot, search_only], [textbox, tabs, chatbot, sources_raw], queue=False, api_name="start_chat_examples")
671
+ .then(chat, [examples_hidden, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name="chat_textbox")
672
  .then(finish_chat, None, [textbox], api_name="finish_chat_examples")
673
  )
674
+ def process_new_docs_html(new_docs, docs):
675
+ if new_docs:
676
+ return docs + new_docs
677
+ return docs
678
+ # return docs + new_docs
679
 
680
+ new_sources_hmtl.change(lambda x : x, inputs = [new_sources_hmtl], outputs = [sources_textbox])
681
  new_figures.change(process_figures, inputs=[sources_raw, new_figures], outputs=[sources_raw, figures_cards, gallery_component])
682
 
683
  # Update sources numbers
climateqa/engine/chains/answer_rag.py CHANGED
@@ -65,6 +65,7 @@ def make_rag_node(llm,with_docs = True):
65
  async def answer_rag(state,config):
66
  print("---- Answer RAG ----")
67
  start_time = time.time()
 
68
 
69
  answer = await rag_chain.ainvoke(state,config)
70
 
 
65
  async def answer_rag(state,config):
66
  print("---- Answer RAG ----")
67
  start_time = time.time()
68
+ print("Sources used : " + "\n".join([x.metadata["short_name"] + " - page " + str(x.metadata["page_number"]) for x in state["documents"]]))
69
 
70
  answer = await rag_chain.ainvoke(state,config)
71
 
climateqa/engine/chains/graph_retriever.py CHANGED
@@ -50,7 +50,9 @@ def make_graph_retriever_node(vectorstore, reranker, rerank_by_question=True, k_
50
  print("---- Retrieving graphs ----")
51
 
52
  POSSIBLE_SOURCES = ["IEA", "OWID"]
53
- questions = state["remaining_questions"] if state["remaining_questions"] is not None and state["remaining_questions"]!=[] else [state["query"]]
 
 
54
  # sources_input = state["sources_input"]
55
  sources_input = ["auto"]
56
 
 
50
  print("---- Retrieving graphs ----")
51
 
52
  POSSIBLE_SOURCES = ["IEA", "OWID"]
53
+ # questions = state["remaining_questions"] if state["remaining_questions"] is not None and state["remaining_questions"]!=[] else [state["query"]]
54
+ questions = state["questions_list"] if state["questions_list"] is not None and state["questions_list"]!=[] else [state["query"]]
55
+
56
  # sources_input = state["sources_input"]
57
  sources_input = ["auto"]
58
 
climateqa/engine/chains/prompts.py CHANGED
@@ -37,7 +37,7 @@ You are given a question and extracted passages of the IPCC and/or IPBES reports
37
 
38
 
39
  answer_prompt_template = """
40
- You are ClimateQ&A, an AI Assistant created by Ekimetrics. You are given a question and extracted passages of the IPCC and/or IPBES reports. Provide a clear and structured answer based on the passages provided, the context and the guidelines.
41
 
42
  Guidelines:
43
  - If the passages have useful facts or numbers, use them in your answer.
 
37
 
38
 
39
  answer_prompt_template = """
40
+ You are ClimateQ&A, an AI Assistant created by Ekimetrics. You are given a question and extracted passages of reports. Provide a clear and structured answer based on the passages provided, the context and the guidelines.
41
 
42
  Guidelines:
43
  - If the passages have useful facts or numbers, use them in your answer.
climateqa/engine/chains/query_transformation.py CHANGED
@@ -7,6 +7,57 @@ from langchain.prompts import ChatPromptTemplate
7
  from langchain_core.utils.function_calling import convert_to_openai_function
8
  from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  ROUTING_INDEX = {
12
  "Vector":["IPCC","IPBES","IPOS", "AcclimaTerra"],
@@ -25,7 +76,7 @@ class QueryDecomposition(BaseModel):
25
 
26
  questions: List[str] = Field(
27
  description="""
28
- Think step by step to answer this question, and provide one or several search engine questions in English for knowledge that you need.
29
  Suppose that the user is looking for information about climate change, energy, biodiversity, nature, and everything we can find the IPCC reports and scientific literature
30
  - If it's already a standalone and explicit question, just return the reformulated question for the search engine
31
  - If you need to decompose the question, output a list of maximum 2 to 3 questions
@@ -39,36 +90,14 @@ class Location(BaseModel):
39
 
40
  class QueryAnalysis(BaseModel):
41
  """
 
 
 
42
  Analyzing the user query to extract topics, sources and date
43
  Also do query expansion to get alternative search queries
44
  Also provide simple keywords to feed a search engine
45
  """
46
 
47
- # keywords: List[str] = Field(
48
- # description="""
49
- # Extract the keywords from the user query to feed a search engine as a list
50
- # Maximum 3 keywords
51
-
52
- # Examples:
53
- # - "What is the impact of deep sea mining ?" -> deep sea mining
54
- # - "How will El Nino be impacted by climate change" -> el nino;climate change
55
- # - "Is climate change a hoax" -> climate change;hoax
56
- # """
57
- # )
58
-
59
- # alternative_queries: List[str] = Field(
60
- # description="""
61
- # Generate alternative search questions from the user query to feed a search engine
62
- # """
63
- # )
64
-
65
- # step_back_question: str = Field(
66
- # description="""
67
- # You are an expert at world knowledge. Your task is to step back and paraphrase a question to a more generic step-back question, which is easier to answer.
68
- # This questions should help you get more context and information about the user query
69
- # """
70
- # )
71
-
72
  sources: List[Literal["IPCC", "IPBES", "IPOS", "AcclimaTerra"]] = Field( #,"OpenAlex"]] = Field(
73
  ...,
74
  description="""
@@ -78,31 +107,19 @@ class QueryAnalysis(BaseModel):
78
  - IPOS is for questions about the ocean and deep sea mining
79
  - AcclimaTerra is for questions about any specific place in, or close to, the french region "Nouvelle-Aquitaine"
80
  """,
81
- # - OpenAlex is for any other questions that are not in the previous categories but could be found in the scientific litterature
82
  )
83
- # topics: List[Literal[
84
- # "Climate change",
85
- # "Biodiversity",
86
- # "Energy",
87
- # "Decarbonization",
88
- # "Climate science",
89
- # "Nature",
90
- # "Climate policy and justice",
91
- # "Oceans",
92
- # "Deep sea mining",
93
- # "ESG and regulations",
94
- # "CSRD",
95
- # ]] = Field(
96
- # ...,
97
- # description = """
98
- # Choose the topics that are most relevant to the user query, ex: Climate change, Energy, Biodiversity, ...
99
- # """,
100
- # )
101
- # date: str = Field(description="The date or period mentioned, ex: 2050, between 2020 and 2050")
102
- # location:Location
103
 
104
 
105
  def make_query_decomposition_chain(llm):
 
 
 
 
 
 
 
 
106
 
107
  openai_functions = [convert_to_openai_function(QueryDecomposition)]
108
  llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"QueryDecomposition"})
@@ -116,7 +133,8 @@ def make_query_decomposition_chain(llm):
116
  return chain
117
 
118
 
119
- def make_query_rewriter_chain(llm):
 
120
 
121
  openai_functions = [convert_to_openai_function(QueryAnalysis)]
122
  llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"QueryAnalysis"})
@@ -153,7 +171,7 @@ def make_query_transform_node(llm,k_final=15):
153
 
154
 
155
  decomposition_chain = make_query_decomposition_chain(llm)
156
- rewriter_chain = make_query_rewriter_chain(llm)
157
 
158
  def transform_query(state):
159
  print("---- Transform query ----")
@@ -172,14 +190,14 @@ def make_query_transform_node(llm,k_final=15):
172
  questions = []
173
  for question in new_state["questions"]:
174
  question_state = {"question":question}
175
- analysis_output = rewriter_chain.invoke({"input":question})
176
 
177
  # TODO WARNING llm should always return smthg
178
- # The case when the llm does not return any sources
179
- if not analysis_output["sources"] or not all(source in ["IPCC", "IPBS", "IPOS"] for source in analysis_output["sources"]):
180
- analysis_output["sources"] = ["IPCC", "IPBES", "IPOS"]
181
 
182
- question_state.update(analysis_output)
183
  questions.append(question_state)
184
 
185
  # Explode the questions into multiple questions with different sources
@@ -206,8 +224,9 @@ def make_query_transform_node(llm,k_final=15):
206
 
207
 
208
  new_state = {
209
- "remaining_questions":new_questions,
210
  "n_questions":len(new_questions),
 
211
  }
212
  return new_state
213
 
 
7
  from langchain_core.utils.function_calling import convert_to_openai_function
8
  from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
9
 
10
+ # OLD QUERY ANALYSIS
11
+ # keywords: List[str] = Field(
12
+ # description="""
13
+ # Extract the keywords from the user query to feed a search engine as a list
14
+ # Maximum 3 keywords
15
+
16
+ # Examples:
17
+ # - "What is the impact of deep sea mining ?" -> deep sea mining
18
+ # - "How will El Nino be impacted by climate change" -> el nino;climate change
19
+ # - "Is climate change a hoax" -> climate change;hoax
20
+ # """
21
+ # )
22
+
23
+ # alternative_queries: List[str] = Field(
24
+ # description="""
25
+ # Generate alternative search questions from the user query to feed a search engine
26
+ # """
27
+ # )
28
+
29
+ # step_back_question: str = Field(
30
+ # description="""
31
+ # You are an expert at world knowledge. Your task is to step back and paraphrase a question to a more generic step-back question, which is easier to answer.
32
+ # This questions should help you get more context and information about the user query
33
+ # """
34
+ # )
35
+ # - OpenAlex is for any other questions that are not in the previous categories but could be found in the scientific litterature
36
+ #
37
+
38
+
39
+ # topics: List[Literal[
40
+ # "Climate change",
41
+ # "Biodiversity",
42
+ # "Energy",
43
+ # "Decarbonization",
44
+ # "Climate science",
45
+ # "Nature",
46
+ # "Climate policy and justice",
47
+ # "Oceans",
48
+ # "Deep sea mining",
49
+ # "ESG and regulations",
50
+ # "CSRD",
51
+ # ]] = Field(
52
+ # ...,
53
+ # description = """
54
+ # Choose the topics that are most relevant to the user query, ex: Climate change, Energy, Biodiversity, ...
55
+ # """,
56
+ # )
57
+ # date: str = Field(description="The date or period mentioned, ex: 2050, between 2020 and 2050")
58
+ # location:Location
59
+
60
+
61
 
62
  ROUTING_INDEX = {
63
  "Vector":["IPCC","IPBES","IPOS", "AcclimaTerra"],
 
76
 
77
  questions: List[str] = Field(
78
  description="""
79
+ Think step by step to answer this question, and provide one or several search engine questions in the provided language for knowledge that you need.
80
  Suppose that the user is looking for information about climate change, energy, biodiversity, nature, and everything we can find the IPCC reports and scientific literature
81
  - If it's already a standalone and explicit question, just return the reformulated question for the search engine
82
  - If you need to decompose the question, output a list of maximum 2 to 3 questions
 
90
 
91
  class QueryAnalysis(BaseModel):
92
  """
93
+ Analyze the user query to extract the relevant sources
94
+
95
+ Deprecated:
96
  Analyzing the user query to extract topics, sources and date
97
  Also do query expansion to get alternative search queries
98
  Also provide simple keywords to feed a search engine
99
  """
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  sources: List[Literal["IPCC", "IPBES", "IPOS", "AcclimaTerra"]] = Field( #,"OpenAlex"]] = Field(
102
  ...,
103
  description="""
 
107
  - IPOS is for questions about the ocean and deep sea mining
108
  - AcclimaTerra is for questions about any specific place in, or close to, the french region "Nouvelle-Aquitaine"
109
  """,
 
110
  )
111
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
 
114
  def make_query_decomposition_chain(llm):
115
+ """Chain to decompose a query into smaller parts to think step by step to answer this question
116
+
117
+ Args:
118
+ llm (_type_): _description_
119
+
120
+ Returns:
121
+ _type_: _description_
122
+ """
123
 
124
  openai_functions = [convert_to_openai_function(QueryDecomposition)]
125
  llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"QueryDecomposition"})
 
133
  return chain
134
 
135
 
136
+ def make_query_analysis_chain(llm):
137
+ """Analyze the user query to extract the relevant sources"""
138
 
139
  openai_functions = [convert_to_openai_function(QueryAnalysis)]
140
  llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"QueryAnalysis"})
 
171
 
172
 
173
  decomposition_chain = make_query_decomposition_chain(llm)
174
+ query_analysis_chain = make_query_analysis_chain(llm)
175
 
176
  def transform_query(state):
177
  print("---- Transform query ----")
 
190
  questions = []
191
  for question in new_state["questions"]:
192
  question_state = {"question":question}
193
+ query_analysis_output = query_analysis_chain.invoke({"input":question})
194
 
195
  # TODO WARNING llm should always return smthg
196
+ # The case when the llm does not return any sources or wrong ouput
197
+ if not query_analysis_output["sources"] or not all(source in ["IPCC", "IPBS", "IPOS"] for source in query_analysis_output["sources"]):
198
+ query_analysis_output["sources"] = ["IPCC", "IPBES", "IPOS"]
199
 
200
+ question_state.update(query_analysis_output)
201
  questions.append(question_state)
202
 
203
  # Explode the questions into multiple questions with different sources
 
224
 
225
 
226
  new_state = {
227
+ "questions_list":new_questions,
228
  "n_questions":len(new_questions),
229
+ "handled_questions_index":[],
230
  }
231
  return new_state
232
 
climateqa/engine/chains/retrieve_documents.py CHANGED
@@ -106,6 +106,17 @@ def _add_metadata_and_score(docs: List) -> Document:
106
  docs_with_metadata.append(doc)
107
  return docs_with_metadata
108
 
 
 
 
 
 
 
 
 
 
 
 
109
  async def get_POC_relevant_documents(
110
  query: str,
111
  vectorstore:VectorStore,
@@ -116,14 +127,18 @@ async def get_POC_relevant_documents(
116
  threshold:float = 0.6,
117
  k_images: int = 5,
118
  reports:list = [],
 
119
  ) :
120
  # Prepare base search kwargs
121
  filters = {}
 
 
122
 
123
- if len(reports) > 0:
124
- filters["short_name"] = {"$in":reports}
125
- else:
126
- filters["source"] = { "$in": sources}
 
127
 
128
  filters_text = {
129
  **filters,
@@ -132,6 +147,8 @@ async def get_POC_relevant_documents(
132
  }
133
 
134
  docs_question = vectorstore.similarity_search_with_score(query=query,filter = filters_text,k = k_documents)
 
 
135
  docs_question = [x for x in docs_question if x[1] > threshold]
136
 
137
  if search_figures:
@@ -141,6 +158,10 @@ async def get_POC_relevant_documents(
141
  "chunk_type":"image"
142
  }
143
  docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images)
 
 
 
 
144
 
145
  return {
146
  "docs_question" : docs_question,
@@ -236,12 +257,13 @@ async def get_IPCC_relevant_documents(
236
 
237
  def concatenate_documents(index, source_type, docs_question_dict, k_by_question, k_summary_by_question, k_images_by_question):
238
  # Keep the right number of documents - The k_summary documents from SPM are placed in front
239
- if source_type == "Vector" :
240
- docs_question = docs_question_dict["docs_summaries"][:k_summary_by_question] + docs_question_dict["docs_full"][:k_by_question - k_summary_by_question]
241
  elif source_type == "POC" :
242
  docs_question = docs_question_dict["docs_question"][:k_by_question]
243
  else :
244
- docs_question = [doc for key in docs_question_dict.keys() for doc in docs_question_dict[key]]
 
245
 
246
  images_question = docs_question_dict["docs_images"][:k_images_by_question]
247
 
@@ -278,8 +300,18 @@ async def retrieve_documents(state,config, source_type, vectorstore,reranker,llm
278
  reports = state["reports"]
279
 
280
  # Get the current question
281
- current_question = state["remaining_questions"][0]
282
- remaining_questions = state["remaining_questions"][1:]
 
 
 
 
 
 
 
 
 
 
283
 
284
  k_by_question = k_final // state["n_questions"]
285
  k_summary_by_question = _get_k_summary_by_question(state["n_questions"])
@@ -318,6 +350,9 @@ async def retrieve_documents(state,config, source_type, vectorstore,reranker,llm
318
  threshold = 0.5,
319
  search_only = search_only,
320
  reports = reports,
 
 
 
321
  )
322
 
323
 
@@ -343,9 +378,12 @@ async def retrieve_documents(state,config, source_type, vectorstore,reranker,llm
343
  images_question = _add_sources_used_in_metadata(images_question,sources,question,index)
344
 
345
  # Add to the list of docs
346
- docs.extend(docs_question)
347
- related_content.extend(images_question)
348
- new_state = {"documents":docs, "related_contents": related_content,"remaining_questions":remaining_questions}
 
 
 
349
  return new_state
350
 
351
 
@@ -355,7 +393,20 @@ def make_IPx_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_
355
  @chain
356
  async def retrieve_IPx_docs(state, config):
357
  source_type = "IPx"
358
- state = await retrieve_documents(state,config, source_type, vectorstore,reranker,llm,rerank_by_question, k_final, k_before_reranking, k_summary)
 
 
 
 
 
 
 
 
 
 
 
 
 
359
  return state
360
 
361
  return retrieve_IPx_docs
@@ -364,12 +415,23 @@ def make_IPx_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_
364
  def make_POC_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
365
 
366
  @chain
367
- async def retrieve_IPx_docs(state, config):
368
  source_type = "POC"
369
- state = await retrieve_documents(state,config, source_type, vectorstore,reranker,llm,rerank_by_question, k_final, k_before_reranking, k_summary)
 
 
 
 
 
 
 
 
 
 
 
370
  return state
371
 
372
- return retrieve_IPx_docs
373
 
374
 
375
 
 
106
  docs_with_metadata.append(doc)
107
  return docs_with_metadata
108
 
109
+ def remove_duplicates_chunks(docs):
110
+ # Remove duplicates or almost duplicates
111
+ docs = sorted(docs,key=lambda x: x[1],reverse=True)
112
+ seen = set()
113
+ result = []
114
+ for doc in docs:
115
+ if doc[0].page_content not in seen:
116
+ seen.add(doc[0].page_content)
117
+ result.append(doc)
118
+ return result
119
+
120
  async def get_POC_relevant_documents(
121
  query: str,
122
  vectorstore:VectorStore,
 
127
  threshold:float = 0.6,
128
  k_images: int = 5,
129
  reports:list = [],
130
+ min_size:int = 200,
131
  ) :
132
  # Prepare base search kwargs
133
  filters = {}
134
+ docs_question = []
135
+ docs_images = []
136
 
137
+ # TODO add source selection
138
+ # if len(reports) > 0:
139
+ # filters["short_name"] = {"$in":reports}
140
+ # else:
141
+ # filters["source"] = { "$in": sources}
142
 
143
  filters_text = {
144
  **filters,
 
147
  }
148
 
149
  docs_question = vectorstore.similarity_search_with_score(query=query,filter = filters_text,k = k_documents)
150
+ # remove duplicates or almost duplicates
151
+ docs_question = remove_duplicates_chunks(docs_question)
152
  docs_question = [x for x in docs_question if x[1] > threshold]
153
 
154
  if search_figures:
 
158
  "chunk_type":"image"
159
  }
160
  docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images)
161
+
162
+ docs_question, docs_images = _add_metadata_and_score(docs_question), _add_metadata_and_score(docs_images)
163
+
164
+ docs_question = [x for x in docs_question if len(x.page_content) > min_size]
165
 
166
  return {
167
  "docs_question" : docs_question,
 
257
 
258
  def concatenate_documents(index, source_type, docs_question_dict, k_by_question, k_summary_by_question, k_images_by_question):
259
  # Keep the right number of documents - The k_summary documents from SPM are placed in front
260
+ if source_type == "IPx":
261
+ docs_question = docs_question_dict["docs_summaries"][:k_summary_by_question] + docs_question_dict["docs_full"][:(k_by_question - k_summary_by_question)]
262
  elif source_type == "POC" :
263
  docs_question = docs_question_dict["docs_question"][:k_by_question]
264
  else :
265
+ raise ValueError("source_type should be either Vector or POC")
266
+ # docs_question = [doc for key in docs_question_dict.keys() for doc in docs_question_dict[key]][:(k_by_question)]
267
 
268
  images_question = docs_question_dict["docs_images"][:k_images_by_question]
269
 
 
300
  reports = state["reports"]
301
 
302
  # Get the current question
303
+ # current_question = state["questions_list"][0]
304
+ # remaining_questions = state["remaining_questions"][1:]
305
+
306
+ current_question_id = None
307
+ print("Here", range(len(state["questions_list"])),state["handled_questions_index"])
308
+
309
+ for i in range(len(state["questions_list"])):
310
+ if i not in state["handled_questions_index"]:
311
+ current_question_id = i
312
+ break
313
+ current_question = state["questions_list"][current_question_id]
314
+ # TODO filter on source_type
315
 
316
  k_by_question = k_final // state["n_questions"]
317
  k_summary_by_question = _get_k_summary_by_question(state["n_questions"])
 
350
  threshold = 0.5,
351
  search_only = search_only,
352
  reports = reports,
353
+ min_size= 200,
354
+ k_documents= k_before_reranking,
355
+ k_images= k_by_question
356
  )
357
 
358
 
 
378
  images_question = _add_sources_used_in_metadata(images_question,sources,question,index)
379
 
380
  # Add to the list of docs
381
+ # docs.extend(docs_question)
382
+ # related_content.extend(images_question)
383
+ docs = docs_question
384
+ related_content = images_question
385
+ new_state = {"documents":docs, "related_contents": related_content, "handled_questions_index": [current_question_id]}
386
+ print("Updated state with question ", current_question_id, " added ", len(docs), " documents")
387
  return new_state
388
 
389
 
 
393
  @chain
394
  async def retrieve_IPx_docs(state, config):
395
  source_type = "IPx"
396
+ return {"documents":[], "related_contents": [], "handled_questions_index": list(range(len(state["questions_list"])))} # TODO Remove
397
+
398
+ state = await retrieve_documents(
399
+ state = state,
400
+ config= config,
401
+ source_type=source_type,
402
+ vectorstore=vectorstore,
403
+ reranker= reranker,
404
+ llm=llm,
405
+ rerank_by_question=rerank_by_question,
406
+ k_final=k_final,
407
+ k_before_reranking=k_before_reranking,
408
+ k_summary=k_summary
409
+ )
410
  return state
411
 
412
  return retrieve_IPx_docs
 
415
  def make_POC_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
416
 
417
  @chain
418
+ async def retrieve_POC_docs_node(state, config):
419
  source_type = "POC"
420
+ state = await retrieve_documents(
421
+ state = state,
422
+ config= config,
423
+ source_type=source_type,
424
+ vectorstore=vectorstore,
425
+ reranker= reranker,
426
+ llm=llm,
427
+ rerank_by_question=rerank_by_question,
428
+ k_final=k_final,
429
+ k_before_reranking=k_before_reranking,
430
+ k_summary=k_summary
431
+ )
432
  return state
433
 
434
+ return retrieve_POC_docs_node
435
 
436
 
437
 
climateqa/engine/graph.py CHANGED
@@ -9,6 +9,9 @@ from langchain_core.runnables.graph import CurveStyle, MermaidDrawMethod
9
  from typing_extensions import TypedDict
10
  from typing import List, Dict
11
 
 
 
 
12
  from IPython.display import display, HTML, Image
13
 
14
  from .chains.answer_chitchat import make_chitchat_node
@@ -31,7 +34,8 @@ class GraphState(TypedDict):
31
  intent : str
32
  search_graphs_chitchat : bool
33
  query: str
34
- remaining_questions : List[dict]
 
35
  n_questions : int
36
  answer: str
37
  audience: str = "experts"
@@ -40,20 +44,20 @@ class GraphState(TypedDict):
40
  sources_auto: bool = True
41
  min_year: int = 1960
42
  max_year: int = None
43
- documents: List[Document]
44
- related_contents : List[Document]
45
  recommended_content : List[Document]
46
  search_only : bool = False
47
  reports : List[str] = []
48
 
49
  def dummy(state):
50
- return state
51
 
52
  def search(state): #TODO
53
- return state
54
 
55
  def answer_search(state):#TODO
56
- return state
57
 
58
  def route_intent(state):
59
  intent = state["intent"]
@@ -76,22 +80,40 @@ def route_translation(state):
76
  if state["language"].lower() == "english":
77
  return "transform_query"
78
  else:
79
- return "translate_query"
 
 
80
 
81
  def route_based_on_relevant_docs(state,threshold_docs=0.2):
82
  docs = [x for x in state["documents"] if x.metadata["reranking_score"] > threshold_docs]
 
83
  if len(docs) > 0:
84
  return "answer_rag"
85
  else:
86
  return "answer_rag_no_docs"
87
 
88
  def route_continue_retrieve_documents(state):
89
- if len(state["remaining_questions"]) == 0 and state["search_only"] :
90
  return END
91
- elif len(state["remaining_questions"]) > 0:
 
 
92
  return "retrieve_documents"
93
- else:
 
 
 
 
94
  return "answer_search"
 
 
 
 
 
 
 
 
 
95
 
96
  def route_retrieve_documents(state):
97
  sources_to_retrieve = []
@@ -167,6 +189,12 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_regi
167
  route_continue_retrieve_documents,
168
  make_id_dict([END,"retrieve_documents","answer_search"])
169
  )
 
 
 
 
 
 
170
 
171
  workflow.add_conditional_edges(
172
  "answer_search",
@@ -188,14 +216,15 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_regi
188
 
189
  # Define the edges
190
  workflow.add_edge("translate_query", "transform_query")
191
- workflow.add_edge("transform_query", "retrieve_documents")
 
192
 
193
  workflow.add_edge("retrieve_graphs", END)
194
  workflow.add_edge("answer_rag", END)
195
  workflow.add_edge("answer_rag_no_docs", END)
196
  workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
197
  workflow.add_edge("retrieve_graphs_chitchat", END)
198
- workflow.add_edge("retrieve_local_data", "answer_search")
199
 
200
  # Compile
201
  app = workflow.compile()
 
9
  from typing_extensions import TypedDict
10
  from typing import List, Dict
11
 
12
+ import operator
13
+ from typing import Annotated
14
+
15
  from IPython.display import display, HTML, Image
16
 
17
  from .chains.answer_chitchat import make_chitchat_node
 
34
  intent : str
35
  search_graphs_chitchat : bool
36
  query: str
37
+ questions_list : List[dict]
38
+ handled_questions_index : Annotated[list[int], operator.add]
39
  n_questions : int
40
  answer: str
41
  audience: str = "experts"
 
44
  sources_auto: bool = True
45
  min_year: int = 1960
46
  max_year: int = None
47
+ documents: Annotated[List[Document], operator.add]
48
+ related_contents : Annotated[List[Document], operator.add]
49
  recommended_content : List[Document]
50
  search_only : bool = False
51
  reports : List[str] = []
52
 
53
  def dummy(state):
54
+ return
55
 
56
  def search(state): #TODO
57
+ return
58
 
59
  def answer_search(state):#TODO
60
+ return
61
 
62
  def route_intent(state):
63
  intent = state["intent"]
 
80
  if state["language"].lower() == "english":
81
  return "transform_query"
82
  else:
83
+ return "transform_query"
84
+ # return "translate_query" #TODO : add translation
85
+
86
 
87
  def route_based_on_relevant_docs(state,threshold_docs=0.2):
88
  docs = [x for x in state["documents"] if x.metadata["reranking_score"] > threshold_docs]
89
+ print("Route : ", ["answer_rag" if len(docs) > 0 else "answer_rag_no_docs"])
90
  if len(docs) > 0:
91
  return "answer_rag"
92
  else:
93
  return "answer_rag_no_docs"
94
 
95
  def route_continue_retrieve_documents(state):
96
+ if len(state["questions_list"]) == len(state["handled_questions_index"]) and state["search_only"] :
97
  return END
98
+ elif len(state["questions_list"]) == len(state["handled_questions_index"]):
99
+ return "answer_search"
100
+ else :
101
  return "retrieve_documents"
102
+
103
+ def route_continue_retrieve_local_documents(state):
104
+ if len(state["questions_list"]) == len(state["handled_questions_index"]) and state["search_only"] :
105
+ return END
106
+ elif len(state["questions_list"]) == len(state["handled_questions_index"]):
107
  return "answer_search"
108
+ else :
109
+ return "retrieve_local_data"
110
+
111
+ # if len(state["remaining_questions"]) == 0 and state["search_only"] :
112
+ # return END
113
+ # elif len(state["remaining_questions"]) > 0:
114
+ # return "retrieve_documents"
115
+ # else:
116
+ # return "answer_search"
117
 
118
  def route_retrieve_documents(state):
119
  sources_to_retrieve = []
 
189
  route_continue_retrieve_documents,
190
  make_id_dict([END,"retrieve_documents","answer_search"])
191
  )
192
+ workflow.add_conditional_edges(
193
+ "retrieve_local_data",
194
+ # lambda state : "retrieve_documents" if len(state["remaining_questions"]) > 0 else "answer_search",
195
+ route_continue_retrieve_local_documents,
196
+ make_id_dict([END,"retrieve_local_data","answer_search"])
197
+ )
198
 
199
  workflow.add_conditional_edges(
200
  "answer_search",
 
216
 
217
  # Define the edges
218
  workflow.add_edge("translate_query", "transform_query")
219
+ # workflow.add_edge("transform_query", "retrieve_documents") #TODO put back
220
+ workflow.add_edge("transform_query", END) # TODO remove
221
 
222
  workflow.add_edge("retrieve_graphs", END)
223
  workflow.add_edge("answer_rag", END)
224
  workflow.add_edge("answer_rag_no_docs", END)
225
  workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
226
  workflow.add_edge("retrieve_graphs_chitchat", END)
227
+ # workflow.add_edge("retrieve_local_data", "answer_search")
228
 
229
  # Compile
230
  app = workflow.compile()
climateqa/event_handler.py CHANGED
@@ -15,6 +15,13 @@ def init_audience(audience :str) -> str:
15
  audience_prompt = audience_prompts["experts"]
16
  return audience_prompt
17
 
 
 
 
 
 
 
 
18
  def handle_retrieved_documents(event: StreamEvent, history : list[ChatMessage], used_documents : list[str]) -> tuple[str, list[ChatMessage], list[str]]:
19
  """
20
  Handles the retrieved documents and returns the HTML representation of the documents
@@ -27,26 +34,22 @@ def handle_retrieved_documents(event: StreamEvent, history : list[ChatMessage],
27
  Returns:
28
  tuple[str, list[ChatMessage], list[str]]: The updated HTML representation of the documents, the updated message history and the updated list of used documents
29
  """
 
 
 
30
  try:
31
- docs = event["data"]["output"]["documents"]
32
- docs_html = []
33
- textual_docs = [d for d in docs if d.metadata["chunk_type"] == "text"]
34
- for i, d in enumerate(textual_docs, 1):
35
- if d.metadata["chunk_type"] == "text":
36
- docs_html.append(make_html_source(d, i))
37
 
38
  used_documents = used_documents + [f"{d.metadata['short_name']} - {d.metadata['name']}" for d in docs]
39
  if used_documents!=[]:
40
  history[-1].content = "Adding sources :\n\n - " + "\n - ".join(np.unique(used_documents))
41
-
42
- docs_html = "".join(docs_html)
43
 
44
- related_contents = event["data"]["output"]["related_contents"]
45
-
46
  except Exception as e:
47
  print(f"Error getting documents: {e}")
48
  print(event)
49
- return docs, docs_html, history, used_documents, related_contents
50
 
51
  def stream_answer(history: list[ChatMessage], event : StreamEvent, start_streaming : bool, answer_message_content : str)-> tuple[list[ChatMessage], bool, str]:
52
  """
 
15
  audience_prompt = audience_prompts["experts"]
16
  return audience_prompt
17
 
18
+ def convert_to_docs_to_html(docs: list[dict]) -> str:
19
+ docs_html = []
20
+ for i, d in enumerate(docs, 1):
21
+ if d.metadata["chunk_type"] == "text":
22
+ docs_html.append(make_html_source(d, i))
23
+ return "".join(docs_html)
24
+
25
  def handle_retrieved_documents(event: StreamEvent, history : list[ChatMessage], used_documents : list[str]) -> tuple[str, list[ChatMessage], list[str]]:
26
  """
27
  Handles the retrieved documents and returns the HTML representation of the documents
 
34
  Returns:
35
  tuple[str, list[ChatMessage], list[str]]: The updated HTML representation of the documents, the updated message history and the updated list of used documents
36
  """
37
+ if "documents" not in event["data"]["output"] or event["data"]["output"]["documents"] == []:
38
+ return history, used_documents, []
39
+
40
  try:
41
+ docs = event["data"]["output"]["documents"]
 
 
 
 
 
42
 
43
  used_documents = used_documents + [f"{d.metadata['short_name']} - {d.metadata['name']}" for d in docs]
44
  if used_documents!=[]:
45
  history[-1].content = "Adding sources :\n\n - " + "\n - ".join(np.unique(used_documents))
46
+
47
+ #TODO do the same for related contents
48
 
 
 
49
  except Exception as e:
50
  print(f"Error getting documents: {e}")
51
  print(event)
52
+ return history, used_documents
53
 
54
  def stream_answer(history: list[ChatMessage], event : StreamEvent, start_streaming : bool, answer_message_content : str)-> tuple[list[ChatMessage], bool, str]:
55
  """
front/utils.py CHANGED
@@ -39,7 +39,11 @@ def parse_output_llm_with_sources(output:str)->str:
39
  content_parts = "".join(parts)
40
  return content_parts
41
 
 
 
42
  def process_figures(docs:list, new_figures:list)->tuple:
 
 
43
  docs = docs + new_figures
44
 
45
  figures = '<div class="figures-container"><p></p> </div>'
 
39
  content_parts = "".join(parts)
40
  return content_parts
41
 
42
+
43
+
44
  def process_figures(docs:list, new_figures:list)->tuple:
45
+ if new_figures == []:
46
+ return docs, "", []
47
  docs = docs + new_figures
48
 
49
  figures = '<div class="figures-container"><p></p> </div>'