lfoppiano commited on
Commit
d428544
1 Parent(s): 3b9ffd5

some revision and refactor of the interface

Browse files
Files changed (2) hide show
  1. requirements.txt +1 -1
  2. streamlit_app.py +63 -114
requirements.txt CHANGED
@@ -7,7 +7,7 @@ grobid_tei_xml==0.1.3
7
  tqdm==4.66.2
8
  pyyaml==6.0.1
9
  pytest==8.1.1
10
- streamlit==1.33.0
11
  lxml
12
  Beautifulsoup4
13
  python-dotenv
 
7
  tqdm==4.66.2
8
  pyyaml==6.0.1
9
  pytest==8.1.1
10
+ streamlit==1.36.0
11
  lxml
12
  Beautifulsoup4
13
  python-dotenv
streamlit_app.py CHANGED
@@ -42,8 +42,6 @@ OPEN_EMBEDDINGS = {
42
  'Salesforce/SFR-Embedding-Mistral': 'Salesforce/SFR-Embedding-Mistral'
43
  }
44
 
45
- DISABLE_MEMORY = ['zephyr-7b-beta']
46
-
47
  if 'rqa' not in st.session_state:
48
  st.session_state['rqa'] = {}
49
 
@@ -108,36 +106,6 @@ st.set_page_config(
108
  }
109
  )
110
 
111
- css_modify_left_column = '''
112
- <style>
113
- [data-testid="stHorizontalBlock"] > div:nth-child(1) {
114
- overflow: hidden;
115
- background-color: red;
116
- height: 70vh;
117
- }
118
- </style>
119
- '''
120
- css_modify_right_column = '''
121
- <style>
122
- [data-testid="stHorizontalBlock"]> div:first-child {
123
- background-color: red;
124
- position: fixed
125
- height: 70vh;
126
- }
127
- </style>
128
- '''
129
- css_disable_scrolling_container = '''
130
- <style>
131
- [data-testid="ScrollToBottomContainer"] {
132
- overflow: hidden;
133
- }
134
- </style>
135
- '''
136
-
137
-
138
- # st.markdown(css_lock_column_fixed, unsafe_allow_html=True)
139
- # st.markdown(css2, unsafe_allow_html=True)
140
-
141
 
142
  def new_file():
143
  st.session_state['loaded_embeddings'] = None
@@ -188,7 +156,7 @@ def init_qa(model, embeddings_name=None, api_key=None):
188
  )
189
  embeddings = HuggingFaceEmbeddings(
190
  model_name=OPEN_EMBEDDINGS[embeddings_name])
191
- st.session_state['memory'] = ConversationBufferWindowMemory(k=4) if model not in DISABLE_MEMORY else None
192
  else:
193
  st.error("The model was not loaded properly. Try reloading. ")
194
  st.stop()
@@ -233,23 +201,27 @@ def get_file_hash(fname):
233
  return hash_md5.hexdigest()
234
 
235
 
236
- def play_old_messages():
237
  if st.session_state['messages']:
238
  for message in st.session_state['messages']:
239
  if message['role'] == 'user':
240
- with st.chat_message("user"):
241
- st.markdown(message['content'])
242
  elif message['role'] == 'assistant':
243
- with st.chat_message("assistant"):
244
- if mode == "LLM":
245
- st.markdown(message['content'], unsafe_allow_html=True)
246
- else:
247
- st.write(message['content'])
248
 
249
 
250
  # is_api_key_provided = st.session_state['api_key']
251
 
252
  with st.sidebar:
 
 
 
 
 
 
