Spaces:
Runtime error
Runtime error
import streamlit as st | |
from huggingface_hub import InferenceClient | |
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1") | |
# Define the AI Planner System Prompt | |
ai_planner_system_prompt = """ | |
As the AI Planner, your primary task is to assist in the development of a coherent and engaging book. You will be responsible for organizing the overall structure, defining the plot or narrative, and outlining the chapters or sections. To accomplish this, you will need to use your understanding of storytelling principles and genre conventions, as well as any specific information provided by the user, to create a well-structured framework for the book. | |
""" | |
def format_prompt(message, history): | |
# Start the prompt and include the AI Planner System Prompt | |
prompt = f"<s>{ai_planner_system_prompt}" | |
# Add the conversation history | |
for user_prompt, bot_response in history: | |
prompt += f"[INST] {user_prompt} [/INST]" | |
prompt += f" {bot_response}</s> " | |
# Add the current message | |
prompt += f"[INST] {message} [/INST]" | |
return prompt | |
def generate(prompt, history): | |
temperature = 0.9 | |
max_new_tokens = 256 | |
top_p = 0.95 | |
repetition_penalty = 1.0 | |
generate_kwargs = dict( | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
seed=42, | |
) | |
formatted_prompt = format_prompt(prompt, history) | |
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) | |
output = "" | |
for response in stream: | |
output += response.token.text | |
return output | |
# Streamlit UI | |
st.title("Chat with Mixtral-8x7B") | |
# Chat history is stored in a session state to persist across reruns | |
if 'history' not in st.session_state: | |
st.session_state['history'] = [] | |
user_input = st.text_input("You:", key="user_input") | |
send_button = st.button("Send") | |
if send_button and user_input: | |
st.session_state.history.append(("You", user_input)) | |
response = generate(user_input, st.session_state.history) | |
st.session_state.history.append(("Mixtral-8x7B", response)) | |
# Display the conversation | |
for index, (user, text) in enumerate(st.session_state.history): | |
# Create a unique key by combining the index with a portion of the text | |
unique_key = f"{index}_{text[:10]}" | |
st.text_area(label=user, value=text, height=100, key=unique_key) |