File size: 4,005 Bytes
afd4069 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
from dataclasses import dataclass
from enum import auto, Enum
import json
from PIL.Image import Image
import streamlit as st
from streamlit.delta_generator import DeltaGenerator
TOOL_PROMPT = 'Answer the following questions as best as you can. You have access to the following tools:\n'
class Role(Enum):
SYSTEM = auto()
USER = auto()
ASSISTANT = auto()
TOOL = auto()
INTERPRETER = auto()
OBSERVATION = auto()
def __str__(self):
match self:
case Role.SYSTEM:
return "<|system|>"
case Role.USER:
return "<|user|>"
case Role.ASSISTANT | Role.TOOL | Role.INTERPRETER:
return "<|assistant|>"
case Role.OBSERVATION:
return "<|observation|>"
# Get the message block for the given role
def get_message(self):
# Compare by value here, because the enum object in the session state
# is not the same as the enum cases here, due to streamlit's rerunning
# behavior.
match self.value:
case Role.SYSTEM.value:
return
case Role.USER.value:
return st.chat_message(name="user", avatar="user")
case Role.ASSISTANT.value:
return st.chat_message(name="assistant", avatar="assistant")
case Role.TOOL.value:
return st.chat_message(name="tool", avatar="assistant")
case Role.INTERPRETER.value:
return st.chat_message(name="interpreter", avatar="assistant")
case Role.OBSERVATION.value:
return st.chat_message(name="observation", avatar="user")
case _:
st.error(f'Unexpected role: {self}')
@dataclass
class Conversation:
role: Role
content: str
tool: str | None = None
image: Image | None = None
def __str__(self) -> str:
print(self.role, self.content, self.tool)
match self.role:
case Role.SYSTEM | Role.USER | Role.ASSISTANT | Role.OBSERVATION:
return f'{self.role}\n{self.content}'
case Role.TOOL:
return f'{self.role}{self.tool}\n{self.content}'
case Role.INTERPRETER:
return f'{self.role}interpreter\n{self.content}'
# Human readable format
def get_text(self) -> str:
text = postprocess_text(self.content)
match self.role.value:
case Role.TOOL.value:
text = f'Calling tool `{self.tool}`:\n{text}'
case Role.INTERPRETER.value:
text = f'{text}'
case Role.OBSERVATION.value:
text = f'Observation:\n```\n{text}\n```'
return text
# Display as a markdown block
def show(self, placeholder: DeltaGenerator | None=None) -> str:
if placeholder:
message = placeholder
else:
message = self.role.get_message()
if self.image:
message.image(self.image)
else:
text = self.get_text()
message.markdown(text)
def preprocess_text(
system: str | None,
tools: list[dict] | None,
history: list[Conversation],
) -> str:
if tools:
tools = json.dumps(tools, indent=4, ensure_ascii=False)
prompt = f"{Role.SYSTEM}\n"
prompt += system if not tools else TOOL_PROMPT
if tools:
tools = json.loads(tools)
prompt += json.dumps(tools, ensure_ascii=False)
for conversation in history:
prompt += f'{conversation}'
prompt += f'{Role.ASSISTANT}\n'
return prompt
def postprocess_text(text: str) -> str:
text = text.replace("\(", "$")
text = text.replace("\)", "$")
text = text.replace("\[", "$$")
text = text.replace("\]", "$$")
text = text.replace("<|assistant|>", "")
text = text.replace("<|observation|>", "")
text = text.replace("<|system|>", "")
text = text.replace("<|user|>", "")
return text.strip() |