timeki commited on
Commit
6f5e8e0
·
1 Parent(s): 55db9a8

take the last question as history to understand the question

Browse files
climateqa/chat.py CHANGED
@@ -101,6 +101,7 @@ async def chat_stream(
101
  audience_prompt = init_audience(audience)
102
  sources = sources or ["IPCC", "IPBES"]
103
  reports = reports or []
 
104
 
105
  # Prepare inputs for agent
106
  inputs = {
@@ -109,7 +110,8 @@ async def chat_stream(
109
  "sources_input": sources,
110
  "relevant_content_sources_selection": relevant_content_sources_selection,
111
  "search_only": search_only,
112
- "reports": reports
 
113
  }
114
 
115
  # Get streaming events from agent
 
101
  audience_prompt = init_audience(audience)
102
  sources = sources or ["IPCC", "IPBES"]
103
  reports = reports or []
104
+ relevant_history_discussion = history[-2:] if len(history) > 1 else []
105
 
106
  # Prepare inputs for agent
107
  inputs = {
 
110
  "sources_input": sources,
111
  "relevant_content_sources_selection": relevant_content_sources_selection,
112
  "search_only": search_only,
113
+ "reports": reports,
114
+ "chat_history": relevant_history_discussion,
115
  }
116
 
117
  # Get streaming events from agent
climateqa/engine/chains/answer_rag.py CHANGED
@@ -65,6 +65,7 @@ def make_rag_node(llm,with_docs = True):
65
  async def answer_rag(state,config):
66
  print("---- Answer RAG ----")
67
  start_time = time.time()
 
68
  print("Sources used : " + "\n".join([x.metadata["short_name"] + " - page " + str(x.metadata["page_number"]) for x in state["documents"]]))
69
 
70
  answer = await rag_chain.ainvoke(state,config)
@@ -73,9 +74,10 @@ def make_rag_node(llm,with_docs = True):
73
  elapsed_time = end_time - start_time
74
  print("RAG elapsed time: ", elapsed_time)
75
  print("Answer size : ", len(answer))
76
- # print(f"\n\nAnswer:\n{answer}")
77
 
78
- return {"answer":answer}
 
 
79
 
80
  return answer_rag
81
 
 
65
  async def answer_rag(state,config):
66
  print("---- Answer RAG ----")
67
  start_time = time.time()
68
+ chat_history = state.get("chat_history",[])
69
  print("Sources used : " + "\n".join([x.metadata["short_name"] + " - page " + str(x.metadata["page_number"]) for x in state["documents"]]))
70
 
71
  answer = await rag_chain.ainvoke(state,config)
 
74
  elapsed_time = end_time - start_time
75
  print("RAG elapsed time: ", elapsed_time)
76
  print("Answer size : ", len(answer))
 
77
 
78
+ chat_history.append({"question":state["query"],"answer":answer})
79
+
80
+ return {"answer":answer,"chat_history": chat_history}
81
 
82
  return answer_rag
83
 
climateqa/engine/chains/standalone_question.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.prompts import ChatPromptTemplate
2
+
3
+ def make_standalone_question_chain(llm):
4
+ prompt = ChatPromptTemplate.from_messages([
5
+ ("system", """You are a helpful assistant that transforms user questions into standalone questions
6
+ by incorporating context from the chat history if needed. The output should be a self-contained
7
+ question that can be understood without any additional context.
8
+
9
+ Examples:
10
+ Chat History: "Let's talk about renewable energy"
11
+ User Input: "What about solar?"
12
+ Output: "What are the key aspects of solar energy as a renewable energy source?"
13
+
14
+ Chat History: "What causes global warming?"
15
+ User Input: "And what are its effects?"
16
+ Output: "What are the effects of global warming on the environment and society?"
17
+ """),
18
+ ("user", """Chat History: {chat_history}
19
+ User Question: {question}
20
+
21
+ Transform this into a standalone question:""")
22
+ ])
23
+
24
+ chain = prompt | llm
25
+ return chain
26
+
27
+ def make_standalone_question_node(llm):
28
+ standalone_chain = make_standalone_question_chain(llm)
29
+
30
+ def transform_to_standalone(state):
31
+ chat_history = state.get("chat_history", "")
32
+ output = standalone_chain.invoke({
33
+ "chat_history": chat_history,
34
+ "question": state["user_input"]
35
+ })
36
+ state["user_input"] = output.content
37
+ return state
38
+
39
+ return transform_to_standalone
climateqa/engine/graph.py CHANGED
@@ -23,13 +23,14 @@ from .chains.retrieve_documents import make_IPx_retriever_node, make_POC_retriev
23
  from .chains.answer_rag import make_rag_node
24
  from .chains.graph_retriever import make_graph_retriever_node
25
  from .chains.chitchat_categorization import make_chitchat_intent_categorization_node
26
- # from .chains.set_defaults import set_defaults
27
 
28
  class GraphState(TypedDict):
29
  """
30
  Represents the state of our graph.
31
  """
32
  user_input : str
 
33
  language : str
34
  intent : str
35
  search_graphs_chitchat : bool
@@ -128,6 +129,7 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_regi
128
  workflow = StateGraph(GraphState)
129
 
130
  # Define the node functions
 
131
  categorize_intent = make_intent_categorization_node(llm)
132
  transform_query = make_query_transform_node(llm)
133
  translate_query = make_translation_node(llm)
@@ -142,6 +144,7 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_regi
142
 
143
  # Define the nodes
144
  # workflow.add_node("set_defaults", set_defaults)
 
145
  workflow.add_node("categorize_intent", categorize_intent)
146
  workflow.add_node("answer_climate", dummy)
147
  workflow.add_node("answer_search", answer_search)
@@ -157,7 +160,7 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_regi
157
  workflow.add_node("answer_rag_no_docs", answer_rag_no_docs)
158
 
159
  # Entry point
160
- workflow.set_entry_point("categorize_intent")
161
 
162
  # CONDITIONAL EDGES
163
  workflow.add_conditional_edges(
@@ -190,6 +193,7 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_regi
190
  )
191
 
192
  # Define the edges
 
193
  workflow.add_edge("translate_query", "transform_query")
194
  workflow.add_edge("transform_query", "retrieve_documents") #TODO put back
195
  # workflow.add_edge("transform_query", "retrieve_local_data")
@@ -228,6 +232,8 @@ def make_graph_agent_poc(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_
228
  workflow = StateGraph(GraphState)
229
 
230
  # Define the node functions
 
 
231
  categorize_intent = make_intent_categorization_node(llm)
232
  transform_query = make_query_transform_node(llm)
233
  translate_query = make_translation_node(llm)
@@ -243,6 +249,7 @@ def make_graph_agent_poc(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_
243
 
244
  # Define the nodes
245
  # workflow.add_node("set_defaults", set_defaults)
 
246
  workflow.add_node("categorize_intent", categorize_intent)
247
  workflow.add_node("answer_climate", dummy)
248
  workflow.add_node("answer_search", answer_search)
@@ -260,7 +267,7 @@ def make_graph_agent_poc(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_
260
  workflow.add_node("answer_rag_no_docs", answer_rag_no_docs)
261
 
262
  # Entry point
263
- workflow.set_entry_point("categorize_intent")
264
 
265
  # CONDITIONAL EDGES
266
  workflow.add_conditional_edges(
@@ -293,6 +300,7 @@ def make_graph_agent_poc(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_
293
  )
294
 
295
  # Define the edges
 
296
  workflow.add_edge("translate_query", "transform_query")
297
  workflow.add_edge("transform_query", "retrieve_documents") #TODO put back
298
  workflow.add_edge("transform_query", "retrieve_local_data")
 
23
  from .chains.answer_rag import make_rag_node
24
  from .chains.graph_retriever import make_graph_retriever_node
25
  from .chains.chitchat_categorization import make_chitchat_intent_categorization_node
26
+ from .chains.standalone_question import make_standalone_question_node
27
 
28
  class GraphState(TypedDict):
29
  """
30
  Represents the state of our graph.
31
  """
32
  user_input : str
33
+ chat_history : str
34
  language : str
35
  intent : str
36
  search_graphs_chitchat : bool
 
129
  workflow = StateGraph(GraphState)
130
 
131
  # Define the node functions
132
+ standalone_question_node = make_standalone_question_node(llm)
133
  categorize_intent = make_intent_categorization_node(llm)
134
  transform_query = make_query_transform_node(llm)
135
  translate_query = make_translation_node(llm)
 
144
 
145
  # Define the nodes
146
  # workflow.add_node("set_defaults", set_defaults)
147
+ workflow.add_node("standalone_question", standalone_question_node)
148
  workflow.add_node("categorize_intent", categorize_intent)
149
  workflow.add_node("answer_climate", dummy)
150
  workflow.add_node("answer_search", answer_search)
 
160
  workflow.add_node("answer_rag_no_docs", answer_rag_no_docs)
161
 
162
  # Entry point
163
+ workflow.set_entry_point("standalone_question")
164
 
165
  # CONDITIONAL EDGES
166
  workflow.add_conditional_edges(
 
193
  )
194
 
195
  # Define the edges
196
+ workflow.add_edge("standalone_question", "categorize_intent")
197
  workflow.add_edge("translate_query", "transform_query")
198
  workflow.add_edge("transform_query", "retrieve_documents") #TODO put back
199
  # workflow.add_edge("transform_query", "retrieve_local_data")
 
232
  workflow = StateGraph(GraphState)
233
 
234
  # Define the node functions
235
+ standalone_question_node = make_standalone_question_node(llm)
236
+
237
  categorize_intent = make_intent_categorization_node(llm)
238
  transform_query = make_query_transform_node(llm)
239
  translate_query = make_translation_node(llm)
 
249
 
250
  # Define the nodes
251
  # workflow.add_node("set_defaults", set_defaults)
252
+ workflow.add_node("standalone_question", standalone_question_node)
253
  workflow.add_node("categorize_intent", categorize_intent)
254
  workflow.add_node("answer_climate", dummy)
255
  workflow.add_node("answer_search", answer_search)
 
267
  workflow.add_node("answer_rag_no_docs", answer_rag_no_docs)
268
 
269
  # Entry point
270
+ workflow.set_entry_point("standalone_question")
271
 
272
  # CONDITIONAL EDGES
273
  workflow.add_conditional_edges(
 
300
  )
301
 
302
  # Define the edges
303
+ workflow.add_edge("standalone_question", "categorize_intent")
304
  workflow.add_edge("translate_query", "transform_query")
305
  workflow.add_edge("transform_query", "retrieve_documents") #TODO put back
306
  workflow.add_edge("transform_query", "retrieve_local_data")