253
  st.session_state['model'] = model = st.selectbox(
254
  "Model:",
255
  options=OPENAI_MODELS + list(OPEN_MODELS.keys()),
@@ -305,22 +277,18 @@ with st.sidebar:
305
  # else:
306
  # is_api_key_provided = st.session_state['api_key']
307
 
308
- st.button(
309
- 'Reset chat memory.',
310
- key="reset-memory-button",
311
- on_click=clear_memory,
312
- help="Clear the conversational memory. Currently implemented to retrain the 4 most recent messages.",
313
- disabled=model in st.session_state['rqa'] and st.session_state['rqa'][model].memory is None)
314
 
315
  left_column, right_column = st.columns([1, 1])
 
 
316
 
317
  with right_column:
318
- st.title("📝 Scientific Document Insights Q/A")
319
- st.subheader("Upload a scientific article in PDF, ask questions, get insights.")
320
-
321
- st.markdown(
322
- ":warning: [Usage disclaimer](https://github.com/lfoppiano/document-qa?tab=readme-ov-file#disclaimer-on-data-security-and-privacy-%EF%B8%8F) :warning: ")
323
-
324
  uploaded_file = st.file_uploader(
325
  "Upload an article",
326
  type=("pdf", "txt"),
@@ -330,11 +298,14 @@ with right_column:
330
  help="The full-text is extracted using Grobid."
331
  )
332
 
333
- question = st.chat_input(
334
- "Ask something about the article",
335
- # placeholder="Can you give me a short summary?",
336
- disabled=not uploaded_file
337
- )
 
 
 
338
 
339
  query_modes = {
340
  "llm": "LLM Q/A",
@@ -355,6 +326,10 @@ with st.sidebar:
355
  "relevant paragraphs to the question in the paper. "
356
  "Question coefficient attempt to estimate how effective the question will be answered."
357
  )
 
 
 
 
358
 
359
  # Add a checkbox for showing annotations
360
  # st.session_state['show_annotations'] = st.checkbox("Show annotations", value=True)
@@ -372,11 +347,6 @@ with st.sidebar:
372
  help="Number of chunks to consider when answering a question",
373
  disabled=not uploaded_file)
374
 
375
- st.session_state['ner_processing'] = st.checkbox("Identify materials and properties.")
376
- st.markdown(
377
- 'The LLM responses undergo post-processing to extract <span style="color:orange">physical quantities, measurements</span>, and <span style="color:green">materials</span> mentions.',
378
- unsafe_allow_html=True)
379
-
380
  st.divider()
381
 
382
  st.header("Documentation")
@@ -403,7 +373,7 @@ if uploaded_file and not st.session_state.loaded_embeddings:
403
  st.error("Before uploading a document, you must enter the API key. ")
404
  st.stop()
405
 
406
- with right_column:
407
  with st.spinner('Reading file, calling Grobid, and creating memory embeddings...'):
408
  binary = uploaded_file.getvalue()
409
  tmp_file = NamedTemporaryFile()
@@ -416,8 +386,6 @@ if uploaded_file and not st.session_state.loaded_embeddings:
416
  st.session_state['loaded_embeddings'] = True
417
  st.session_state.messages = []
418
 
419
- # timestamp = datetime.utcnow()
420
-
421
 
422
  def rgb_to_hex(rgb):
423
  return "#{:02x}{:02x}{:02x}".format(*rgb)
@@ -439,41 +407,21 @@ def generate_color_gradient(num_elements):
439
 
440
 
441
  with right_column:
442
- # css = '''
443
- # <style>
444
- # [data-testid="column"] {
445
- # overflow: auto;
446
- # height: 70vh;
447
- # }
448
- # </style>
449
- # '''
450
- # st.markdown(css, unsafe_allow_html=True)
451
-
452
- # st.markdown(
453
- # """
454
- # <script>
455
- # document.querySelectorAll('[data-testid="column"]').scrollIntoView({behavior: "smooth"});
456
- # </script>
457
- # """,
458
- # unsafe_allow_html=True,
459
- # )
460
-
461
  if st.session_state.loaded_embeddings and question and len(question) > 0 and st.session_state.doc_id:
462
  for message in st.session_state.messages:
463
- with st.chat_message(message["role"]):
464
  if message['mode'] == "llm":
465
- st.markdown(message["content"], unsafe_allow_html=True)
466
  elif message['mode'] == "embeddings":
467
- st.write(message["content"])
468
  if message['mode'] == "question_coefficient":
469
- st.markdown(message["content"], unsafe_allow_html=True)
470
  if model not in st.session_state['rqa']:
471
  st.error("The API Key for the " + model + " is missing. Please add it before sending any query. `")
472
  st.stop()
473
 
474
- with st.chat_message("user"):
475
- st.markdown(question)
476
- st.session_state.messages.append({"role": "user", "mode": mode, "content": question})
477
 
478
  text_response = None
479
  if mode == "embeddings":
@@ -484,12 +432,13 @@ with right_column:
484
  context_size=context_size
485
  )
486
  elif mode == "llm":
487
- with st.spinner("Generating LLM response..."):
488
- _, text_response, coordinates = st.session_state['rqa'][model].query_document(
489
- question,
490
- st.session_state.doc_id,
491
- context_size=context_size
492
- )
 
