Spaces:
Sleeping
Sleeping
Refactor to extract token stream handler
Browse files- app.py +29 -38
- token_stream_handler.py +13 -0
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import os
|
2 |
import streamlit as st
|
3 |
-
|
|
|
4 |
from langchain.chains import ConversationalRetrievalChain
|
5 |
from langchain.schema import ChatMessage
|
6 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
@@ -10,21 +11,6 @@ from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
|
10 |
|
11 |
st.set_page_config(page_title="InkChatGPT", page_icon="π")
|
12 |
|
13 |
-
__import__("pysqlite3")
|
14 |
-
import sys
|
15 |
-
|
16 |
-
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
|
17 |
-
|
18 |
-
|
19 |
-
class StreamHandler(BaseCallbackHandler):
|
20 |
-
def __init__(self, container, initial_text=""):
|
21 |
-
self.container = container
|
22 |
-
self.text = initial_text
|
23 |
-
|
24 |
-
def on_llm_new_token(self, token: str, **kwargs) -> None:
|
25 |
-
self.text += token
|
26 |
-
self.container.markdown(self.text)
|
27 |
-
|
28 |
|
29 |
def load_and_process_file(file_data):
|
30 |
"""
|
@@ -156,38 +142,43 @@ def clear_history():
|
|
156 |
del st.session_state["history"]
|
157 |
|
158 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
def build_sidebar():
|
160 |
with st.sidebar:
|
161 |
st.title("π InkChatGPT")
|
162 |
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
|
|
|
|
167 |
|
168 |
-
|
169 |
-
|
|
|
170 |
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
|
175 |
-
|
176 |
-
add_file = st.button(
|
177 |
"Process File",
|
178 |
-
on_click=
|
179 |
-
|
|
|
180 |
)
|
181 |
|
182 |
-
if uploaded_file and add_file:
|
183 |
-
with st.spinner("π Thinking..."):
|
184 |
-
vector_store = load_and_process_file(uploaded_file)
|
185 |
-
|
186 |
-
if vector_store:
|
187 |
-
crc = initialize_chat_model(vector_store)
|
188 |
-
st.session_state.crc = crc
|
189 |
-
st.success("File processed successfully!")
|
190 |
-
|
191 |
|
192 |
if __name__ == "__main__":
|
193 |
build_sidebar()
|
|
|
1 |
import os
|
2 |
import streamlit as st
|
3 |
+
|
4 |
+
from token_stream_handler import StreamHandler
|
5 |
from langchain.chains import ConversationalRetrievalChain
|
6 |
from langchain.schema import ChatMessage
|
7 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
|
11 |
|
12 |
st.set_page_config(page_title="InkChatGPT", page_icon="π")
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
def load_and_process_file(file_data):
|
16 |
"""
|
|
|
142 |
del st.session_state["history"]
|
143 |
|
144 |
|
145 |
+
def process_data(uploaded_file, openai_api_key):
|
146 |
+
if uploaded_file and openai_api_key.startswith("sk-"):
|
147 |
+
with st.spinner("π Thinking..."):
|
148 |
+
vector_store = load_and_process_file(uploaded_file)
|
149 |
+
|
150 |
+
if vector_store:
|
151 |
+
crc = initialize_chat_model(vector_store)
|
152 |
+
st.session_state.crc = crc
|
153 |
+
st.success(f"File: `{uploaded_file.name}`, processed successfully!")
|
154 |
+
|
155 |
+
|
156 |
def build_sidebar():
|
157 |
with st.sidebar:
|
158 |
st.title("π InkChatGPT")
|
159 |
|
160 |
+
with st.form(key="input_form"):
|
161 |
+
openai_api_key = st.text_input(
|
162 |
+
"OpenAI API Key",
|
163 |
+
type="password",
|
164 |
+
placeholder="Enter your OpenAI API key",
|
165 |
+
)
|
166 |
|
167 |
+
st.session_state.api_key = openai_api_key
|
168 |
+
if not openai_api_key:
|
169 |
+
st.info("Please add your OpenAI API key to continue.")
|
170 |
|
171 |
+
uploaded_file = st.file_uploader(
|
172 |
+
"Select a file", type=["pdf", "docx", "txt"], key="file_uploader"
|
173 |
+
)
|
174 |
|
175 |
+
st.form_submit_button(
|
|
|
176 |
"Process File",
|
177 |
+
on_click=process_data(
|
178 |
+
uploaded_file=uploaded_file, openai_api_key=openai_api_key
|
179 |
+
),
|
180 |
)
|
181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
|
183 |
if __name__ == "__main__":
|
184 |
build_sidebar()
|
token_stream_handler.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from langchain.callbacks.base import BaseCallbackHandler
|
4 |
+
|
5 |
+
|
6 |
+
class StreamHandler(BaseCallbackHandler):
|
7 |
+
def __init__(self, container, initial_text=""):
|
8 |
+
self.container = container
|
9 |
+
self.text = initial_text
|
10 |
+
|
11 |
+
def on_llm_new_token(self, token: str, **kwargs) -> None:
|
12 |
+
self.text += token
|
13 |
+
self.container.markdown(self.text)
|