File size: 2,487 Bytes
afd4069 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import streamlit as st
from streamlit.delta_generator import DeltaGenerator
from client import get_client
from conversation import postprocess_text, preprocess_text, Conversation, Role
MAX_LENGTH = 8192
client = get_client()
# Append a conversation into history, while show it in a new markdown block
def append_conversation(
conversation: Conversation,
history: list[Conversation],
placeholder: DeltaGenerator | None=None,
) -> None:
history.append(conversation)
conversation.show(placeholder)
def main(top_p: float, temperature: float, system_prompt: str, prompt_text: str):
placeholder = st.empty()
with placeholder.container():
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
history: list[Conversation] = st.session_state.chat_history
for conversation in history:
conversation.show()
if prompt_text:
prompt_text = prompt_text.strip()
append_conversation(Conversation(Role.USER, prompt_text), history)
input_text = preprocess_text(
system_prompt,
tools=None,
history=history,
)
print("=== Input:")
print(input_text)
print("=== History:")
print(history)
placeholder = st.empty()
message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant")
markdown_placeholder = message_placeholder.empty()
output_text = ''
for response in client.generate_stream(
system_prompt,
tools=None,
history=history,
do_sample=True,
max_length=MAX_LENGTH,
temperature=temperature,
top_p=top_p,
stop_sequences=[str(Role.USER)],
):
token = response.token
if response.token.special:
print("=== Output:")
print(output_text)
match token.text.strip():
case '<|user|>':
break
case _:
st.error(f'Unexpected special token: {token.text.strip()}')
break
output_text += response.token.text
markdown_placeholder.markdown(postprocess_text(output_text + '▌'))
append_conversation(Conversation(
Role.ASSISTANT,
postprocess_text(output_text),
), history, markdown_placeholder) |