timeki commited on
Commit
cf2e6bf
·
1 Parent(s): 6f5e8e0

Add follow up questions

Browse files
app.py CHANGED
@@ -176,6 +176,8 @@ def event_handling(
176
  tab_graphs = main_tab_components.tab_graphs
177
  tab_papers = main_tab_components.tab_papers
178
  graphs_container = main_tab_components.graph_container
 
 
179
 
180
  dropdown_sources = config_components.dropdown_sources
181
  dropdown_reports = config_components.dropdown_reports
@@ -196,15 +198,20 @@ def event_handling(
196
  # Event for textbox
197
  (textbox
198
  .submit(start_chat, [textbox, chatbot, search_only], [textbox, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{textbox.elem_id}")
199
- .then(chat, [textbox, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name=f"chat_{textbox.elem_id}")
200
  .then(finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}")
201
  )
202
  # Event for examples_hidden
203
  (examples_hidden
204
  .change(start_chat, [examples_hidden, chatbot, search_only], [examples_hidden, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{examples_hidden.elem_id}")
205
- .then(chat, [examples_hidden, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name=f"chat_{examples_hidden.elem_id}")
206
  .then(finish_chat, None, [textbox], api_name=f"finish_chat_{examples_hidden.elem_id}")
207
  )
 
 
 
 
 
208
 
209
  elif tab_name == "Beta - POC Adapt'Action":
210
  print("chat poc - message sent")
 
176
  tab_graphs = main_tab_components.tab_graphs
177
  tab_papers = main_tab_components.tab_papers
178
  graphs_container = main_tab_components.graph_container
179
+ follow_up_examples = main_tab_components.follow_up_examples
180
+ follow_up_examples_hidden = main_tab_components.follow_up_examples_hidden
181
 
182
  dropdown_sources = config_components.dropdown_sources
183
  dropdown_reports = config_components.dropdown_reports
 
198
  # Event for textbox
199
  (textbox
200
  .submit(start_chat, [textbox, chatbot, search_only], [textbox, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{textbox.elem_id}")
201
+ .then(chat, [textbox, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs, follow_up_examples.dataset], concurrency_limit=8, api_name=f"chat_{textbox.elem_id}")
202
  .then(finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}")
203
  )
204
  # Event for examples_hidden
205
  (examples_hidden
206
  .change(start_chat, [examples_hidden, chatbot, search_only], [examples_hidden, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{examples_hidden.elem_id}")
207
+ .then(chat, [examples_hidden, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs,follow_up_examples.dataset], concurrency_limit=8, api_name=f"chat_{examples_hidden.elem_id}")
208
  .then(finish_chat, None, [textbox], api_name=f"finish_chat_{examples_hidden.elem_id}")
209
  )
210
+ (follow_up_examples_hidden
211
+ .change(start_chat, [examples_hidden, chatbot, search_only], [follow_up_examples_hidden, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{examples_hidden.elem_id}")
212
+ .then(chat, [examples_hidden, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs,follow_up_examples.dataset], concurrency_limit=8, api_name=f"chat_{examples_hidden.elem_id}")
213
+ .then(finish_chat, None, [textbox], api_name=f"finish_chat_{follow_up_examples_hidden.elem_id}")
214
+ )
215
 
216
  elif tab_name == "Beta - POC Adapt'Action":
217
  print("chat poc - message sent")
climateqa/chat.py CHANGED
@@ -131,6 +131,7 @@ async def chat_stream(
131
  retrieved_contents = []
132
  answer_message_content = ""
133
  vanna_data = {}
 
134
 
135
  # Define processing steps
136
  steps_display = {
@@ -202,7 +203,12 @@ async def chat_stream(
202
  sub_questions = [q["question"] + "-> relevant sources : " + str(q["sources"]) for q in event["data"]["output"]["questions_list"]]
203
  history[-1].content += "Decompose question into sub-questions:\n\n - " + "\n - ".join(sub_questions)
204
 
205
- yield history, docs_html, output_query, output_language, related_contents, graphs_html#, vanna_data
 
 
 
 
 
206
 
207
  except Exception as e:
208
  print(f"Event {event} has failed")
@@ -213,4 +219,4 @@ async def chat_stream(
213
  # Call the function to log interaction
214
  log_interaction_to_azure(history, output_query, sources, docs, share_client, user_id)
215
 
216
- yield history, docs_html, output_query, output_language, related_contents, graphs_html#, vanna_data
 
131
  retrieved_contents = []
132
  answer_message_content = ""
133
  vanna_data = {}
134
+ follow_up_examples = gr.Dataset(samples=[])
135
 
136
  # Define processing steps
137
  steps_display = {
 
203
  sub_questions = [q["question"] + "-> relevant sources : " + str(q["sources"]) for q in event["data"]["output"]["questions_list"]]
204
  history[-1].content += "Decompose question into sub-questions:\n\n - " + "\n - ".join(sub_questions)
205
 
206
+ # Handle follow up questions
207
+ if event["name"] == "generate_follow_up" and event["event"] == "on_chain_end":
208
+ follow_up_examples = event["data"]["output"].get("follow_up_questions", [])
209
+ follow_up_examples = gr.Dataset(samples= [ [question] for question in follow_up_examples ])
210
+
211
+ yield history, docs_html, output_query, output_language, related_contents, graphs_html, follow_up_examples#, vanna_data
212
 
213
  except Exception as e:
214
  print(f"Event {event} has failed")
 
219
  # Call the function to log interaction
220
  log_interaction_to_azure(history, output_query, sources, docs, share_client, user_id)
221
 
222
+ yield history, docs_html, output_query, output_language, related_contents, graphs_html, follow_up_examples#, vanna_data
climateqa/engine/chains/follow_up.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from langchain.prompts import ChatPromptTemplate
3
+
4
+
5
+ FOLLOW_UP_TEMPLATE = """Based on the previous question and answer, generate 2-3 relevant follow-up questions that would help explore the topic further.
6
+
7
+ Previous Question: {user_input}
8
+ Previous Answer: {answer}
9
+
10
+ Generate short, concise, focused follow-up questions
11
+ You don't need a full question as it will be reformulated later as a standalone question with the context. Eg. "Details the first point"
12
+ """
13
+
14
+ def make_follow_up_node(llm):
15
+ prompt = ChatPromptTemplate.from_template(FOLLOW_UP_TEMPLATE)
16
+
17
+ def generate_follow_up(state):
18
+ if not state.get("answer"):
19
+ return state
20
+
21
+ response = llm.invoke(prompt.format(
22
+ user_input=state["user_input"],
23
+ answer=state["answer"]
24
+ ))
25
+
26
+ # Extract questions from response
27
+ follow_ups = [q.strip() for q in response.content.split("\n") if q.strip()]
28
+ state["follow_up_questions"] = follow_ups
29
+
30
+ return state
31
+
32
+ return generate_follow_up
climateqa/engine/graph.py CHANGED
@@ -24,6 +24,7 @@ from .chains.answer_rag import make_rag_node
24
  from .chains.graph_retriever import make_graph_retriever_node
25
  from .chains.chitchat_categorization import make_chitchat_intent_categorization_node
26
  from .chains.standalone_question import make_standalone_question_node
 
27
 
28
  class GraphState(TypedDict):
29
  """
@@ -50,6 +51,7 @@ class GraphState(TypedDict):
50
  recommended_content : List[Document] # OWID Graphs # TODO merge with related_contents
51
  search_only : bool = False
52
  reports : List[str] = []
 
53
 
54
  def dummy(state):
55
  return
@@ -121,6 +123,11 @@ def route_retrieve_documents(state):
121
  return END
122
  return sources_to_retrieve
123
 
 
 
 
 
 
124
  def make_id_dict(values):
125
  return {k:k for k in values}
126
 
@@ -141,6 +148,7 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_regi
141
  answer_rag = make_rag_node(llm, with_docs=True)
142
  answer_rag_no_docs = make_rag_node(llm, with_docs=False)
143
  chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
 
144
 
145
  # Define the nodes
146
  # workflow.add_node("set_defaults", set_defaults)
@@ -158,6 +166,8 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_regi
158
  workflow.add_node("retrieve_documents", retrieve_documents)
159
  workflow.add_node("answer_rag", answer_rag)
160
  workflow.add_node("answer_rag_no_docs", answer_rag_no_docs)
 
 
161
 
162
  # Entry point
163
  workflow.set_entry_point("standalone_question")
@@ -192,6 +202,12 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_regi
192
  make_id_dict(["retrieve_graphs", END])
193
  )
194
 
 
 
 
 
 
 
195
  # Define the edges
196
  workflow.add_edge("standalone_question", "categorize_intent")
197
  workflow.add_edge("translate_query", "transform_query")
@@ -200,13 +216,17 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_regi
200
  # workflow.add_edge("transform_query", END) # TODO remove
201
 
202
  workflow.add_edge("retrieve_graphs", END)
203
- workflow.add_edge("answer_rag", END)
204
- workflow.add_edge("answer_rag_no_docs", END)
 
 
205
  workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
206
  workflow.add_edge("retrieve_graphs_chitchat", END)
207
 
208
  # workflow.add_edge("retrieve_local_data", "answer_search")
209
  workflow.add_edge("retrieve_documents", "answer_search")
 
 
210
 
211
  # Compile
212
  app = workflow.compile()
@@ -246,6 +266,7 @@ def make_graph_agent_poc(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_
246
  answer_rag = make_rag_node(llm, with_docs=True)
247
  answer_rag_no_docs = make_rag_node(llm, with_docs=False)
248
  chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
 
249
 
250
  # Define the nodes
251
  # workflow.add_node("set_defaults", set_defaults)
@@ -265,6 +286,8 @@ def make_graph_agent_poc(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_
265
  workflow.add_node("retrieve_documents", retrieve_documents)
266
  workflow.add_node("answer_rag", answer_rag)
267
  workflow.add_node("answer_rag_no_docs", answer_rag_no_docs)
 
 
268
 
269
  # Entry point
270
  workflow.set_entry_point("standalone_question")
@@ -299,6 +322,12 @@ def make_graph_agent_poc(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_
299
  make_id_dict(["retrieve_graphs", END])
300
  )
301
 
 
 
 
 
 
 
302
  # Define the edges
303
  workflow.add_edge("standalone_question", "categorize_intent")
304
  workflow.add_edge("translate_query", "transform_query")
@@ -307,6 +336,8 @@ def make_graph_agent_poc(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_
307
  # workflow.add_edge("transform_query", END) # TODO remove
308
 
309
  workflow.add_edge("retrieve_graphs", END)
 
 
310
  workflow.add_edge("answer_rag", END)
311
  workflow.add_edge("answer_rag_no_docs", END)
312
  workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
@@ -314,10 +345,7 @@ def make_graph_agent_poc(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_
314
 
315
  workflow.add_edge("retrieve_local_data", "answer_search")
316
  workflow.add_edge("retrieve_documents", "answer_search")
317
-
318
- # workflow.add_edge("transform_query", "retrieve_drias_data")
319
- # workflow.add_edge("retrieve_drias_data", END)
320
-
321
 
322
  # Compile
323
  app = workflow.compile()
 
24
  from .chains.graph_retriever import make_graph_retriever_node
25
  from .chains.chitchat_categorization import make_chitchat_intent_categorization_node
26
  from .chains.standalone_question import make_standalone_question_node
27
+ from .chains.follow_up import make_follow_up_node # Add this import
28
 
29
  class GraphState(TypedDict):
30
  """
 
51
  recommended_content : List[Document] # OWID Graphs # TODO merge with related_contents
52
  search_only : bool = False
53
  reports : List[str] = []
54
+ follow_up_questions: List[str] = []
55
 
56
  def dummy(state):
57
  return
 
123
  return END
124
  return sources_to_retrieve
125
 
126
+ def route_follow_up(state):
127
+ if state["follow_up_questions"]:
128
+ return "process_follow_up"
129
+ return END
130
+
131
  def make_id_dict(values):
132
  return {k:k for k in values}
133
 
 
148
  answer_rag = make_rag_node(llm, with_docs=True)
149
  answer_rag_no_docs = make_rag_node(llm, with_docs=False)
150
  chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
151
+ generate_follow_up = make_follow_up_node(llm)
152
 
153
  # Define the nodes
154
  # workflow.add_node("set_defaults", set_defaults)
 
166
  workflow.add_node("retrieve_documents", retrieve_documents)
167
  workflow.add_node("answer_rag", answer_rag)
168
  workflow.add_node("answer_rag_no_docs", answer_rag_no_docs)
169
+ workflow.add_node("generate_follow_up", generate_follow_up)
170
+ # workflow.add_node("process_follow_up", standalone_question_node)
171
 
172
  # Entry point
173
  workflow.set_entry_point("standalone_question")
 
202
  make_id_dict(["retrieve_graphs", END])
203
  )
204
 
205
+ # workflow.add_conditional_edges(
206
+ # "generate_follow_up",
207
+ # route_follow_up,
208
+ # make_id_dict(["process_follow_up", END])
209
+ # )
210
+
211
  # Define the edges
212
  workflow.add_edge("standalone_question", "categorize_intent")
213
  workflow.add_edge("translate_query", "transform_query")
 
216
  # workflow.add_edge("transform_query", END) # TODO remove
217
 
218
  workflow.add_edge("retrieve_graphs", END)
219
+ workflow.add_edge("answer_rag", "generate_follow_up")
220
+ workflow.add_edge("answer_rag_no_docs", "generate_follow_up")
221
+ # workflow.add_edge("answer_rag", END)
222
+ # workflow.add_edge("answer_rag_no_docs", END)
223
  workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
224
  workflow.add_edge("retrieve_graphs_chitchat", END)
225
 
226
  # workflow.add_edge("retrieve_local_data", "answer_search")
227
  workflow.add_edge("retrieve_documents", "answer_search")
228
+ workflow.add_edge("generate_follow_up",END)
229
+ # workflow.add_edge("process_follow_up", "categorize_intent")
230
 
231
  # Compile
232
  app = workflow.compile()
 
266
  answer_rag = make_rag_node(llm, with_docs=True)
267
  answer_rag_no_docs = make_rag_node(llm, with_docs=False)
268
  chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
269
+ generate_follow_up = make_follow_up_node(llm)
270
 
271
  # Define the nodes
272
  # workflow.add_node("set_defaults", set_defaults)
 
286
  workflow.add_node("retrieve_documents", retrieve_documents)
287
  workflow.add_node("answer_rag", answer_rag)
288
  workflow.add_node("answer_rag_no_docs", answer_rag_no_docs)
289
+ workflow.add_node("generate_follow_up", generate_follow_up)
290
+ workflow.add_node("process_follow_up", standalone_question_node)
291
 
292
  # Entry point
293
  workflow.set_entry_point("standalone_question")
 
322
  make_id_dict(["retrieve_graphs", END])
323
  )
324
 
325
+ workflow.add_conditional_edges(
326
+ "generate_follow_up",
327
+ route_follow_up,
328
+ make_id_dict(["process_follow_up", END])
329
+ )
330
+
331
  # Define the edges
332
  workflow.add_edge("standalone_question", "categorize_intent")
333
  workflow.add_edge("translate_query", "transform_query")
 
336
  # workflow.add_edge("transform_query", END) # TODO remove
337
 
338
  workflow.add_edge("retrieve_graphs", END)
339
+ workflow.add_edge("answer_rag", "generate_follow_up")
340
+ workflow.add_edge("answer_rag_no_docs", "generate_follow_up")
341
  workflow.add_edge("answer_rag", END)
342
  workflow.add_edge("answer_rag_no_docs", END)
343
  workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
 
345
 
346
  workflow.add_edge("retrieve_local_data", "answer_search")
347
  workflow.add_edge("retrieve_documents", "answer_search")
348
+ workflow.add_edge("process_follow_up", "categorize_intent")
 
 
 
349
 
350
  # Compile
351
  app = workflow.compile()
front/tabs/chat_interface.py CHANGED
@@ -54,7 +54,10 @@ def create_chat_interface(tab):
54
  max_height="80vh",
55
  height="100vh"
56
  )
57
-
 
 
 
58
  with gr.Row(elem_id="input-message"):
59
 
60
  textbox = gr.Textbox(
@@ -68,7 +71,7 @@ def create_chat_interface(tab):
68
 
69
  config_button = gr.Button("", elem_id="config-button")
70
 
71
- return chatbot, textbox, config_button
72
 
73
 
74
 
 
54
  max_height="80vh",
55
  height="100vh"
56
  )
57
+ with gr.Row(elem_id="follow-up-examples"):
58
+ follow_up_examples_hidden = gr.Textbox(visible=False, elem_id="follow-up-hidden")
59
+ follow_up_examples = gr.Examples(examples=["sample_1","sample_2"], label="Follow up questions", inputs= [follow_up_examples_hidden], elem_id="follow-up-button", run_on_click=False)
60
+
61
  with gr.Row(elem_id="input-message"):
62
 
63
  textbox = gr.Textbox(
 
71
 
72
  config_button = gr.Button("", elem_id="config-button")
73
 
74
+ return chatbot, textbox, config_button, follow_up_examples, follow_up_examples_hidden
75
 
76
 
77
 
front/tabs/main_tab.py CHANGED
@@ -29,6 +29,8 @@ class MainTabPanel:
29
  tab_graphs: gr.Tab
30
  tab_papers: gr.Tab
31
  graph_container: gr.HTML
 
 
32
 
33
  def cqa_tab(tab_name):
34
  # State variables
@@ -37,7 +39,7 @@ def cqa_tab(tab_name):
37
  with gr.Row(elem_id="chatbot-row"):
38
  # Left column - Chat interface
39
  with gr.Column(scale=2):
40
- chatbot, textbox, config_button = create_chat_interface(tab_name)
41
 
42
  # Right column - Content panels
43
  with gr.Column(scale=2, variant="panel", elem_id="right-panel"):
@@ -91,5 +93,7 @@ def cqa_tab(tab_name):
91
  tab_figures=tab_figures,
92
  tab_graphs=tab_graphs,
93
  tab_papers=tab_papers,
94
- graph_container=graphs_container
 
 
95
  )
 
29
  tab_graphs: gr.Tab
30
  tab_papers: gr.Tab
31
  graph_container: gr.HTML
32
+ follow_up_examples : gr.Examples
33
+ follow_up_examples_hidden : gr.Textbox
34
 
35
  def cqa_tab(tab_name):
36
  # State variables
 
39
  with gr.Row(elem_id="chatbot-row"):
40
  # Left column - Chat interface
41
  with gr.Column(scale=2):
42
+ chatbot, textbox, config_button, follow_up_examples, follow_up_examples_hidden = create_chat_interface(tab_name)
43
 
44
  # Right column - Content panels
45
  with gr.Column(scale=2, variant="panel", elem_id="right-panel"):
 
93
  tab_figures=tab_figures,
94
  tab_graphs=tab_graphs,
95
  tab_papers=tab_papers,
96
+ graph_container=graphs_container,
97
+ follow_up_examples= follow_up_examples,
98
+ follow_up_examples_hidden = follow_up_examples_hidden
99
  )
style.css CHANGED
@@ -115,6 +115,11 @@ main.flex.flex-1.flex-col {
115
  border-radius: 40px;
116
  padding-left: 30px;
117
  resize: none;
 
 
 
 
 
118
  }
119
 
120
  #input-message > div {
@@ -474,6 +479,18 @@ a {
474
  text-decoration: none !important;
475
  }
476
 
 
 
 
 
 
 
 
 
 
 
 
 
477
  /* Media Queries */
478
  /* Desktop Media Query */
479
  @media screen and (min-width: 1024px) {
@@ -496,6 +513,15 @@ a {
496
  overflow-y: scroll !important;
497
  }
498
 
 
 
 
 
 
 
 
 
 
499
  div#chatbot-row {
500
  max-height: calc(100vh - 90px) !important;
501
  }
@@ -514,7 +540,11 @@ a {
514
  /* Mobile Media Query */
515
  @media screen and (max-width: 767px) {
516
  div#chatbot {
517
- height: 500px !important;
 
 
 
 
518
  }
519
 
520
  #submit-button {
 
115
  border-radius: 40px;
116
  padding-left: 30px;
117
  resize: none;
118
+ background-color: #f0f8ff; /* Light blue background */
119
+ border: 2px solid #4b8ec3; /* Blue border */
120
+ font-size: 16px; /* Increase font size */
121
+ color: #333; /* Text color */
122
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* Add shadow */
123
  }
124
 
125
  #input-message > div {
 
479
  text-decoration: none !important;
480
  }
481
 
482
+ /* Follow-up Examples Styles */
483
+ #follow-up-examples {
484
+ height: 15vh;
485
+ overflow-y: auto;
486
+ padding: 10px 0;
487
+ }
488
+
489
+ #follow-up-button {
490
+ height: 100%;
491
+ overflow-y: auto;
492
+ }
493
+
494
  /* Media Queries */
495
  /* Desktop Media Query */
496
  @media screen and (min-width: 1024px) {
 
513
  overflow-y: scroll !important;
514
  }
515
 
516
+ div#chatbot-row {
517
+ max-height: calc(100vh - 200px) !important;
518
+ }
519
+
520
+ div#chatbot {
521
+ height: 65vh !important;
522
+ max-height: 65vh !important;
523
+ }
524
+
525
  div#chatbot-row {
526
  max-height: calc(100vh - 90px) !important;
527
  }
 
540
  /* Mobile Media Query */
541
  @media screen and (max-width: 767px) {
542
  div#chatbot {
543
+ height: 400px !important; /* Reduced from 500px */
544
+ }
545
+
546
+ #follow-up-examples {
547
+ height: 100px;
548
  }
549
 
550
  #submit-button {