493
 
494
  elif mode == "question_coefficient":
495
  with st.spinner("Estimate question/context relevancy..."):
@@ -511,28 +460,28 @@ with right_column:
511
  if not text_response:
512
  st.error("Something went wrong. Contact Luca Foppiano ([email protected]) to report the issue.")
513
 
514
- with st.chat_message("assistant"):
515
- if mode == "llm":
516
- if st.session_state['ner_processing']:
517
- with st.spinner("Processing NER on LLM response..."):
518
- entities = gqa.process_single_text(text_response)
519
- decorated_text = decorate_text_with_annotations(text_response.strip(), entities)
520
- decorated_text = decorated_text.replace('class="label material"', 'style="color:green"')
521
- decorated_text = re.sub(r'class="label[^"]+"', 'style="color:orange"', decorated_text)
522
- text_response = decorated_text
523
- st.markdown(text_response, unsafe_allow_html=True)
524
- else:
525
- st.write(text_response)
526
- st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})
527
 
528
  elif st.session_state.loaded_embeddings and st.session_state.doc_id:
529
- play_old_messages()
530
 
531
  with left_column:
532
  if st.session_state['binary']:
533
  pdf_viewer(
534
  input=st.session_state['binary'],
535
- annotation_outline_size=1,
536
  annotations=st.session_state['annotations'],
537
- render_text=True
 
538
  )
 
42
  'Salesforce/SFR-Embedding-Mistral': 'Salesforce/SFR-Embedding-Mistral'
43
  }
44
 
 
 
45
  if 'rqa' not in st.session_state:
46
  st.session_state['rqa'] = {}
47
 
 
106
  }
107
  )
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  def new_file():
111
  st.session_state['loaded_embeddings'] = None
 
156
  )
157
  embeddings = HuggingFaceEmbeddings(
158
  model_name=OPEN_EMBEDDINGS[embeddings_name])
159
+ # st.session_state['memory'] = ConversationBufferWindowMemory(k=4) if model not in DISABLE_MEMORY else None
160
  else:
161
  st.error("The model was not loaded properly. Try reloading. ")
162
  st.stop()
 
201
  return hash_md5.hexdigest()
202
 
203
 
204
+ def play_old_messages(container):
205
  if st.session_state['messages']:
206
  for message in st.session_state['messages']:
207
  if message['role'] == 'user':
208
+ container.chat_message("user").markdown(message['content'])
 
209
  elif message['role'] == 'assistant':
210
+ if mode == "LLM":
211
+ container.chat_message("assistant").markdown(message['content'], unsafe_allow_html=True)
212
+ else:
213
+ container.chat_message("assistant").write(message['content'])
 
214
 
215
 
216
  # is_api_key_provided = st.session_state['api_key']
217
 
218
  with st.sidebar:
219
+ st.title("📝 Scientific Document Insights Q/A")
220
+ st.subheader("Upload a scientific article in PDF, ask questions, get insights.")
221
+ st.markdown(
222
+ ":warning: [Usage disclaimer](https://github.com/lfoppiano/document-qa?tab=readme-ov-file#disclaimer-on-data-security-and-privacy-%EF%B8%8F) :warning: ")
223
+
224
+ st.divider()
225
  st.session_state['model'] = model = st.selectbox(
226
  "Model:",
227
  options=OPENAI_MODELS + list(OPEN_MODELS.keys()),
 
277
  # else:
278
  # is_api_key_provided = st.session_state['api_key']
279
 
280
+ # st.button(
281
+ # 'Reset chat memory.',
282
+ # key="reset-memory-button",
283
+ # on_click=clear_memory,
284
+ # help="Clear the conversational memory. Currently implemented to retrain the 4 most recent messages.",
285
+ # disabled=model in st.session_state['rqa'] and st.session_state['rqa'][model].memory is None)
286
 
287
  left_column, right_column = st.columns([1, 1])
288
+ right_column = right_column.container(height=600, border=False)
289
+ left_column = left_column.container(height=600, border=False)
290
 
291
  with right_column:
 
 
 
 
 
 
292
  uploaded_file = st.file_uploader(
293
  "Upload an article",
294
  type=("pdf", "txt"),
 
298
  help="The full-text is extracted using Grobid."
299
  )
300
 
301
+ placeholder = st.empty()
302
+ messages = st.container(height=300, border=False)
303
+
304
+ question = st.chat_input(
305
+ "Ask something about the article",
306
+ # placeholder="Can you give me a short summary?",
307
+ disabled=not uploaded_file
308
+ )
309
 
310
  query_modes = {
311
  "llm": "LLM Q/A",
 
326
  "relevant paragraphs to the question in the paper. "
327
  "Question coefficient attempt to estimate how effective the question will be answered."
328
  )
329
+ st.session_state['ner_processing'] = st.checkbox(
330
+ "Identify materials and properties.",
331
+ help='The LLM responses undergo post-processing to extract physical quantities, measurements, and materials mentions.'
332
+ )
333
 
334
  # Add a checkbox for showing annotations
335
  # st.session_state['show_annotations'] = st.checkbox("Show annotations", value=True)
 
347
  help="Number of chunks to consider when answering a question",
348
  disabled=not uploaded_file)
