File size: 9,707 Bytes
ebf567b
95bf5e4
 
 
 
ebf567b
9f83bcc
6660737
 
95bf5e4
 
 
 
 
 
ebf567b
 
 
 
 
 
90d520c
 
 
 
 
 
 
 
 
 
 
 
 
9f83bcc
 
 
 
 
 
 
90d520c
 
 
 
 
95bf5e4
 
 
90d520c
95bf5e4
 
 
 
 
 
6660737
 
95bf5e4
9f83bcc
 
90d520c
95bf5e4
 
 
 
 
 
 
90d520c
 
6660737
 
90d520c
 
95bf5e4
 
90d520c
95bf5e4
90d520c
ebf567b
9f83bcc
 
ebf567b
 
9f83bcc
 
 
 
95bf5e4
9f83bcc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90d520c
9f83bcc
 
 
 
 
 
 
 
95bf5e4
 
9f83bcc
 
 
95bf5e4
9f83bcc
 
 
 
90d520c
 
95bf5e4
9f83bcc
90d520c
9f83bcc
95bf5e4
 
 
 
 
 
 
 
 
 
9f83bcc
 
 
90d520c
 
 
ebf567b
95bf5e4
 
9f83bcc
95bf5e4
 
 
 
 
 
90d520c
 
 
 
 
95bf5e4
 
 
 
 
9f83bcc
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
import functools
import json
import os
import logging
from groq import Groq
import functions
from utils import python_type, raise_error
from tools import tools

# Set up logging
logging.basicConfig(level=logging.DEBUG)

client = Groq(api_key=os.environ["GROQ_API_KEY"])

MODEL = "llama3-groq-70b-8192-tool-use-preview"
all_functions = [func for func in dir(functions) if callable(
    getattr(functions, func)) and not func.startswith("__")]
names_to_functions = {func: functools.partial(
    getattr(functions, func)) for func in all_functions}


def create_message(prompt, message_type):
    logging.debug(
        f"Creating message with prompt: {prompt} and message type: {message_type}")
    system_message = ""

    if message_type == "reasoning_chain":
        system_message = (
            "You are a movie search assistant bot who uses TMDB to help users "
            "find movies. Think step by step and identify the sequence of "
            "reasoning steps that will help to answer the user's query."
        )
    elif message_type == "function_call":
        system_message = (
            "You are a movie search assistant bot that utilizes TMDB to help users find movies. "
            "Approach each query step by step, determining the sequence of function calls needed to gather the necessary information. "
            "Execute functions sequentially, using the output from one function to inform the next function call when required. "
            "Only call multiple functions simultaneously when they can run independently of each other. "
            "Once you have identified all the required parameters from previous calls, "
            "finalize your process with a discover_movie function call that returns a list of movie IDs. "
            "Ensure that this call includes all necessary parameters to accurately filter the movies."
        )
    else:
        raise ValueError(
            "Invalid message type. Expected 'reasoning_chain' or 'function_call'")

    return [
        {
            "role": "system",
            "content": system_message,
        },
        {
            "role": "user",
            "content": prompt,
        },
    ]


def get_response(client, model, messages, tool_choice="auto"):
    logging.info(
        f"Getting response with model: {model}, \nmessages: {json.dumps(messages, indent=2)}, \ntool_choice: {tool_choice}")
    response = client.chat.completions.create(
        model=model,
        messages=messages,
        tools=tools,
        tool_choice=tool_choice,
        temperature=0,
        max_tokens=4096,
    )
    logging.debug(f"Response: {response}")
    return response


def generate_reasoning_chain(user_prompt):
    messages = create_message(user_prompt, "reasoning_chain")
    logging.debug(f"Generating reasoning chain with messages: {messages}")
    cot_response = get_response(client, MODEL, messages, tool_choice="none")
    logging.info(f"COT response: {cot_response.choices[0].message.content}")
    if cot_response.choices[0].finish_reason == "stop":
        return cot_response.choices[0]
    else:
        raise_error("Failed to generate reasoning chain. Got response: " +
                    str(cot_response), Exception)


def validate_params(tool_params, param_name, param_value):
    """
    Checks if the parameter value matches with the one defined in tools.py
    """
    logging.debug(
        f"Validating parameter: {param_name} with value: {param_value}")
    param_def = tool_params.get(param_name, None)
    if param_def is None:
        logging.error(
            f"Parameter {param_name} not found in tools. Dropping this tool call.")
        return False
    try:
        param_value = python_type(param_def["type"])(param_value)
    except ValueError:
        logging.error(
            f"Parameter {param_name} value cannot be cast to {param_def['type']}. Dropping this tool call.")
        return False
    return True


