timeki commited on
Commit
6093b14
·
1 Parent(s): 3240e5c

Merge branch 'main' into feature/clean_code

Browse files
app.py CHANGED
@@ -104,7 +104,7 @@ embeddings_function = get_embeddings_function()
104
  vectorstore = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX"))
105
  vectorstore_graphs = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_OWID"), text_key="description")
106
 
107
- llm = get_llm(provider="openai", max_tokens=1024, temperature=0.0)
108
  reranker = get_reranker("nano")
109
 
110
  agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, reranker=reranker)
@@ -159,7 +159,8 @@ async def chat(
159
  "audience": audience_prompt,
160
  "sources_input": sources,
161
  "relevant_content_sources": relevant_content_sources,
162
- "search_only": search_only
 
163
  }
164
 
165
  # Get streaming events from agent
@@ -193,7 +194,7 @@ async def chat(
193
  node = event["metadata"]["langgraph_node"]
194
 
195
  # Handle document retrieval
196
- if event["event"] == "on_chain_end" and event["name"] == "retrieve_documents":
197
  docs, docs_html, history, used_documents, related_contents = handle_retrieved_documents(
198
  event, history, used_documents
199
  )
@@ -220,7 +221,7 @@ async def chat(
220
  # Handle answer streaming
221
  elif (event["name"] != "transform_query" and
222
  event["event"] == "on_chat_model_stream" and
223
- node in ["answer_rag", "answer_search", "answer_chitchat"]):
224
  history, start_streaming, answer_message_content = stream_answer(
225
  history, event, start_streaming, answer_message_content
226
  )
@@ -348,9 +349,9 @@ def change_sample_questions(key):
348
  def start_chat(query, history, search_only):
349
  history = history + [ChatMessage(role="user", content=query)]
350
  if not search_only:
351
- return (gr.update(interactive=False), gr.update(selected=1), history)
352
  else:
353
- return (gr.update(interactive=False), gr.update(selected=2), history)
354
 
355
  def finish_chat():
356
  return gr.update(interactive=True, value="")
@@ -378,7 +379,7 @@ def create_chat_interface():
378
  textbox = gr.Textbox(
379
  placeholder="Ask me anything here!",
380
  show_label=False,
381
- scale=7,
382
  lines=1,
383
  interactive=True,
384
  elem_id="input-textbox"
@@ -417,6 +418,8 @@ def create_examples_tab():
417
 
418
  def create_figures_tab():
419
  sources_raw = gr.State()
 
 
420
 
421
  with Modal(visible=False, elem_id="modal_figure_galery") as figure_modal:
422
  gallery_component = gr.Gallery(
@@ -438,7 +441,7 @@ def create_figures_tab():
438
 
439
  figures_cards = gr.HTML(show_label=False, elem_id="sources-figures")
440
 
441
- return sources_raw, gallery_component, figures_cards, figure_modal
442
 
443
  def create_papers_tab():
444
  with gr.Accordion(
@@ -492,9 +495,9 @@ def create_config_modal(config_open):
492
  )
493
 
494
  dropdown_external_sources = gr.CheckboxGroup(
495
- choices=["IPCC figures", "OpenAlex", "OurWorldInData"],
496
  label="Select database to search for relevant content",
497
- value=["IPCC figures"],
498
  interactive=True
499
  )
500
 
@@ -543,7 +546,7 @@ def create_config_modal(config_open):
543
  )
544
 
545
  dropdown_external_sources.change(
546
- lambda x: gr.update(visible="OpenAlex" in x),
547
  inputs=[dropdown_external_sources],
548
  outputs=[after]
549
  )
@@ -588,7 +591,7 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
588
  with gr.Tabs(elem_id="group-subtabs") as tabs_recommended_content:
589
  # Figures subtab
590
  with gr.Tab("Figures", elem_id="tab-figures", id=3) as tab_figures:
591
- sources_raw, gallery_component, figures_cards, figure_modal = create_figures_tab()
592
 
593
  # Papers subtab
594
  with gr.Tab("Papers", elem_id="tab-citations", id=4) as tab_papers:
@@ -641,18 +644,20 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
641
 
642
 
643
  (textbox
644
- .submit(start_chat, [textbox, chatbot, search_only], [textbox, tabs, chatbot], queue=False, api_name="start_chat_textbox")
645
- .then(chat, [textbox, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, sources_textbox, output_query, output_language, sources_raw, current_graphs], concurrency_limit=8, api_name="chat_textbox")
646
  .then(finish_chat, None, [textbox], api_name="finish_chat_textbox")
647
  )
648
 
 
 
649
  (examples_hidden
650
- .change(start_chat, [examples_hidden, chatbot, search_only], [textbox, tabs, chatbot], queue=False, api_name="start_chat_examples")
651
- .then(chat, [examples_hidden, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, sources_textbox, output_query, output_language, sources_raw, current_graphs], concurrency_limit=8, api_name="chat_textbox")
652
  .then(finish_chat, None, [textbox], api_name="finish_chat_examples")
653
  )
654
 
655
- sources_raw.change(process_figures, inputs=[sources_raw], outputs=[figures_cards, gallery_component])
656
 
657
  # Update sources numbers
658
  sources_textbox.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs, papers_html], [tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers])
 
104
  vectorstore = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX"))
105
  vectorstore_graphs = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_OWID"), text_key="description")
106
 
107
+ llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
108
  reranker = get_reranker("nano")
109
 
110
  agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, reranker=reranker)
 
