File size: 7,461 Bytes
afd4069
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import re
import yaml
from yaml import YAMLError

import streamlit as st
from streamlit.delta_generator import DeltaGenerator

from client import get_client
from conversation import postprocess_text, preprocess_text, Conversation, Role
from tool_registry import dispatch_tool, get_tools

MAX_LENGTH = 8192
TRUNCATE_LENGTH = 1024

EXAMPLE_TOOL = {
    "name": "get_current_weather",
    "description": "Get the current weather in a given location",
    "parameters": {
        "type": "object",
        "properties": {
            "location": {
                "type": "string",
                "description": "The city and state, e.g. San Francisco, CA",
            },
            "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
        },
        "required": ["location"],
    }
}

client = get_client()

def tool_call(*args, **kwargs) -> dict:
    print("=== Tool call:")
    print(args)
    print(kwargs)
    st.session_state.calling_tool = True
    return kwargs

def yaml_to_dict(tools: str) -> list[dict] | None:
    try:
        return yaml.safe_load(tools)
    except YAMLError:
        return None

def extract_code(text: str) -> str:
    pattern = r'```([^\n]*)\n(.*?)```'
    matches = re.findall(pattern, text, re.DOTALL)
    return matches[-1][1]

# Append a conversation into history, while show it in a new markdown block
def append_conversation(
    conversation: Conversation,
    history: list[Conversation],
    placeholder: DeltaGenerator | None=None,
) -> None:
    history.append(conversation)
    conversation.show(placeholder)

def main(top_p: float, temperature: float, prompt_text: str):
    manual_mode = st.toggle('Manual mode',
        help='Define your tools in YAML format. You need to supply tool call results manually.'
    )

    if manual_mode:
        with st.expander('Tools'):
            tools = st.text_area(
                'Define your tools in YAML format here:',
                yaml.safe_dump([EXAMPLE_TOOL], sort_keys=False),
                height=400,
            )
        tools = yaml_to_dict(tools)

        if not tools:
            st.error('YAML format error in tools definition')
    else:
        tools = get_tools()

    if 'tool_history' not in st.session_state:
        st.session_state.tool_history = []
    if 'calling_tool' not in st.session_state:
        st.session_state.calling_tool = False

    history: list[Conversation] = st.session_state.tool_history

    for conversation in history:
        conversation.show()

    if prompt_text:
        prompt_text = prompt_text.strip()
        role = st.session_state.calling_tool and Role.OBSERVATION or Role.USER
        append_conversation(Conversation(role, prompt_text), history)
        st.session_state.calling_tool = False

        input_text = preprocess_text(
            None,
            tools,
            history,
        )
        print("=== Input:")
        print(input_text)
        print("=== History:")
        print(history)

        placeholder = st.container()
        message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant")
        markdown_placeholder = message_placeholder.empty()

        for _ in range(5):
            output_text = ''
            for response in client.generate_stream(
                system=None,
                tools=tools,
                history=history,
                do_sample=True,
                max_length=MAX_LENGTH,
                temperature=temperature,
                top_p=top_p,
                stop_sequences=[str(r) for r in (Role.USER, Role.OBSERVATION)],
            ):
                token = response.token
                if response.token.special:
                    print("=== Output:")
                    print(output_text)

                    match token.text.strip():
                        case '<|user|>':
                            append_conversation(Conversation(
                                Role.ASSISTANT,
                                postprocess_text(output_text),
                            ), history, markdown_placeholder)
                            return
                        # Initiate tool call
                        case '<|assistant|>':
                            append_conversation(Conversation(
                                Role.ASSISTANT,
                                postprocess_text(output_text),
                            ), history, markdown_placeholder)
                            output_text = ''
                            message_placeholder = placeholder.chat_message(name="tool", avatar="assistant")
                            markdown_placeholder = message_placeholder.empty()
                            continue
                        case '<|observation|>':
                            tool, *output_text = output_text.strip().split('\n')
                            output_text = '\n'.join(output_text)
                            
                            append_conversation(Conversation(
                                Role.TOOL,
                                postprocess_text(output_text),
                                tool,
                            ), history, markdown_placeholder)
                            message_placeholder = placeholder.chat_message(name="observation", avatar="user")
                            markdown_placeholder = message_placeholder.empty()
                            
                            try:
                                code = extract_code(output_text)
                                args = eval(code, {'tool_call': tool_call}, {})
                            except:
                                st.error('Failed to parse tool call')
                                return
                            
                            output_text = ''
                            
                            if manual_mode:
                                st.info('Please provide tool call results below:')
                                return
                            else:
                                with markdown_placeholder:
                                    with st.spinner(f'Calling tool {tool}...'):
                                        observation = dispatch_tool(tool, args)

                                if len(observation) > TRUNCATE_LENGTH:
                                    observation = observation[:TRUNCATE_LENGTH] + ' [TRUNCATED]'
                                append_conversation(Conversation(
                                    Role.OBSERVATION, observation
                                ), history, markdown_placeholder)
                                message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant")
                                markdown_placeholder = message_placeholder.empty()
                                st.session_state.calling_tool = False
                                break
                        case _:
                            st.error(f'Unexpected special token: {token.text.strip()}')
                            return
                output_text += response.token.text
                markdown_placeholder.markdown(postprocess_text(output_text + '▌'))
            else:
                append_conversation(Conversation(
                    Role.ASSISTANT,
                    postprocess_text(output_text),
                ), history, markdown_placeholder)
                return