def extract_leaf_values(json_obj):
    """Recursively extract leaf values from a JSON object or string."""
    # Check if the input is a string and try to parse it
    if isinstance(json_obj, str):
        try:
            json_obj = json.loads(json_obj)
        except json.JSONDecodeError:
            return [json_obj]  # Return the string if it's not valid JSON

    if isinstance(json_obj, dict):
        values = []
        for value in json_obj.values():
            values.extend(extract_leaf_values(value))
        return values
    elif isinstance(json_obj, list):
        values = []
        for item in json_obj:
            values.extend(extract_leaf_values(item))
        return values
    else:
        return [json_obj]


def is_tool_valid(tool_name):
    """Check if the tool name is valid and return its definition."""
    return next((tool for tool in tools if tool["function"]["name"] == tool_name), None)


def validate_tool_parameters(tool_def, tool_args):
    """Validate the parameters of the tool against its definition."""
    tool_params = tool_def["function"]["parameters"]["properties"]
    for param_name, param_value in tool_args.items():
        if not validate_params(tool_params, param_name, param_value):
            logging.error(
                f"Invalid parameter {param_name} for tool {tool_def['function']['name']}. Dropping this tool call.")
            return False
    return True


def are_arguments_valid(tool_args, user_query_values, previous_values):
    """Check if all argument values are valid."""
    arg_values = tool_args.values()
    return all(str(value) in user_query_values or value in previous_values for value in arg_values)


def verify_tool_calls(tool_calls, messages):
    """
    Verify tool calls based on user query and previous tool outputs.

    :param tool_calls: List of tool calls with arguments.
    :param messages: List containing user query and previous tool outputs.
    :return: List of valid tool calls.
    """
    # Extract user query from the first message with role 'user'
    user_query_values = next((msg["content"]
                              for msg in messages if msg["role"] == "user"), None)

    # Extract previous tool outputs from messages with role 'tool'
    previous_tool_outputs = [msg["content"]
                             for msg in messages if msg["role"] == "tool"]

    previous_values = [
        value for output in previous_tool_outputs for value in extract_leaf_values(output)]

    valid_tool_calls = []

    for tool_call in tool_calls:
        tool_name = tool_call.function.name
        tool_args = json.loads(tool_call.function.arguments)

        tool_def = is_tool_valid(tool_name)
        if tool_def:
            if validate_tool_parameters(tool_def, tool_args):
                valid_tool_calls.append(tool_call)
        else:
            logging.error(
                f"Tool {tool_name} not found in tools. Dropping this tool call.")

    tool_calls_str = [json.dumps(tool_call.__dict__, default=str)
                      for tool_call in valid_tool_calls]
    logging.info(
        'Tool calls validated successfully. Valid tool calls are: %s', tool_calls_str)
    return valid_tool_calls


def gather_movie_data(messages):
    logging.debug(f"Gathering movie data with messages: {messages}")
    response = get_response(client, MODEL, messages, tool_choice="required")
    logging.debug(f"Calling tools based on the response: {response}")
    if response.choices[0].finish_reason == "tool_calls":
        tool_calls = response.choices[0].message.tool_calls
        # validate tool calls
        valid_tool_calls = verify_tool_calls(tool_calls, messages)
        # valid_tool_calls = tool_calls
        updated_messages = messages.copy()
        tool_messages_count = len(
            [msg for msg in messages if msg["role"] == "tool"])
        if tool_messages_count <= 3 and valid_tool_calls:
            tool_call = valid_tool_calls[0]  # Run one tool call at a time
            logging.info(
                f"Tool call: {tool_call.function.name}, Tool call parameters: {tool_call.function.arguments}")
            tool_output = execute_tool(tool_call)
            logging.debug(
                f"Tool call output: {json.dumps(tool_output, indent=2)}")
            if tool_call.function.name == "discover_movie" or tool_messages_count > 3:
                return tool_output["results"]  # A list of movies
            else:
                updated_messages.append(
                    {
                        "tool_call_id": tool_call.id,
                        "role": "tool",
                        "name": tool_call.function.name,
                        "content": str(tool_output),
                    }
                )
            return gather_movie_data(updated_messages)
        else:
            return "No results found"
    else:
        raise Exception(
            "Failed to gather movie data. Got response: ", response)


def execute_tool(tool_call):
    logging.info(f"Executing tool: {tool_call.function.name}")
    function_to_call = names_to_functions[tool_call.function.name]
    function_args = json.loads(tool_call.function.arguments)
    return function_to_call(**function_args)


def chatbot(user_prompt):
    cot_response_choice = generate_reasoning_chain(user_prompt)
    cot = create_message(user_prompt, "function_call")
    cot.append({
        'role': cot_response_choice.message.role,
        'content': cot_response_choice.message.content})
    movie_list = gather_movie_data(cot)
    return movie_list


if __name__ == "__main__":
    print(json.dumps(chatbot("Movies of Tom Hanks that were released in 1993"), indent=2))
    # print(json.dumps(chatbot("List some movies of Tom Cruise"), indent=2))