159
  "audience": audience_prompt,
160
  "sources_input": sources,
161
  "relevant_content_sources": relevant_content_sources,
162
+ "search_only": search_only,
163
+ "reports": reports
164
  }
165
 
166
  # Get streaming events from agent
 
194
  node = event["metadata"]["langgraph_node"]
195
 
196
  # Handle document retrieval
197
+ if event["event"] == "on_chain_end" and event["name"] == "retrieve_documents" and event["data"]["output"] != None:
198
  docs, docs_html, history, used_documents, related_contents = handle_retrieved_documents(
199
  event, history, used_documents
200
  )
 
221
  # Handle answer streaming
222
  elif (event["name"] != "transform_query" and
223
  event["event"] == "on_chat_model_stream" and
224
+ node in ["answer_rag","answer_rag_no_docs", "answer_search", "answer_chitchat"]):
225
  history, start_streaming, answer_message_content = stream_answer(
226
  history, event, start_streaming, answer_message_content
227
  )
 
349
  def start_chat(query, history, search_only):
350
  history = history + [ChatMessage(role="user", content=query)]
351
  if not search_only:
352
+ return (gr.update(interactive=False), gr.update(selected=1), history, [])
353
  else:
354
+ return (gr.update(interactive=False), gr.update(selected=2), history, [])
355
 
356
  def finish_chat():
357
  return gr.update(interactive=True, value="")
 
