File size: 7,461 Bytes
afd4069 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
import re
import yaml
from yaml import YAMLError
import streamlit as st
from streamlit.delta_generator import DeltaGenerator
from client import get_client
from conversation import postprocess_text, preprocess_text, Conversation, Role
from tool_registry import dispatch_tool, get_tools
MAX_LENGTH = 8192
TRUNCATE_LENGTH = 1024
EXAMPLE_TOOL = {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
}
}
client = get_client()
def tool_call(*args, **kwargs) -> dict:
print("=== Tool call:")
print(args)
print(kwargs)
st.session_state.calling_tool = True
return kwargs
def yaml_to_dict(tools: str) -> list[dict] | None:
try:
return yaml.safe_load(tools)
except YAMLError:
return None
def extract_code(text: str) -> str:
pattern = r'```([^\n]*)\n(.*?)```'
matches = re.findall(pattern, text, re.DOTALL)
return matches[-1][1]
# Append a conversation into history, while show it in a new markdown block
def append_conversation(
conversation: Conversation,
history: list[Conversation],
placeholder: DeltaGenerator | None=None,
) -> None:
history.append(conversation)
conversation.show(placeholder)
def main(top_p: float, temperature: float, prompt_text: str):
manual_mode = st.toggle('Manual mode',
help='Define your tools in YAML format. You need to supply tool call results manually.'
)
if manual_mode:
with st.expander('Tools'):
tools = st.text_area(
'Define your tools in YAML format here:',
yaml.safe_dump([EXAMPLE_TOOL], sort_keys=False),
height=400,
)
tools = yaml_to_dict(tools)
if not tools:
st.error('YAML format error in tools definition')
else:
tools = get_tools()
if 'tool_history' not in st.session_state:
st.session_state.tool_history = []
if 'calling_tool' not in st.session_state:
st.session_state.calling_tool = False
history: list[Conversation] = st.session_state.tool_history
for conversation in history:
conversation.show()
if prompt_text:
prompt_text = prompt_text.strip()
role = st.session_state.calling_tool and Role.OBSERVATION or Role.USER
append_conversation(Conversation(role, prompt_text), history)
st.session_state.calling_tool = False
input_text = preprocess_text(
None,
tools,
history,
)
print("=== Input:")
print(input_text)
print("=== History:")
print(history)
placeholder = st.container()
message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant")
markdown_placeholder = message_placeholder.empty()
for _ in range(5):
output_text = ''
for response in client.generate_stream(
system=None,
tools=tools,
history=history,
do_sample=True,
max_length=MAX_LENGTH,
temperature=temperature,
top_p=top_p,
stop_sequences=[str(r) for r in (Role.USER, Role.OBSERVATION)],
):
token = response.token
if response.token.special:
print("=== Output:")
print(output_text)
match token.text.strip():
case '<|user|>':
append_conversation(Conversation(
Role.ASSISTANT,
postprocess_text(output_text),
), history, markdown_placeholder)
return
# Initiate tool call
case '<|assistant|>':
append_conversation(Conversation(
Role.ASSISTANT,
postprocess_text(output_text),
), history, markdown_placeholder)
output_text = ''
message_placeholder = placeholder.chat_message(name="tool", avatar="assistant")
markdown_placeholder = message_placeholder.empty()
continue
case '<|observation|>':
tool, *output_text = output_text.strip().split('\n')
output_text = '\n'.join(output_text)
append_conversation(Conversation(
Role.TOOL,
postprocess_text(output_text),
tool,
), history, markdown_placeholder)
message_placeholder = placeholder.chat_message(name="observation", avatar="user")
markdown_placeholder = message_placeholder.empty()
try:
code = extract_code(output_text)
args = eval(code, {'tool_call': tool_call}, {})
except:
st.error('Failed to parse tool call')
return
output_text = ''
if manual_mode:
st.info('Please provide tool call results below:')
return
else:
with markdown_placeholder:
with st.spinner(f'Calling tool {tool}...'):
observation = dispatch_tool(tool, args)
if len(observation) > TRUNCATE_LENGTH:
observation = observation[:TRUNCATE_LENGTH] + ' [TRUNCATED]'
append_conversation(Conversation(
Role.OBSERVATION, observation
), history, markdown_placeholder)
message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant")
markdown_placeholder = message_placeholder.empty()
st.session_state.calling_tool = False
break
case _:
st.error(f'Unexpected special token: {token.text.strip()}')
return
output_text += response.token.text
markdown_placeholder.markdown(postprocess_text(output_text + '▌'))
else:
append_conversation(Conversation(
Role.ASSISTANT,
postprocess_text(output_text),
), history, markdown_placeholder)
return
|