Spaces:
Running
Running
small improvements
Browse files- streamlit_app.py +19 -16
streamlit_app.py
CHANGED
|
@@ -134,13 +134,13 @@ model = st.sidebar.radio("Model (cannot be changed after selection or upload)",
|
|
| 134 |
|
| 135 |
if not st.session_state['api_key']:
|
| 136 |
if model == 'mistral-7b-instruct-v0.1' or model == 'llama-2-70b-chat':
|
| 137 |
-
api_key = st.sidebar.text_input('Huggingface API Key')# if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ else os.environ['HUGGINGFACEHUB_API_TOKEN']
|
| 138 |
if api_key:
|
| 139 |
st.session_state['api_key'] = is_api_key_provided = True
|
| 140 |
os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key
|
| 141 |
st.session_state['rqa'] = init_qa(model)
|
| 142 |
elif model == 'chatgpt-3.5-turbo':
|
| 143 |
-
api_key = st.sidebar.text_input('OpenAI API Key') #if 'OPENAI_API_KEY' not in os.environ else os.environ['OPENAI_API_KEY']
|
| 144 |
if api_key:
|
| 145 |
st.session_state['api_key'] = is_api_key_provided = True
|
| 146 |
os.environ['OPENAI_API_KEY'] = api_key
|
|
@@ -211,31 +211,34 @@ if st.session_state.loaded_embeddings and question and len(question) > 0 and st.
|
|
| 211 |
elif message['mode'] == "Embeddings":
|
| 212 |
st.write(message["content"])
|
| 213 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
text_response = None
|
| 215 |
if mode == "Embeddings":
|
| 216 |
-
|
|
|
|
| 217 |
context_size=context_size)
|
| 218 |
elif mode == "LLM":
|
| 219 |
-
|
|
|
|
| 220 |
context_size=context_size)
|
| 221 |
|
| 222 |
if not text_response:
|
| 223 |
st.error("Something went wrong. Contact Luca Foppiano ([email protected]) to report the issue.")
|
| 224 |
|
| 225 |
-
with st.chat_message("user"):
|
| 226 |
-
st.markdown(question)
|
| 227 |
-
st.session_state.messages.append({"role": "user", "mode": mode, "content": question})
|
| 228 |
-
|
| 229 |
with st.chat_message("assistant"):
|
| 230 |
if mode == "LLM":
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
|
|
|
| 239 |
else:
|
| 240 |
st.write(text_response)
|
| 241 |
st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})
|
|
|
|
| 134 |
|
| 135 |
if not st.session_state['api_key']:
|
| 136 |
if model == 'mistral-7b-instruct-v0.1' or model == 'llama-2-70b-chat':
|
| 137 |
+
api_key = st.sidebar.text_input('Huggingface API Key', type="password")# if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ else os.environ['HUGGINGFACEHUB_API_TOKEN']
|
| 138 |
if api_key:
|
| 139 |
st.session_state['api_key'] = is_api_key_provided = True
|
| 140 |
os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key
|
| 141 |
st.session_state['rqa'] = init_qa(model)
|
| 142 |
elif model == 'chatgpt-3.5-turbo':
|
| 143 |
+
api_key = st.sidebar.text_input('OpenAI API Key', type="password") #if 'OPENAI_API_KEY' not in os.environ else os.environ['OPENAI_API_KEY']
|
| 144 |
if api_key:
|
| 145 |
st.session_state['api_key'] = is_api_key_provided = True
|
| 146 |
os.environ['OPENAI_API_KEY'] = api_key
|
|
|
|
| 211 |
elif message['mode'] == "Embeddings":
|
| 212 |
st.write(message["content"])
|
| 213 |
|
| 214 |
+
with st.chat_message("user"):
|
| 215 |
+
st.markdown(question)
|
| 216 |
+
st.session_state.messages.append({"role": "user", "mode": mode, "content": question})
|
| 217 |
+
|
| 218 |
text_response = None
|
| 219 |
if mode == "Embeddings":
|
| 220 |
+
with st.spinner("Generating LLM response..."):
|
| 221 |
+
text_response = st.session_state['rqa'].query_storage(question, st.session_state.doc_id,
|
| 222 |
context_size=context_size)
|
| 223 |
elif mode == "LLM":
|
| 224 |
+
with st.spinner("Generating response..."):
|
| 225 |
+
_, text_response = st.session_state['rqa'].query_document(question, st.session_state.doc_id,
|
| 226 |
context_size=context_size)
|
| 227 |
|
| 228 |
if not text_response:
|
| 229 |
st.error("Something went wrong. Contact Luca Foppiano ([email protected]) to report the issue.")
|
| 230 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
with st.chat_message("assistant"):
|
| 232 |
if mode == "LLM":
|
| 233 |
+
with st.spinner("Processing NER on LLM response..."):
|
| 234 |
+
entities = gqa.process_single_text(text_response)
|
| 235 |
+
# for entity in entities:
|
| 236 |
+
# entity
|
| 237 |
+
decorated_text = decorate_text_with_annotations(text_response.strip(), entities)
|
| 238 |
+
decorated_text = decorated_text.replace('class="label material"', 'style="color:green"')
|
| 239 |
+
decorated_text = re.sub(r'class="label[^"]+"', 'style="color:orange"', decorated_text)
|
| 240 |
+
st.markdown(decorated_text, unsafe_allow_html=True)
|
| 241 |
+
text_response = decorated_text
|
| 242 |
else:
|
| 243 |
st.write(text_response)
|
| 244 |
st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})
|