379
  textbox = gr.Textbox(
380
  placeholder="Ask me anything here!",
381
  show_label=False,
382
+ scale=12,
383
  lines=1,
384
  interactive=True,
385
  elem_id="input-textbox"
 
418
 
419
  def create_figures_tab():
420
  sources_raw = gr.State()
421
+ new_figures = gr.State([])
422
+ used_figures = gr.State([])
423
 
424
  with Modal(visible=False, elem_id="modal_figure_galery") as figure_modal:
425
  gallery_component = gr.Gallery(
 
441
 
442
  figures_cards = gr.HTML(show_label=False, elem_id="sources-figures")
443
 
444
+ return sources_raw, new_figures, used_figures, gallery_component, figures_cards, figure_modal
445
 
446
  def create_papers_tab():
447
  with gr.Accordion(
 
495
  )
496
 
497
  dropdown_external_sources = gr.CheckboxGroup(
498
+ choices=["Figures (IPCC/IPBES)", "Papers (OpenAlex)", "Graphs (OurWorldInData)"],
499
  label="Select database to search for relevant content",
500
+ value=["Figures (IPCC/IPBES)"],
501
  interactive=True
502
  )
503
 
 
546
  )
547
 
548
  dropdown_external_sources.change(
549
+ lambda x: gr.update(visible="Papers (OpenAlex)" in x),
550
  inputs=[dropdown_external_sources],
551
  outputs=[after]
552
  )
 
591
  with gr.Tabs(elem_id="group-subtabs") as tabs_recommended_content:
592
  # Figures subtab
593
  with gr.Tab("Figures", elem_id="tab-figures", id=3) as tab_figures:
594
+ sources_raw, new_figures, used_figures, gallery_component, figures_cards, figure_modal = create_figures_tab()
595
 
596
  # Papers subtab
597
  with gr.Tab("Papers", elem_id="tab-citations", id=4) as tab_papers:
 
644
 
645
 
646
  (textbox
647
+ .submit(start_chat, [textbox, chatbot, search_only], [textbox, tabs, chatbot, sources_raw], queue=False, api_name="start_chat_textbox")
648
+ .then(chat, [textbox, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, sources_textbox, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name="chat_textbox")
649
  .then(finish_chat, None, [textbox], api_name="finish_chat_textbox")
650
  )
651
 
652
+
653
+
654
  (examples_hidden
655
+ .change(start_chat, [examples_hidden, chatbot, search_only], [textbox, tabs, chatbot, sources_raw], queue=False, api_name="start_chat_examples")
656
+ .then(chat, [examples_hidden, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, sources_textbox, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name="chat_textbox")
657
  .then(finish_chat, None, [textbox], api_name="finish_chat_examples")
658
  )
659
 
660
+ new_figures.change(process_figures, inputs=[sources_raw, new_figures], outputs=[sources_raw, figures_cards, gallery_component])
661
 
662
  # Update sources numbers
663
  sources_textbox.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs, papers_html], [tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers])
climateqa/constants.py CHANGED
@@ -1,4 +1,6 @@
1
  POSSIBLE_REPORTS = [
 
 
2
  "IPCC AR6 WGI SPM",
3
  "IPCC AR6 WGI FR",
4
  "IPCC AR6 WGI TS",
 
1
  POSSIBLE_REPORTS = [
2
+ "IPBES IABWFH SPM",
3
+ "IPBES CBL SPM",
4
  "IPCC AR6 WGI SPM",
5
  "IPCC AR6 WGI FR",
6
  "IPCC AR6 WGI TS",
climateqa/engine/chains/retrieve_documents.py CHANGED
@@ -87,7 +87,7 @@ def _get_k_images_by_question(n_questions):
87
  elif n_questions == 2:
88
  return 5
89
  elif n_questions == 3:
90
- return 2
91
  else:
92
  return 1
93
 
@@ -98,7 +98,10 @@ def _add_metadata_and_score(docs: List) -> Document:
98
  doc.page_content = doc.page_content.replace("\r\n"," ")
99
  doc.metadata["similarity_score"] = score
100
  doc.metadata["content"] = doc.page_content
101
- doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
 
 
 
102
  # doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
103
  docs_with_metadata.append(doc)
104
  return docs_with_metadata
@@ -216,14 +219,17 @@ async def retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_qu
216
  docs = state["documents"]
217
  else:
218
  docs = []
 
219
  # Get the related_content from the state
220
  if "related_content" in state and state["related_content"] is not None:
221
  related_content = state["related_content"]
222
  else:
223
  related_content = []
224
 
225
- search_figures = "IPCC figures" in state["relevant_content_sources"]
226
  search_only = state["search_only"]
 
 
227
 
228
  # Get the current question
229
  current_question = state["remaining_questions"][0]
@@ -253,6 +259,7 @@ async def retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_qu
253
  k_images = k_images_by_question,
254
  threshold = 0.5,
255
  search_only = search_only,
 
256
  )
257
 
258
 
 
87
  elif n_questions == 2:
88
  return 5
89
  elif n_questions == 3:
90
+ return 3
91
  else:
92
  return 1
93
 
 
98
  doc.page_content = doc.page_content.replace("\r\n"," ")
99
  doc.metadata["similarity_score"] = score
100
  doc.metadata["content"] = doc.page_content
101
+ if doc.metadata["page_number"] != "N/A":
102
+ doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
103
+ else:
104
+ doc.metadata["page_number"] = 1
105
  # doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
106
  docs_with_metadata.append(doc)
107
  return docs_with_metadata
 
219
  docs = state["documents"]
220
  else:
221
  docs = []
222
+
223
  # Get the related_content from the state
224
  if "related_content" in state and state["related_content"] is not None:
225
  related_content = state["related_content"]
226
  else:
227
  related_content = []
228
 
229
+ search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources"]
230
  search_only = state["search_only"]
231
+
232
+ reports = state["reports"]
233
 
234
  # Get the current question
235
  current_question = state["remaining_questions"][0]
 
259
  k_images = k_images_by_question,
260
  threshold = 0.5,
261
  search_only = search_only,
262
+ reports = reports,
263
  )
264
 
265
 
climateqa/engine/chains/retrieve_papers.py CHANGED
@@ -33,7 +33,7 @@ def generate_keywords(query):
33
 
34
 
35
  async def find_papers(query,after, relevant_content_sources, reranker= reranker):
36
- if "OpenAlex" in relevant_content_sources:
37
  summary = ""
38
  keywords = generate_keywords(query)
39
  df_works = oa.search(keywords,after = after)
 
33
 
34
 
35
  async def find_papers(query,after, relevant_content_sources, reranker= reranker):
36
+ if "Papers (OpenAlex)" in relevant_content_sources:
37
  summary = ""
38
  keywords = generate_keywords(query)
39
  df_works = oa.search(keywords,after = after)
climateqa/engine/graph.py CHANGED
@@ -36,7 +36,7 @@ class GraphState(TypedDict):
36
  answer: str
37
  audience: str = "experts"
38
  sources_input: List[str] = ["IPCC","IPBES"]
39
- relevant_content_sources: List[str] = ["IPCC figures"]
40
  sources_auto: bool = True
41
  min_year: int = 1960
42
  max_year: int = None
@@ -44,6 +44,7 @@ class GraphState(TypedDict):
44
  related_contents : Dict[str,Document]
45
  recommended_content : List[Document]
46
  search_only : bool = False
 
47
 
48
  def search(state): #TODO
49
  return state
@@ -82,7 +83,7 @@ def route_based_on_relevant_docs(state,threshold_docs=0.2):
82
  return "answer_rag_no_docs"
83
 
84
  def route_retrieve_documents(state):
85
- if state["search_only"] :
86
  return END
87
  elif len(state["remaining_questions"]) > 0:
88
  return "retrieve_documents"
@@ -158,7 +159,7 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, reranker, thresh
158
  )
