ChatGLM3 / composite_demo /demo_tool.py
kakuguo's picture
Upload 52 files
afd4069
raw
history blame
7.46 kB
import re
import yaml
from yaml import YAMLError
import streamlit as st
from streamlit.delta_generator import DeltaGenerator
from client import get_client
from conversation import postprocess_text, preprocess_text, Conversation, Role
from tool_registry import dispatch_tool, get_tools
MAX_LENGTH = 8192
TRUNCATE_LENGTH = 1024
EXAMPLE_TOOL = {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
}
}
client = get_client()
def tool_call(*args, **kwargs) -> dict:
print("=== Tool call:")
print(args)
print(kwargs)
st.session_state.calling_tool = True
return kwargs
def yaml_to_dict(tools: str) -> list[dict] | None:
try:
return yaml.safe_load(tools)
except YAMLError:
return None
def extract_code(text: str) -> str:
pattern = r'```([^\n]*)\n(.*?)```'
matches = re.findall(pattern, text, re.DOTALL)
return matches[-1][1]
# Append a conversation into history, while show it in a new markdown block
def append_conversation(
conversation: Conversation,
history: list[Conversation],
placeholder: DeltaGenerator | None=None,
) -> None:
history.append(conversation)
conversation.show(placeholder)
def main(top_p: float, temperature: float, prompt_text: str):
manual_mode = st.toggle('Manual mode',
help='Define your tools in YAML format. You need to supply tool call results manually.'
)
if manual_mode:
with st.expander('Tools'):
tools = st.text_area(
'Define your tools in YAML format here:',
yaml.safe_dump([EXAMPLE_TOOL], sort_keys=False),
height=400,
)
tools = yaml_to_dict(tools)
if not tools:
st.error('YAML format error in tools definition')
else:
tools = get_tools()
if 'tool_history' not in st.session_state:
st.session_state.tool_history = []
if 'calling_tool' not in st.session_state:
st.session_state.calling_tool = False
history: list[Conversation] = st.session_state.tool_history
for conversation in history:
conversation.show()
if prompt_text:
prompt_text = prompt_text.strip()
role = st.session_state.calling_tool and Role.OBSERVATION or Role.USER
append_conversation(Conversation(role, prompt_text), history)
st.session_state.calling_tool = False
input_text = preprocess_text(
None,
tools,
history,
)
print("=== Input:")
print(input_text)
print("=== History:")
print(history)
placeholder = st.container()
message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant")
markdown_placeholder = message_placeholder.empty()
for _ in range(5):
output_text = ''
for response in client.generate_stream(
system=None,
tools=tools,
history=history,
do_sample=True,
max_length=MAX_LENGTH,
temperature=temperature,
top_p=top_p,
stop_sequences=[str(r) for r in (Role.USER, Role.OBSERVATION)],
):
token = response.token
if response.token.special:
print("=== Output:")
print(output_text)
match token.text.strip():
case '<|user|>':
append_conversation(Conversation(
Role.ASSISTANT,
postprocess_text(output_text),
), history, markdown_placeholder)
return
# Initiate tool call
case '<|assistant|>':
append_conversation(Conversation(
Role.ASSISTANT,
postprocess_text(output_text),
), history, markdown_placeholder)
output_text = ''
message_placeholder = placeholder.chat_message(name="tool", avatar="assistant")
markdown_placeholder = message_placeholder.empty()
continue
case '<|observation|>':
tool, *output_text = output_text.strip().split('\n')
output_text = '\n'.join(output_text)
append_conversation(Conversation(
Role.TOOL,
postprocess_text(output_text),
tool,
), history, markdown_placeholder)
message_placeholder = placeholder.chat_message(name="observation", avatar="user")
markdown_placeholder = message_placeholder.empty()
try:
code = extract_code(output_text)
args = eval(code, {'tool_call': tool_call}, {})
except:
st.error('Failed to parse tool call')
return
output_text = ''
if manual_mode:
st.info('Please provide tool call results below:')
return
else:
with markdown_placeholder:
with st.spinner(f'Calling tool {tool}...'):
observation = dispatch_tool(tool, args)
if len(observation) > TRUNCATE_LENGTH:
observation = observation[:TRUNCATE_LENGTH] + ' [TRUNCATED]'
append_conversation(Conversation(
Role.OBSERVATION, observation
), history, markdown_placeholder)
message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant")
markdown_placeholder = message_placeholder.empty()
st.session_state.calling_tool = False
break
case _:
st.error(f'Unexpected special token: {token.text.strip()}')
return
output_text += response.token.text
markdown_placeholder.markdown(postprocess_text(output_text + '▌'))
else:
append_conversation(Conversation(
Role.ASSISTANT,
postprocess_text(output_text),
), history, markdown_placeholder)
return