|
import re |
|
import yaml |
|
from yaml import YAMLError |
|
|
|
import streamlit as st |
|
from streamlit.delta_generator import DeltaGenerator |
|
|
|
from client import get_client |
|
from conversation import postprocess_text, preprocess_text, Conversation, Role |
|
from tool_registry import dispatch_tool, get_tools |
|
|
|
MAX_LENGTH = 8192 |
|
TRUNCATE_LENGTH = 1024 |
|
|
|
EXAMPLE_TOOL = { |
|
"name": "get_current_weather", |
|
"description": "Get the current weather in a given location", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"location": { |
|
"type": "string", |
|
"description": "The city and state, e.g. San Francisco, CA", |
|
}, |
|
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, |
|
}, |
|
"required": ["location"], |
|
} |
|
} |
|
|
|
client = get_client() |
|
|
|
def tool_call(*args, **kwargs) -> dict: |
|
print("=== Tool call:") |
|
print(args) |
|
print(kwargs) |
|
st.session_state.calling_tool = True |
|
return kwargs |
|
|
|
def yaml_to_dict(tools: str) -> list[dict] | None: |
|
try: |
|
return yaml.safe_load(tools) |
|
except YAMLError: |
|
return None |
|
|
|
def extract_code(text: str) -> str: |
|
pattern = r'```([^\n]*)\n(.*?)```' |
|
matches = re.findall(pattern, text, re.DOTALL) |
|
return matches[-1][1] |
|
|
|
|
|
def append_conversation( |
|
conversation: Conversation, |
|
history: list[Conversation], |
|
placeholder: DeltaGenerator | None=None, |
|
) -> None: |
|
history.append(conversation) |
|
conversation.show(placeholder) |
|
|
|
def main(top_p: float, temperature: float, prompt_text: str): |
|
manual_mode = st.toggle('Manual mode', |
|
help='Define your tools in YAML format. You need to supply tool call results manually.' |
|
) |
|
|
|
if manual_mode: |
|
with st.expander('Tools'): |
|
tools = st.text_area( |
|
'Define your tools in YAML format here:', |
|
yaml.safe_dump([EXAMPLE_TOOL], sort_keys=False), |
|
height=400, |
|
) |
|
tools = yaml_to_dict(tools) |
|
|
|
if not tools: |
|
st.error('YAML format error in tools definition') |
|
else: |
|
tools = get_tools() |
|
|
|
if 'tool_history' not in st.session_state: |
|
st.session_state.tool_history = [] |
|
if 'calling_tool' not in st.session_state: |
|
st.session_state.calling_tool = False |
|
|
|
history: list[Conversation] = st.session_state.tool_history |
|
|
|
for conversation in history: |
|
conversation.show() |
|
|
|
if prompt_text: |
|
prompt_text = prompt_text.strip() |
|
role = st.session_state.calling_tool and Role.OBSERVATION or Role.USER |
|
append_conversation(Conversation(role, prompt_text), history) |
|
st.session_state.calling_tool = False |
|
|
|
input_text = preprocess_text( |
|
None, |
|
tools, |
|
history, |
|
) |
|
print("=== Input:") |
|
print(input_text) |
|
print("=== History:") |
|
print(history) |
|
|
|
placeholder = st.container() |
|
message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant") |
|
markdown_placeholder = message_placeholder.empty() |
|
|
|
for _ in range(5): |
|
output_text = '' |
|
for response in client.generate_stream( |
|
system=None, |
|
tools=tools, |
|
history=history, |
|
do_sample=True, |
|
max_length=MAX_LENGTH, |
|
temperature=temperature, |
|
top_p=top_p, |
|
stop_sequences=[str(r) for r in (Role.USER, Role.OBSERVATION)], |
|
): |
|
token = response.token |
|
if response.token.special: |
|
print("=== Output:") |
|
print(output_text) |
|
|
|
match token.text.strip(): |
|
case '<|user|>': |
|
append_conversation(Conversation( |
|
Role.ASSISTANT, |
|
postprocess_text(output_text), |
|
), history, markdown_placeholder) |
|
return |
|
|
|
case '<|assistant|>': |
|
append_conversation(Conversation( |
|
Role.ASSISTANT, |
|
postprocess_text(output_text), |
|
), history, markdown_placeholder) |
|
output_text = '' |
|
message_placeholder = placeholder.chat_message(name="tool", avatar="assistant") |
|
markdown_placeholder = message_placeholder.empty() |
|
continue |
|
case '<|observation|>': |
|
tool, *output_text = output_text.strip().split('\n') |
|
output_text = '\n'.join(output_text) |
|
|
|
append_conversation(Conversation( |
|
Role.TOOL, |
|
postprocess_text(output_text), |
|
tool, |
|
), history, markdown_placeholder) |
|
message_placeholder = placeholder.chat_message(name="observation", avatar="user") |
|
markdown_placeholder = message_placeholder.empty() |
|
|
|
try: |
|
code = extract_code(output_text) |
|
args = eval(code, {'tool_call': tool_call}, {}) |
|
except: |
|
st.error('Failed to parse tool call') |
|
return |
|
|
|
output_text = '' |
|
|
|
if manual_mode: |
|
st.info('Please provide tool call results below:') |
|
return |
|
else: |
|
with markdown_placeholder: |
|
with st.spinner(f'Calling tool {tool}...'): |
|
observation = dispatch_tool(tool, args) |
|
|
|
if len(observation) > TRUNCATE_LENGTH: |
|
observation = observation[:TRUNCATE_LENGTH] + ' [TRUNCATED]' |
|
append_conversation(Conversation( |
|
Role.OBSERVATION, observation |
|
), history, markdown_placeholder) |
|
message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant") |
|
markdown_placeholder = message_placeholder.empty() |
|
st.session_state.calling_tool = False |
|
break |
|
case _: |
|
st.error(f'Unexpected special token: {token.text.strip()}') |
|
return |
|
output_text += response.token.text |
|
markdown_placeholder.markdown(postprocess_text(output_text + '▌')) |
|
else: |
|
append_conversation(Conversation( |
|
Role.ASSISTANT, |
|
postprocess_text(output_text), |
|
), history, markdown_placeholder) |
|
return |
|
|