159
  workflow.add_conditional_edges(
160
  "transform_query",
161
- lambda state : "retrieve_graphs" if "OurWorldInData" in state["relevant_content_sources"] else END,
162
  make_id_dict(["retrieve_graphs", END])
163
  )
164
 
 
36
  answer: str
37
  audience: str = "experts"
38
  sources_input: List[str] = ["IPCC","IPBES"]
39
+ relevant_content_sources: List[str] = ["Figures (IPCC/IPBES)"]
40
  sources_auto: bool = True
41
  min_year: int = 1960
42
  max_year: int = None
 
44
  related_contents : Dict[str,Document]
45
  recommended_content : List[Document]
46
  search_only : bool = False
47
+ reports : List[str] = []
48
 
49
  def search(state): #TODO
50
  return state
 
83
  return "answer_rag_no_docs"
84
 
85
  def route_retrieve_documents(state):
86
+ if len(state["remaining_questions"]) == 0 and state["search_only"] :
87
  return END
88
  elif len(state["remaining_questions"]) > 0:
89
  return "retrieve_documents"
 
159
  )
160
  workflow.add_conditional_edges(
161
  "transform_query",
162
+ lambda state : "retrieve_graphs" if "Graphs (OurWorldInData)" in state["relevant_content_sources"] else END,
163
  make_id_dict(["retrieve_graphs", END])
164
  )
