import streamlit as st from streamlit.delta_generator import DeltaGenerator from client import get_client from conversation import postprocess_text, preprocess_text, Conversation, Role MAX_LENGTH = 8192 client = get_client() # Append a conversation into history, while show it in a new markdown block def append_conversation( conversation: Conversation, history: list[Conversation], placeholder: DeltaGenerator | None=None, ) -> None: history.append(conversation) conversation.show(placeholder) def main(top_p: float, temperature: float, system_prompt: str, prompt_text: str): placeholder = st.empty() with placeholder.container(): if 'chat_history' not in st.session_state: st.session_state.chat_history = [] history: list[Conversation] = st.session_state.chat_history for conversation in history: conversation.show() if prompt_text: prompt_text = prompt_text.strip() append_conversation(Conversation(Role.USER, prompt_text), history) input_text = preprocess_text( system_prompt, tools=None, history=history, ) print("=== Input:") print(input_text) print("=== History:") print(history) placeholder = st.empty() message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant") markdown_placeholder = message_placeholder.empty() output_text = '' for response in client.generate_stream( system_prompt, tools=None, history=history, do_sample=True, max_length=MAX_LENGTH, temperature=temperature, top_p=top_p, stop_sequences=[str(Role.USER)], ): token = response.token if response.token.special: print("=== Output:") print(output_text) match token.text.strip(): case '<|user|>': break case _: st.error(f'Unexpected special token: {token.text.strip()}') break output_text += response.token.text markdown_placeholder.markdown(postprocess_text(output_text + '▌')) append_conversation(Conversation( Role.ASSISTANT, postprocess_text(output_text), ), history, markdown_placeholder)