timeki commited on
Commit
1305361
·
1 Parent(s): 6093b14

small cleaning

Browse files
app.py CHANGED
@@ -115,14 +115,13 @@ def update_config_modal_visibility(config_open):
115
  return gr.update(visible=new_config_visibility_status), new_config_visibility_status
116
 
117
  # Main chat function
118
- # async def chat(query, history, audience, sources, reports, relevant_content_sources, search_only):
119
  async def chat(
120
  query: str,
121
  history: list[ChatMessage],
122
  audience: str,
123
  sources: list[str],
124
  reports: list[str],
125
- relevant_content_sources: list[str],
126
  search_only: bool
127
  ) -> tuple[list, str, str, str, list, str]:
128
  """Process a chat query and return response with relevant sources and visualizations.
@@ -133,7 +132,7 @@ async def chat(
133
  audience (str): Target audience type
134
  sources (list): Knowledge base sources to search
135
  reports (list): Specific reports to search within sources
136
- relevant_content_sources (list): Types of content to retrieve (figures, papers, etc)
137
  search_only (bool): Whether to only search without generating answer
138
 
139
  Yields:
@@ -158,7 +157,7 @@ async def chat(
158
  "user_input": query,
159
  "audience": audience_prompt,
160
  "sources_input": sources,
161
- "relevant_content_sources": relevant_content_sources,
162
  "search_only": search_only,
163
  "reports": reports
164
  }
@@ -168,7 +167,6 @@ async def chat(
168
 
169
  # Initialize state variables
170
  docs = []
171
- used_figures = []
172
  related_contents = []
173
  docs_html = ""
174
  output_query = ""
@@ -176,7 +174,6 @@ async def chat(
176
  output_keywords = ""
177
  start_streaming = False
178
  graphs_html = ""
179
- figures = '<div class="figures-container"><p></p> </div>'
180
  used_documents = []
181
  answer_message_content = ""
182
 
@@ -236,7 +233,7 @@ async def chat(
236
  sub_questions = [q["question"] for q in event["data"]["output"]["remaining_questions"]]
237
  history[-1].content += "Decompose question into sub-questions:\n\n - " + "\n - ".join(sub_questions)
238
 
239
- yield history, docs_html, output_query, output_language, related_contents, graphs_html #,output_query,output_keywords
240
 
241
  except Exception as e:
242
  print(f"Event {event} has failed")
 
115
  return gr.update(visible=new_config_visibility_status), new_config_visibility_status
116
 
117
  # Main chat function
 
118
  async def chat(
119
  query: str,
120
  history: list[ChatMessage],
121
  audience: str,
122
  sources: list[str],
123
  reports: list[str],
124
+ relevant_content_sources_selection: list[str],
125
  search_only: bool
126
  ) -> tuple[list, str, str, str, list, str]:
127
  """Process a chat query and return response with relevant sources and visualizations.
 
132
  audience (str): Target audience type
133
  sources (list): Knowledge base sources to search
134
  reports (list): Specific reports to search within sources
135
+ relevant_content_sources_selection (list): Types of content to retrieve (figures, papers, etc)
136
  search_only (bool): Whether to only search without generating answer
137
 
138
  Yields:
 
157
  "user_input": query,
158
  "audience": audience_prompt,
159
  "sources_input": sources,
160
+ "relevant_content_sources_selection": relevant_content_sources_selection,
161
  "search_only": search_only,
162
  "reports": reports
163
  }
 
167
 
168
  # Initialize state variables
169
  docs = []
 
170
  related_contents = []
171
  docs_html = ""
172
  output_query = ""
 
174
  output_keywords = ""
175
  start_streaming = False
176
  graphs_html = ""
 
177
  used_documents = []
178
  answer_message_content = ""
179
 
 
233
  sub_questions = [q["question"] for q in event["data"]["output"]["remaining_questions"]]
234
  history[-1].content += "Decompose question into sub-questions:\n\n - " + "\n - ".join(sub_questions)
