WIP add regionnal sources
Browse files- app.py +6 -7
- climateqa/engine/chains/query_transformation.py +24 -11
- climateqa/engine/chains/retrieve_documents.py +100 -25
- climateqa/engine/graph.py +35 -12
- climateqa/engine/reranker.py +5 -0
app.py
CHANGED
@@ -103,11 +103,12 @@ CITATION_TEXT = r"""@misc{climateqa,
|
|
103 |
embeddings_function = get_embeddings_function()
|
104 |
vectorstore = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX"))
|
105 |
vectorstore_graphs = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_OWID"), text_key="description")
|
|
|
106 |
|
107 |
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
|
108 |
reranker = get_reranker("nano")
|
109 |
|
110 |
-
agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, reranker=reranker)
|
111 |
|
112 |
# Function to update modal visibility
|
113 |
def update_config_modal_visibility(config_open):
|
@@ -149,7 +150,7 @@ async def chat(
|
|
149 |
print(f">> NEW QUESTION ({date_now}) : {query}")
|
150 |
|
151 |
audience_prompt = init_audience(audience)
|
152 |
-
sources = sources or ["IPCC", "IPBES"
|
153 |
reports = reports or []
|
154 |
|
155 |
# Prepare inputs for agent
|
@@ -606,9 +607,7 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
|
|
606 |
outputs=[graphs_container]
|
607 |
)
|
608 |
|
609 |
-
|
610 |
-
|
611 |
-
# Other tabs
|
612 |
with gr.Tab("About", elem_classes="max-height other-tabs"):
|
613 |
with gr.Row():
|
614 |
with gr.Column(scale=1):
|
@@ -629,10 +628,10 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
|
|
629 |
show_copy_button=True,
|
630 |
lines=len(CITATION_TEXT.split('\n')),
|
631 |
)
|
632 |
-
|
633 |
-
# Event handlers
|
634 |
config_modal, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only, dropdown_audience, after, output_query, output_language = create_config_modal(config_open)
|
635 |
|
|
|
636 |
config_button.click(
|
637 |
fn=update_config_modal_visibility,
|
638 |
inputs=[config_open],
|
|
|
103 |
embeddings_function = get_embeddings_function()
|
104 |
vectorstore = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX"))
|
105 |
vectorstore_graphs = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_OWID"), text_key="description")
|
106 |
+
vectorstore_region = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_REGION"))
|
107 |
|
108 |
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
|
109 |
reranker = get_reranker("nano")
|
110 |
|
111 |
+
agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker)
|
112 |
|
113 |
# Function to update modal visibility
|
114 |
def update_config_modal_visibility(config_open):
|
|
|
150 |
print(f">> NEW QUESTION ({date_now}) : {query}")
|
151 |
|
152 |
audience_prompt = init_audience(audience)
|
153 |
+
sources = sources or ["IPCC", "IPBES"]
|
154 |
reports = reports or []
|
155 |
|
156 |
# Prepare inputs for agent
|
|
|
607 |
outputs=[graphs_container]
|
608 |
)
|
609 |
|
610 |
+
|
|
|
|
|
611 |
with gr.Tab("About", elem_classes="max-height other-tabs"):
|
612 |
with gr.Row():
|
613 |
with gr.Column(scale=1):
|
|
|
628 |
show_copy_button=True,
|
629 |
lines=len(CITATION_TEXT.split('\n')),
|
630 |
)
|
631 |
+
# Configuration pannel
|
|
|
632 |
config_modal, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only, dropdown_audience, after, output_query, output_language = create_config_modal(config_open)
|
633 |
|
634 |
+
# Event handlers
|
635 |
config_button.click(
|
636 |
fn=update_config_modal_visibility,
|
637 |
inputs=[config_open],
|
climateqa/engine/chains/query_transformation.py
CHANGED
@@ -9,7 +9,7 @@ from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
|
9 |
|
10 |
|
11 |
ROUTING_INDEX = {
|
12 |
-
"Vector":["IPCC","IPBES","IPOS"],
|
13 |
"OpenAlex":["OpenAlex"],
|
14 |
}
|
15 |
|
@@ -69,13 +69,14 @@ class QueryAnalysis(BaseModel):
|
|
69 |
# """
|
70 |
# )
|
71 |
|
72 |
-
sources: List[Literal["IPCC", "IPBES", "IPOS"]] = Field( #,"OpenAlex"]] = Field(
|
73 |
...,
|
74 |
description="""
|
75 |
Given a user question choose which documents would be most relevant for answering their question,
|
76 |
- IPCC is for questions about climate change, energy, impacts, and everything we can find the IPCC reports
|
77 |
- IPBES is for questions about biodiversity and nature
|
78 |
- IPOS is for questions about the ocean and deep sea mining
|
|
|
79 |
""",
|
80 |
# - OpenAlex is for any other questions that are not in the previous categories but could be found in the scientific litterature
|
81 |
)
|
@@ -133,6 +134,23 @@ def make_query_rewriter_chain(llm):
|
|
133 |
|
134 |
|
135 |
def make_query_transform_node(llm,k_final=15):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
decomposition_chain = make_query_decomposition_chain(llm)
|
138 |
rewriter_chain = make_query_rewriter_chain(llm)
|
@@ -140,14 +158,9 @@ def make_query_transform_node(llm,k_final=15):
|
|
140 |
def transform_query(state):
|
141 |
print("---- Transform query ----")
|
142 |
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
else:
|
147 |
-
auto_mode = True
|
148 |
-
|
149 |
-
sources_input = state.get("sources_input")
|
150 |
-
if sources_input is None: sources_input = ROUTING_INDEX["Vector"]
|
151 |
|
152 |
new_state = {}
|
153 |
|
@@ -159,7 +172,7 @@ def make_query_transform_node(llm,k_final=15):
|
|
159 |
questions = []
|
160 |
for question in new_state["questions"]:
|
161 |
question_state = {"question":question}
|
162 |
-
analysis_output = rewriter_chain.invoke({"input":question})
|
163 |
|
164 |
# TODO WARNING llm should always return smthg
|
165 |
# The case when the llm does not return any sources
|
|
|
9 |
|
10 |
|
11 |
ROUTING_INDEX = {
|
12 |
+
"Vector":["IPCC","IPBES","IPOS", "AcclimaTerra"],
|
13 |
"OpenAlex":["OpenAlex"],
|
14 |
}
|
15 |
|
|
|
69 |
# """
|
70 |
# )
|
71 |
|
72 |
+
sources: List[Literal["IPCC", "IPBES", "IPOS", "AcclimaTerra"]] = Field( #,"OpenAlex"]] = Field(
|
73 |
...,
|
74 |
description="""
|
75 |
Given a user question choose which documents would be most relevant for answering their question,
|
76 |
- IPCC is for questions about climate change, energy, impacts, and everything we can find the IPCC reports
|
77 |
- IPBES is for questions about biodiversity and nature
|
78 |
- IPOS is for questions about the ocean and deep sea mining
|
79 |
+
- AcclimaTerra is for questions about any specific place in, or close to, the french region "Nouvelle-Aquitaine"
|
80 |
""",
|
81 |
# - OpenAlex is for any other questions that are not in the previous categories but could be found in the scientific litterature
|
82 |
)
|
|
|
134 |
|
135 |
|
136 |
def make_query_transform_node(llm,k_final=15):
|
137 |
+
"""
|
138 |
+
Creates a query transformation node that processes and transforms a given query state.
|
139 |
+
Args:
|
140 |
+
llm: The language model to be used for query decomposition and rewriting.
|
141 |
+
k_final (int, optional): The final number of questions to be generated. Defaults to 15.
|
142 |
+
Returns:
|
143 |
+
function: A function that takes a query state and returns a transformed state.
|
144 |
+
The returned function performs the following steps:
|
145 |
+
1. Checks if the query should be processed in auto mode based on the state.
|
146 |
+
2. Retrieves the input sources from the state or defaults to a predefined routing index.
|
147 |
+
3. Decomposes the query using the decomposition chain.
|
148 |
+
4. Analyzes each decomposed question using the rewriter chain.
|
149 |
+
5. Ensures that the sources returned by the language model are valid.
|
150 |
+
6. Explodes the questions into multiple questions with different sources based on the mode.
|
151 |
+
7. Constructs a new state with the transformed questions and their respective sources.
|
152 |
+
"""
|
153 |
+
|
154 |
|
155 |
decomposition_chain = make_query_decomposition_chain(llm)
|
156 |
rewriter_chain = make_query_rewriter_chain(llm)
|
|
|
158 |
def transform_query(state):
|
159 |
print("---- Transform query ----")
|
160 |
|
161 |
+
auto_mode = state.get("sources_auto", False)
|
162 |
+
sources_input = state.get("sources_input", ROUTING_INDEX["Vector"])
|
163 |
+
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
new_state = {}
|
166 |
|
|
|
172 |
questions = []
|
173 |
for question in new_state["questions"]:
|
174 |
question_state = {"question":question}
|
175 |
+
analysis_output = rewriter_chain.invoke({"input":question})
|
176 |
|
177 |
# TODO WARNING llm should always return smthg
|
178 |
# The case when the llm does not return any sources
|
climateqa/engine/chains/retrieve_documents.py
CHANGED
@@ -7,7 +7,7 @@ from langchain_core.runnables import chain
|
|
7 |
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
|
8 |
from langchain_core.runnables import RunnableLambda
|
9 |
|
10 |
-
from ..reranker import rerank_docs
|
11 |
# from ...knowledge.retriever import ClimateQARetriever
|
12 |
from ...knowledge.openalex import OpenAlexRetriever
|
13 |
from .keywords_extraction import make_keywords_extraction_chain
|
@@ -106,6 +106,48 @@ def _add_metadata_and_score(docs: List) -> Document:
|
|
106 |
docs_with_metadata.append(doc)
|
107 |
return docs_with_metadata
|
108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
async def get_IPCC_relevant_documents(
|
110 |
query: str,
|
111 |
vectorstore:VectorStore,
|
@@ -191,12 +233,26 @@ async def get_IPCC_relevant_documents(
|
|
191 |
}
|
192 |
|
193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
|
195 |
# The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results
|
196 |
# @chain
|
197 |
-
async def retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5, k_images=5):
|
198 |
"""
|
199 |
-
|
200 |
|
201 |
Args:
|
202 |
state (dict): The current state containing documents, related content, relevant content sources, remaining questions and n_questions.
|
@@ -212,7 +268,7 @@ async def retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_qu
|
|
212 |
Returns:
|
213 |
dict: The updated state containing the retrieved and reranked documents, related content, and remaining questions.
|
214 |
"""
|
215 |
-
print("---- Retrieve documents ----")
|
216 |
docs = state.get("documents", [])
|
217 |
related_content = state.get("related_content", [])
|
218 |
|
@@ -237,45 +293,51 @@ async def retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_qu
|
|
237 |
await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config)
|
238 |
|
239 |
|
240 |
-
if index == "Vector": # always true for now
|
|
|
241 |
docs_question_dict = await get_IPCC_relevant_documents(
|
242 |
query = question,
|
243 |
vectorstore=vectorstore,
|
244 |
search_figures = search_figures,
|
245 |
sources = sources,
|
246 |
min_size = 200,
|
247 |
-
k_summary =
|
248 |
k_total = k_before_reranking,
|
249 |
k_images = k_images_by_question,
|
250 |
threshold = 0.5,
|
251 |
search_only = search_only,
|
252 |
reports = reports,
|
253 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
254 |
|
255 |
|
256 |
# Rerank
|
257 |
-
if reranker is not None:
|
258 |
with suppress_output():
|
259 |
-
|
260 |
-
|
261 |
-
docs_question_images_reranked = rerank_docs(reranker,docs_question_dict["docs_images"],question)
|
262 |
-
if rerank_by_question:
|
263 |
-
docs_question_summary_reranked = sorted(docs_question_summary_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
|
264 |
-
docs_question_fulltext_reranked = sorted(docs_question_fulltext_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
|
265 |
-
docs_question_images_reranked = sorted(docs_question_images_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
|
266 |
else:
|
267 |
-
docs_question = docs_question_dict["docs_summaries"] + docs_question_dict["docs_full"]
|
268 |
# Add a default reranking score
|
269 |
for doc in docs_question:
|
270 |
doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
|
271 |
|
272 |
-
|
273 |
-
docs_question =
|
274 |
-
|
275 |
-
|
276 |
if reranker is not None and rerank_by_question:
|
277 |
-
docs_question =
|
278 |
-
|
279 |
# Add sources used in the metadata
|
280 |
docs_question = _add_sources_used_in_metadata(docs_question,sources,question,index)
|
281 |
images_question = _add_sources_used_in_metadata(images_question,sources,question,index)
|
@@ -288,13 +350,26 @@ async def retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_qu
|
|
288 |
|
289 |
|
290 |
|
291 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
292 |
@chain
|
293 |
-
async def
|
294 |
-
|
|
|
295 |
return state
|
296 |
|
297 |
-
return
|
298 |
|
299 |
|
300 |
|
|
|
7 |
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
|
8 |
from langchain_core.runnables import RunnableLambda
|
9 |
|
10 |
+
from ..reranker import rerank_docs, rerank_and_sort_docs
|
11 |
# from ...knowledge.retriever import ClimateQARetriever
|
12 |
from ...knowledge.openalex import OpenAlexRetriever
|
13 |
from .keywords_extraction import make_keywords_extraction_chain
|
|
|
106 |
docs_with_metadata.append(doc)
|
107 |
return docs_with_metadata
|
108 |
|
109 |
+
async def get_POC_relevant_documents(
|
110 |
+
query: str,
|
111 |
+
vectorstore:VectorStore,
|
112 |
+
sources:list = ["Acclimaterra","PCAET","Plan Biodiversite"],
|
113 |
+
search_figures:bool = False,
|
114 |
+
search_only:bool = False,
|
115 |
+
k_documents:int = 10,
|
116 |
+
threshold:float = 0.6,
|
117 |
+
k_images: int = 5,
|
118 |
+
reports:list = [],
|
119 |
+
) :
|
120 |
+
# Prepare base search kwargs
|
121 |
+
filters = {}
|
122 |
+
|
123 |
+
if len(reports) > 0:
|
124 |
+
filters["short_name"] = {"$in":reports}
|
125 |
+
else:
|
126 |
+
filters["source"] = { "$in": sources}
|
127 |
+
|
128 |
+
filters_text = {
|
129 |
+
**filters,
|
130 |
+
"chunk_type":"text",
|
131 |
+
# "report_type": {}, # TODO to be completed to choose the right documents / chapters according to the analysis of the question
|
132 |
+
}
|
133 |
+
|
134 |
+
docs_question = vectorstore.similarity_search_with_score(query=query,filter = filters_text,k = k_documents)
|
135 |
+
docs_question = [x for x in docs_question if x[1] > threshold]
|
136 |
+
|
137 |
+
if search_figures:
|
138 |
+
# Images
|
139 |
+
filters_image = {
|
140 |
+
**filters,
|
141 |
+
"chunk_type":"image"
|
142 |
+
}
|
143 |
+
docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images)
|
144 |
+
|
145 |
+
return {
|
146 |
+
"docs_question" : docs_question,
|
147 |
+
"docs_images" : docs_images
|
148 |
+
}
|
149 |
+
|
150 |
+
|
151 |
async def get_IPCC_relevant_documents(
|
152 |
query: str,
|
153 |
vectorstore:VectorStore,
|
|
|
233 |
}
|
234 |
|
235 |
|
236 |
+
|
237 |
+
def concatenate_documents(index, source_type, docs_question_dict, k_by_question, k_summary_by_question, k_images_by_question):
|
238 |
+
# Keep the right number of documents - The k_summary documents from SPM are placed in front
|
239 |
+
if source_type == "Vector" :
|
240 |
+
docs_question = docs_question_dict["docs_summaries"][:k_summary_by_question] + docs_question_dict["docs_full"][:k_by_question - k_summary_by_question]
|
241 |
+
elif source_type == "POC" :
|
242 |
+
docs_question = docs_question_dict["docs_question"][:k_by_question]
|
243 |
+
else :
|
244 |
+
docs_question = [doc for key in docs_question_dict.keys() for doc in docs_question_dict[key]]
|
245 |
+
|
246 |
+
images_question = docs_question_dict["docs_images"][:k_images_by_question]
|
247 |
+
|
248 |
+
return docs_question, images_question
|
249 |
+
|
250 |
|
251 |
# The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results
|
252 |
# @chain
|
253 |
+
async def retrieve_documents(state,config, source_type, vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5, k_images=5):
|
254 |
"""
|
255 |
+
Unpack the first question of the remaining questions, and retrieve and rerank corresponding documents, based on the question and selected_sources
|
256 |
|
257 |
Args:
|
258 |
state (dict): The current state containing documents, related content, relevant content sources, remaining questions and n_questions.
|
|
|
268 |
Returns:
|
269 |
dict: The updated state containing the retrieved and reranked documents, related content, and remaining questions.
|
270 |
"""
|
271 |
+
print(f"---- Retrieve documents from {source_type}----")
|
272 |
docs = state.get("documents", [])
|
273 |
related_content = state.get("related_content", [])
|
274 |
|
|
|
293 |
await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config)
|
294 |
|
295 |
|
296 |
+
# if index == "Vector": # always true for now #TODO rename to IPx
|
297 |
+
if source_type == "IPx": # always true for now #TODO rename to IPx
|
298 |
docs_question_dict = await get_IPCC_relevant_documents(
|
299 |
query = question,
|
300 |
vectorstore=vectorstore,
|
301 |
search_figures = search_figures,
|
302 |
sources = sources,
|
303 |
min_size = 200,
|
304 |
+
k_summary = k_before_reranking-1,
|
305 |
k_total = k_before_reranking,
|
306 |
k_images = k_images_by_question,
|
307 |
threshold = 0.5,
|
308 |
search_only = search_only,
|
309 |
reports = reports,
|
310 |
)
|
311 |
+
|
312 |
+
if source_type == "POC":
|
313 |
+
docs_question_dict = await get_POC_relevant_documents(
|
314 |
+
query = question,
|
315 |
+
vectorstore=vectorstore,
|
316 |
+
search_figures = search_figures,
|
317 |
+
sources = sources,
|
318 |
+
threshold = 0.5,
|
319 |
+
search_only = search_only,
|
320 |
+
reports = reports,
|
321 |
+
)
|
322 |
|
323 |
|
324 |
# Rerank
|
325 |
+
if reranker is not None and rerank_by_question:
|
326 |
with suppress_output():
|
327 |
+
for key in docs_question_dict.keys():
|
328 |
+
docs_question_dict[key] = rerank_and_sort_docs(reranker,docs_question_dict[key],question)
|
|
|
|
|
|
|
|
|
|
|
329 |
else:
|
|
|
330 |
# Add a default reranking score
|
331 |
for doc in docs_question:
|
332 |
doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
|
333 |
|
334 |
+
# Keep the right number of documents
|
335 |
+
docs_question, images_question = concatenate_documents(index, source_type, docs_question_dict, k_by_question, k_summary_by_question, k_images_by_question)
|
336 |
+
|
337 |
+
# Rerank the documents to put the most relevant in front
|
338 |
if reranker is not None and rerank_by_question:
|
339 |
+
docs_question = rerank_and_sort_docs(reranker, docs_question, question)
|
340 |
+
|
341 |
# Add sources used in the metadata
|
342 |
docs_question = _add_sources_used_in_metadata(docs_question,sources,question,index)
|
343 |
images_question = _add_sources_used_in_metadata(images_question,sources,question,index)
|
|
|
350 |
|
351 |
|
352 |
|
353 |
+
def make_IPx_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
|
354 |
+
|
355 |
+
@chain
|
356 |
+
async def retrieve_IPx_docs(state, config):
|
357 |
+
source_type = "IPx"
|
358 |
+
state = await retrieve_documents(state,config, source_type, vectorstore,reranker,llm,rerank_by_question, k_final, k_before_reranking, k_summary)
|
359 |
+
return state
|
360 |
+
|
361 |
+
return retrieve_IPx_docs
|
362 |
+
|
363 |
+
|
364 |
+
def make_POC_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
|
365 |
+
|
366 |
@chain
|
367 |
+
async def retrieve_IPx_docs(state, config):
|
368 |
+
source_type = "POC"
|
369 |
+
state = await retrieve_documents(state,config, source_type, vectorstore,reranker,llm,rerank_by_question, k_final, k_before_reranking, k_summary)
|
370 |
return state
|
371 |
|
372 |
+
return retrieve_IPx_docs
|
373 |
|
374 |
|
375 |
|
climateqa/engine/graph.py
CHANGED
@@ -16,7 +16,7 @@ from .chains.answer_ai_impact import make_ai_impact_node
|
|
16 |
from .chains.query_transformation import make_query_transform_node
|
17 |
from .chains.translation import make_translation_node
|
18 |
from .chains.intent_categorization import make_intent_categorization_node
|
19 |
-
from .chains.retrieve_documents import
|
20 |
from .chains.answer_rag import make_rag_node
|
21 |
from .chains.graph_retriever import make_graph_retriever_node
|
22 |
from .chains.chitchat_categorization import make_chitchat_intent_categorization_node
|
@@ -46,6 +46,9 @@ class GraphState(TypedDict):
|
|
46 |
search_only : bool = False
|
47 |
reports : List[str] = []
|
48 |
|
|
|
|
|
|
|
49 |
def search(state): #TODO
|
50 |
return state
|
51 |
|
@@ -60,7 +63,7 @@ def route_intent(state):
|
|
60 |
# return "answer_ai_impact"
|
61 |
else:
|
62 |
# Search route
|
63 |
-
return "
|
64 |
|
65 |
def chitchat_route_intent(state):
|
66 |
intent = state["search_graphs_chitchat"]
|
@@ -82,18 +85,29 @@ def route_based_on_relevant_docs(state,threshold_docs=0.2):
|
|
82 |
else:
|
83 |
return "answer_rag_no_docs"
|
84 |
|
85 |
-
def
|
86 |
if len(state["remaining_questions"]) == 0 and state["search_only"] :
|
87 |
return END
|
88 |
elif len(state["remaining_questions"]) > 0:
|
89 |
return "retrieve_documents"
|
90 |
else:
|
91 |
return "answer_search"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
def make_id_dict(values):
|
94 |
return {k:k for k in values}
|
95 |
|
96 |
-
def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, reranker, threshold_docs=0.2):
|
97 |
|
98 |
workflow = StateGraph(GraphState)
|
99 |
|
@@ -103,8 +117,9 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, reranker, thresh
|
|
103 |
translate_query = make_translation_node(llm)
|
104 |
answer_chitchat = make_chitchat_node(llm)
|
105 |
answer_ai_impact = make_ai_impact_node(llm)
|
106 |
-
retrieve_documents =
|
107 |
retrieve_graphs = make_graph_retriever_node(vectorstore_graphs, reranker)
|
|
|
108 |
answer_rag = make_rag_node(llm, with_docs=True)
|
109 |
answer_rag_no_docs = make_rag_node(llm, with_docs=False)
|
110 |
chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
|
@@ -112,13 +127,14 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, reranker, thresh
|
|
112 |
# Define the nodes
|
113 |
# workflow.add_node("set_defaults", set_defaults)
|
114 |
workflow.add_node("categorize_intent", categorize_intent)
|
115 |
-
workflow.add_node("
|
116 |
workflow.add_node("answer_search", answer_search)
|
117 |
workflow.add_node("transform_query", transform_query)
|
118 |
workflow.add_node("translate_query", translate_query)
|
119 |
workflow.add_node("answer_chitchat", answer_chitchat)
|
120 |
workflow.add_node("chitchat_categorize_intent", chitchat_categorize_intent)
|
121 |
workflow.add_node("retrieve_graphs", retrieve_graphs)
|
|
|
122 |
workflow.add_node("retrieve_graphs_chitchat", retrieve_graphs)
|
123 |
workflow.add_node("retrieve_documents", retrieve_documents)
|
124 |
workflow.add_node("answer_rag", answer_rag)
|
@@ -131,7 +147,7 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, reranker, thresh
|
|
131 |
workflow.add_conditional_edges(
|
132 |
"categorize_intent",
|
133 |
route_intent,
|
134 |
-
make_id_dict(["answer_chitchat","
|
135 |
)
|
136 |
|
137 |
workflow.add_conditional_edges(
|
@@ -141,14 +157,14 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, reranker, thresh
|
|
141 |
)
|
142 |
|
143 |
workflow.add_conditional_edges(
|
144 |
-
"
|
145 |
route_translation,
|
146 |
make_id_dict(["translate_query","transform_query"])
|
147 |
)
|
148 |
workflow.add_conditional_edges(
|
149 |
"retrieve_documents",
|
150 |
# lambda state : "retrieve_documents" if len(state["remaining_questions"]) > 0 else "answer_search",
|
151 |
-
|
152 |
make_id_dict([END,"retrieve_documents","answer_search"])
|
153 |
)
|
154 |
|
@@ -159,9 +175,16 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, reranker, thresh
|
|
159 |
)
|
160 |
workflow.add_conditional_edges(
|
161 |
"transform_query",
|
162 |
-
|
163 |
-
make_id_dict(["retrieve_graphs", END])
|
164 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
|
166 |
# Define the edges
|
167 |
workflow.add_edge("translate_query", "transform_query")
|
@@ -172,7 +195,7 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, reranker, thresh
|
|
172 |
workflow.add_edge("answer_rag_no_docs", END)
|
173 |
workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
|
174 |
workflow.add_edge("retrieve_graphs_chitchat", END)
|
175 |
-
|
176 |
|
177 |
# Compile
|
178 |
app = workflow.compile()
|
|
|
16 |
from .chains.query_transformation import make_query_transform_node
|
17 |
from .chains.translation import make_translation_node
|
18 |
from .chains.intent_categorization import make_intent_categorization_node
|
19 |
+
from .chains.retrieve_documents import make_IPx_retriever_node, make_POC_retriever_node
|
20 |
from .chains.answer_rag import make_rag_node
|
21 |
from .chains.graph_retriever import make_graph_retriever_node
|
22 |
from .chains.chitchat_categorization import make_chitchat_intent_categorization_node
|
|
|
46 |
search_only : bool = False
|
47 |
reports : List[str] = []
|
48 |
|
49 |
+
def dummy(state):
|
50 |
+
return state
|
51 |
+
|
52 |
def search(state): #TODO
|
53 |
return state
|
54 |
|
|
|
63 |
# return "answer_ai_impact"
|
64 |
else:
|
65 |
# Search route
|
66 |
+
return "answer_climate"
|
67 |
|
68 |
def chitchat_route_intent(state):
|
69 |
intent = state["search_graphs_chitchat"]
|
|
|
85 |
else:
|
86 |
return "answer_rag_no_docs"
|
87 |
|
88 |
+
def route_continue_retrieve_documents(state):
|
89 |
if len(state["remaining_questions"]) == 0 and state["search_only"] :
|
90 |
return END
|
91 |
elif len(state["remaining_questions"]) > 0:
|
92 |
return "retrieve_documents"
|
93 |
else:
|
94 |
return "answer_search"
|
95 |
+
|
96 |
+
def route_retrieve_documents(state):
|
97 |
+
sources_to_retrieve = []
|
98 |
+
|
99 |
+
if "Graphs (OurWorldInData)" in state["relevant_content_sources_selection"] :
|
100 |
+
sources_to_retrieve.append("retrieve_graphs")
|
101 |
+
if "POC region" in state["relevant_content_sources_selection"] :
|
102 |
+
sources_to_retrieve.append("retrieve_local_data")
|
103 |
+
if sources_to_retrieve == []:
|
104 |
+
return END
|
105 |
+
return sources_to_retrieve
|
106 |
|
107 |
def make_id_dict(values):
|
108 |
return {k:k for k in values}
|
109 |
|
110 |
+
def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_region, reranker, threshold_docs=0.2):
|
111 |
|
112 |
workflow = StateGraph(GraphState)
|
113 |
|
|
|
117 |
translate_query = make_translation_node(llm)
|
118 |
answer_chitchat = make_chitchat_node(llm)
|
119 |
answer_ai_impact = make_ai_impact_node(llm)
|
120 |
+
retrieve_documents = make_IPx_retriever_node(vectorstore_ipcc, reranker, llm)
|
121 |
retrieve_graphs = make_graph_retriever_node(vectorstore_graphs, reranker)
|
122 |
+
retrieve_local_data = make_POC_retriever_node(vectorstore_region, reranker, llm)
|
123 |
answer_rag = make_rag_node(llm, with_docs=True)
|
124 |
answer_rag_no_docs = make_rag_node(llm, with_docs=False)
|
125 |
chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
|
|
|
127 |
# Define the nodes
|
128 |
# workflow.add_node("set_defaults", set_defaults)
|
129 |
workflow.add_node("categorize_intent", categorize_intent)
|
130 |
+
workflow.add_node("answer_climate", dummy)
|
131 |
workflow.add_node("answer_search", answer_search)
|
132 |
workflow.add_node("transform_query", transform_query)
|
133 |
workflow.add_node("translate_query", translate_query)
|
134 |
workflow.add_node("answer_chitchat", answer_chitchat)
|
135 |
workflow.add_node("chitchat_categorize_intent", chitchat_categorize_intent)
|
136 |
workflow.add_node("retrieve_graphs", retrieve_graphs)
|
137 |
+
workflow.add_node("retrieve_local_data", retrieve_local_data)
|
138 |
workflow.add_node("retrieve_graphs_chitchat", retrieve_graphs)
|
139 |
workflow.add_node("retrieve_documents", retrieve_documents)
|
140 |
workflow.add_node("answer_rag", answer_rag)
|
|
|
147 |
workflow.add_conditional_edges(
|
148 |
"categorize_intent",
|
149 |
route_intent,
|
150 |
+
make_id_dict(["answer_chitchat","answer_climate"])
|
151 |
)
|
152 |
|
153 |
workflow.add_conditional_edges(
|
|
|
157 |
)
|
158 |
|
159 |
workflow.add_conditional_edges(
|
160 |
+
"answer_climate",
|
161 |
route_translation,
|
162 |
make_id_dict(["translate_query","transform_query"])
|
163 |
)
|
164 |
workflow.add_conditional_edges(
|
165 |
"retrieve_documents",
|
166 |
# lambda state : "retrieve_documents" if len(state["remaining_questions"]) > 0 else "answer_search",
|
167 |
+
route_continue_retrieve_documents,
|
168 |
make_id_dict([END,"retrieve_documents","answer_search"])
|
169 |
)
|
170 |
|
|
|
175 |
)
|
176 |
workflow.add_conditional_edges(
|
177 |
"transform_query",
|
178 |
+
route_retrieve_documents,
|
179 |
+
make_id_dict(["retrieve_graphs","retrieve_local_data", END])
|
180 |
)
|
181 |
+
|
182 |
+
|
183 |
+
# workflow.add_conditional_edges(
|
184 |
+
# "transform_query",
|
185 |
+
# lambda state : "retrieve_graphs" if "POC region" in state["relevant_content_sources_selection"] else END,
|
186 |
+
# make_id_dict(["retrieve_local_data", END])
|
187 |
+
# )
|
188 |
|
189 |
# Define the edges
|
190 |
workflow.add_edge("translate_query", "transform_query")
|
|
|
195 |
workflow.add_edge("answer_rag_no_docs", END)
|
196 |
workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
|
197 |
workflow.add_edge("retrieve_graphs_chitchat", END)
|
198 |
+
workflow.add_edge("retrieve_local_data", "answer_search")
|
199 |
|
200 |
# Compile
|
201 |
app = workflow.compile()
|
climateqa/engine/reranker.py
CHANGED
@@ -47,4 +47,9 @@ def rerank_docs(reranker,docs,query):
|
|
47 |
doc.metadata["reranking_score"] = result.score
|
48 |
doc.metadata["query_used_for_retrieval"] = query
|
49 |
docs_reranked.append(doc)
|
|
|
|
|
|
|
|
|
|
|
50 |
return docs_reranked
|
|
|
47 |
doc.metadata["reranking_score"] = result.score
|
48 |
doc.metadata["query_used_for_retrieval"] = query
|
49 |
docs_reranked.append(doc)
|
50 |
+
return docs_reranked
|
51 |
+
|
52 |
+
def rerank_and_sort_docs(reranker, docs, query):
|
53 |
+
docs_reranked = rerank_docs(reranker,docs,query)
|
54 |
+
docs_reranked = sorted(docs_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
|
55 |
return docs_reranked
|