165
 
front/utils.py CHANGED
@@ -39,23 +39,29 @@ def parse_output_llm_with_sources(output:str)->str:
39
  content_parts = "".join(parts)
40
  return content_parts
41
 
42
- def process_figures(docs:list)->tuple:
43
- gallery=[]
44
- used_figures =[]
45
  figures = '<div class="figures-container"><p></p> </div>'
 
 
 
 
 
 
 
46
  docs_figures = [d for d in docs if d.metadata["chunk_type"] == "image"]
47
- for i, doc in enumerate(docs_figures):
48
- if doc.metadata["chunk_type"] == "image":
49
- if doc.metadata["figure_code"] != "N/A":
50
- title = f"{doc.metadata['figure_code']} - {doc.metadata['short_name']}"
51
- else:
52
- title = f"{doc.metadata['short_name']}"
53
 
54
 
55
- if title not in used_figures:
56
- used_figures.append(title)
 
 
57
  try:
58
- key = f"Image {i+1}"
59
 
60
  image_path = doc.metadata["image_path"].split("documents/")[1]
61
  img = get_image_from_azure_blob_storage(image_path)
@@ -68,12 +74,12 @@ def process_figures(docs:list)->tuple:
68
 
69
  img_str = base64.b64encode(buffered.getvalue()).decode()
70
 
71
- figures = figures + make_html_figure_sources(doc, i, img_str)
72
  gallery.append(img)
73
  except Exception as e:
74
- print(f"Skipped adding image {i} because of {e}")
75
 
76
- return figures, gallery
77
 
78
 
79
  def generate_html_graphs(graphs:list)->str:
 
39
  content_parts = "".join(parts)
40
  return content_parts
41
 
42
+ def process_figures(docs:list, new_figures:list)->tuple:
43
+ docs = docs + new_figures
44
+
45
  figures = '<div class="figures-container"><p></p> </div>'
46
+ gallery = []
47
+ used_figures = []
48
+
49
+ if docs == []:
50
+ return docs, figures, gallery
51
+
52
+
53
  docs_figures = [d for d in docs if d.metadata["chunk_type"] == "image"]
54
+ for i_doc, doc in enumerate(docs_figures):
55
+ if doc.metadata["chunk_type"] == "image":
56
+ path = doc.metadata["image_path"]
 
 
 
57
 
58
 
59
+ if path not in used_figures:
60
+ used_figures.append(path)
61
+ figure_number = len(used_figures)
62
+
63
  try:
64
+ key = f"Image {figure_number}"
65
 
66
  image_path = doc.metadata["image_path"].split("documents/")[1]
67
  img = get_image_from_azure_blob_storage(image_path)
 
74
 
75
  img_str = base64.b64encode(buffered.getvalue()).decode()
76
 
77
+ figures = figures + make_html_figure_sources(doc, figure_number, img_str)
78
  gallery.append(img)
79
  except Exception as e:
80
+ print(f"Skipped adding image {figure_number} because of {e}")
81
 
82
+ return docs, figures, gallery
83
 
84
 
85
  def generate_html_graphs(graphs:list)->str:
style.css CHANGED
@@ -24,18 +24,11 @@ main.flex.flex-1.flex-col {
24
  }
25
 
26
  #group-subtabs {
27
- width: 100%;
28
- position: sticky;
29
  }
30
 
31
- #group-subtabs .tab-container {
32
- display: flex;
33
- text-align: center;
34
- width: 100%;
35
- }
36
 
37
- #group-subtabs .tab-container button {
38
- flex: 1;
39
  }
40
 
41
  .tab-nav {
 
24
  }
25
 
26
  #group-subtabs {
27
+ /* display: block; */
28
+ position : sticky;
29
  }
30
 
 
 
 
 
 
31
 
 
 
32
  }
33
 
34
  .tab-nav {