235
 
236
+ yield history, docs_html, output_query, output_language, related_contents, graphs_html
237
 
238
  except Exception as e:
239
  print(f"Event {event} has failed")
climateqa/engine/chains/retrieve_documents.py CHANGED
@@ -213,20 +213,10 @@ async def retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_qu
213
  dict: The updated state containing the retrieved and reranked documents, related content, and remaining questions.
214
  """
215
  print("---- Retrieve documents ----")
 
 
216
 
217
- # Get the documents from the state
218
- if "documents" in state and state["documents"] is not None:
219
- docs = state["documents"]
220
- else:
221
- docs = []
222
-
223
- # Get the related_content from the state
224
- if "related_content" in state and state["related_content"] is not None:
225
- related_content = state["related_content"]
226
- else:
227
- related_content = []
228
-
229
- search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources"]
230
  search_only = state["search_only"]
231
 
232
  reports = state["reports"]
 
213
  dict: The updated state containing the retrieved and reranked documents, related content, and remaining questions.
214
  """
215
  print("---- Retrieve documents ----")
216
+ docs = state.get("documents", [])
217
+ related_content = state.get("related_content", [])
218
 
219
+ search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"]
 
 
 
 
 
 
 
 
 
 
 
 
220
  search_only = state["search_only"]
221
 
222
  reports = state["reports"]
climateqa/engine/chains/retrieve_papers.py CHANGED
@@ -32,8 +32,8 @@ def generate_keywords(query):
32
  return keywords
33
 
34
 
35
- async def find_papers(query,after, relevant_content_sources, reranker= reranker):
36
- if "Papers (OpenAlex)" in relevant_content_sources:
37
  summary = ""
38
  keywords = generate_keywords(query)
39
  df_works = oa.search(keywords,after = after)
 
32
  return keywords
33
 
34
 
35
+ async def find_papers(query,after, relevant_content_sources_selection, reranker= reranker):
36
+ if "Papers (OpenAlex)" in relevant_content_sources_selection:
37
  summary = ""
38
  keywords = generate_keywords(query)
39
  df_works = oa.search(keywords,after = after)
climateqa/engine/graph.py CHANGED
@@ -36,12 +36,12 @@ class GraphState(TypedDict):
36
  answer: str
37
  audience: str = "experts"
38
  sources_input: List[str] = ["IPCC","IPBES"]
39
- relevant_content_sources: List[str] = ["Figures (IPCC/IPBES)"]
40
  sources_auto: bool = True
41
  min_year: int = 1960
42
  max_year: int = None
43
  documents: List[Document]
44
- related_contents : Dict[str,Document]
45
  recommended_content : List[Document]
46
  search_only : bool = False
47
  reports : List[str] = []
@@ -159,7 +159,7 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, reranker, thresh
159
  )
160
  workflow.add_conditional_edges(
161
  "transform_query",
162
- lambda state : "retrieve_graphs" if "Graphs (OurWorldInData)" in state["relevant_content_sources"] else END,
163
  make_id_dict(["retrieve_graphs", END])
164
  )
165
 
 
36
  answer: str
37
  audience: str = "experts"
38
  sources_input: List[str] = ["IPCC","IPBES"]
39
+ relevant_content_sources_selection: List[str] = ["Figures (IPCC/IPBES)"]
40
  sources_auto: bool = True
41
  min_year: int = 1960
42
  max_year: int = None
43
  documents: List[Document]
44
+ related_contents : List[Document]
45
  recommended_content : List[Document]
46
  search_only : bool = False
47
  reports : List[str] = []
 
159
  )
160
  workflow.add_conditional_edges(
161
  "transform_query",
162
+ lambda state : "retrieve_graphs" if "Graphs (OurWorldInData)" in state["relevant_content_sources_selection"] else END,
163
  make_id_dict(["retrieve_graphs", END])
164
  )
165