File size: 2,480 Bytes
9d2ee94
 
 
 
 
2b93ce3
 
 
 
 
9d2ee94
2b93ce3
 
 
9d2ee94
 
 
2b93ce3
9d2ee94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
055132b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import streamlit as st
from huggingface_hub import InferenceClient

client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")

# Define the AI Planner System Prompt
ai_planner_system_prompt = """
As the AI Planner, your primary task is to assist in the development of a coherent and engaging book. You will be responsible for organizing the overall structure, defining the plot or narrative, and outlining the chapters or sections. To accomplish this, you will need to use your understanding of storytelling principles and genre conventions, as well as any specific information provided by the user, to create a well-structured framework for the book.
"""

def format_prompt(message, history):
    # Start the prompt and include the AI Planner System Prompt
    prompt = f"<s>{ai_planner_system_prompt}"
    # Add the conversation history
    for user_prompt, bot_response in history:
        prompt += f"[INST] {user_prompt} [/INST]"
        prompt += f" {bot_response}</s> "
    # Add the current message
    prompt += f"[INST] {message} [/INST]"
    return prompt

def generate(prompt, history):
    temperature = 0.9
    max_new_tokens = 256
    top_p = 0.95
    repetition_penalty = 1.0
    
    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )
    
    formatted_prompt = format_prompt(prompt, history)
    
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""
    for response in stream:
        output += response.token.text
    return output

# Streamlit UI
st.title("Chat with Mixtral-8x7B")

# Chat history is stored in a session state to persist across reruns
if 'history' not in st.session_state:
    st.session_state['history'] = []

user_input = st.text_input("You:", key="user_input")
send_button = st.button("Send")

if send_button and user_input:
    st.session_state.history.append(("You", user_input))
    response = generate(user_input, st.session_state.history)
    st.session_state.history.append(("Mixtral-8x7B", response))

# Display the conversation
for index, (user, text) in enumerate(st.session_state.history):
    # Create a unique key by combining the index with a portion of the text
    unique_key = f"{index}_{text[:10]}"
    st.text_area(label=user, value=text, height=100, key=unique_key)