349
 
 
 
 
 
 
350
  st.divider()
351
 
352
  st.header("Documentation")
 
373
  st.error("Before uploading a document, you must enter the API key. ")
374
  st.stop()
375
 
376
+ with left_column:
377
  with st.spinner('Reading file, calling Grobid, and creating memory embeddings...'):
378
  binary = uploaded_file.getvalue()
379
  tmp_file = NamedTemporaryFile()
 
386
  st.session_state['loaded_embeddings'] = True
387
  st.session_state.messages = []
388
 
 
 
389
 
390
  def rgb_to_hex(rgb):
391
  return "#{:02x}{:02x}{:02x}".format(*rgb)
 
407
 
408
 
409
  with right_column:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
  if st.session_state.loaded_embeddings and question and len(question) > 0 and st.session_state.doc_id:
411
  for message in st.session_state.messages:
412
+ with messages.chat_message(message["role"]):
413
  if message['mode'] == "llm":
414
+ messages.chat_message(message["role"]).markdown(message["content"], unsafe_allow_html=True)
415
  elif message['mode'] == "embeddings":
416
+ messages.chat_message(message["role"]).write(message["content"])
417
  if message['mode'] == "question_coefficient":
418
+ messages.chat_message(message["role"]).markdown(message["content"], unsafe_allow_html=True)
419
  if model not in st.session_state['rqa']:
420
  st.error("The API Key for the " + model + " is missing. Please add it before sending any query. `")
421
  st.stop()
422
 
423
+ messages.chat_message("user").markdown(question)
424
+ st.session_state.messages.append({"role": "user", "mode": mode, "content": question})
 
425
 
426
  text_response = None
427
  if mode == "embeddings":
 
432
  context_size=context_size
433
  )
434
  elif mode == "llm":
435
+ with placeholder:
436
+ with st.spinner("Generating LLM response..."):
437
+ _, text_response, coordinates = st.session_state['rqa'][model].query_document(
438
+ question,
439
+ st.session_state.doc_id,
440
+ context_size=context_size
441
+ )
442
 
443
  elif mode == "question_coefficient":
444
  with st.spinner("Estimate question/context relevancy..."):
 
460
  if not text_response:
461
  st.error("Something went wrong. Contact Luca Foppiano ([email protected]) to report the issue.")
462
 
463
+ if mode == "llm":
464
+ if st.session_state['ner_processing']:
465
+ with st.spinner("Processing NER on LLM response..."):
466
+ entities = gqa.process_single_text(text_response)
467
+ decorated_text = decorate_text_with_annotations(text_response.strip(), entities)
468
+ decorated_text = decorated_text.replace('class="label material"', 'style="color:green"')
469
+ decorated_text = re.sub(r'class="label[^"]+"', 'style="color:orange"', decorated_text)
470
+ text_response = decorated_text
471
+ messages.chat_message("assistant").markdown(text_response, unsafe_allow_html=True)
472
+ else:
473
+ messages.chat_message("assistant").write(text_response)
474
+ st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})
 
475
 
476
  elif st.session_state.loaded_embeddings and st.session_state.doc_id:
477
+ play_old_messages(messages)
478
 
479
  with left_column:
480
  if st.session_state['binary']:
481
  pdf_viewer(
482
  input=st.session_state['binary'],
483
+ annotation_outline_size=2,
484
  annotations=st.session_state['annotations'],
485
+ render_text=True,
486
+ height=700
487
  )