tmzh commited on
Commit
9f83bcc
·
1 Parent(s): 90d520c

call one tool at a time

Browse files
Files changed (2) hide show
  1. agent.py +134 -24
  2. utils.py +23 -0
agent.py CHANGED
@@ -4,6 +4,7 @@ import os
4
  import logging
5
  from groq import Groq
6
  import functions
 
7
  from tools import tools
8
 
9
  # Set up logging
@@ -31,14 +32,13 @@ def create_message(prompt, message_type):
31
  )
32
  elif message_type == "function_call":
33
  system_message = (
34
- "You are a movie search assistant bot who uses TMDB to help users "
35
- "find movies. Think step by step and identify the sequence of "
36
- "function calls that will help to answer the user's query. Use the "
37
- "available functions to gather the necessary data. "
38
- "Do not call multiple functions when they need to be executed in sequence. "
39
- "Only call multiple functions when they can be executed in parallel. "
40
- "Stop with a discover_movie function call that returns a list of movie ids. "
41
- "Ensure the discover_movie function call includes all the necessary parameters to filter the movies accurately."
42
  )
43
  else:
44
  raise ValueError(
@@ -57,8 +57,8 @@ def create_message(prompt, message_type):
57
 
58
 
59
  def get_response(client, model, messages, tool_choice="auto"):
60
- logging.debug(
61
- f"Getting response with model: {model}, messages: {messages}, tool_choice: {tool_choice}")
62
  response = client.chat.completions.create(
63
  model=model,
64
  messages=messages,
@@ -79,27 +79,135 @@ def generate_reasoning_chain(user_prompt):
79
  if cot_response.choices[0].finish_reason == "stop":
80
  return cot_response.choices[0]
81
  else:
82
- logging.error(
83
- "Failed to generate reasoning chain. Got response: ", cot_response)
84
- raise Exception("Failed to generate reasoning chain")
85
 
86
 
87
- def gather_movie_data(messages, iteration=0, max_iterations=2):
 
 
 
88
  logging.debug(
89
- f"Gathering movie data with messages: {messages}, iteration: {iteration}")
90
- response = get_response(client, MODEL, messages, tool_choice="required")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  logging.info(
92
- f"Gathering movie data response: {response}")
 
 
 
 
 
 
 
93
  if response.choices[0].finish_reason == "tool_calls":
94
  tool_calls = response.choices[0].message.tool_calls
 
 
 
95
  updated_messages = messages.copy()
96
- for tool_call in tool_calls:
 
 
 
97
  logging.info(
98
  f"Tool call: {tool_call.function.name}, Tool call parameters: {tool_call.function.arguments}")
99
  tool_output = execute_tool(tool_call)
100
- logging.info(
101
  f"Tool call output: {json.dumps(tool_output, indent=2)}")
102
- if tool_call.function.name == "discover_movie":
103
  return tool_output["results"] # A list of movies
104
  else:
105
  updated_messages.append(
@@ -110,15 +218,16 @@ def gather_movie_data(messages, iteration=0, max_iterations=2):
110
  "content": str(tool_output),
111
  }
112
  )
113
- if iteration < max_iterations:
114
- return gather_movie_data(updated_messages, iteration + 1)
 
115
  else:
116
  raise Exception(
117
  "Failed to gather movie data. Got response: ", response)
118
 
119
 
120
  def execute_tool(tool_call):
121
- logging.debug(f"Executing tool: {tool_call.function.name}")
122
  function_to_call = names_to_functions[tool_call.function.name]
123
  function_args = json.loads(tool_call.function.arguments)
124
  return function_to_call(**function_args)
@@ -135,4 +244,5 @@ def chatbot(user_prompt):
135
 
136
 
137
  if __name__ == "__main__":
138
- print(json.dumps(chatbot("List comedy movies with tom cruise in it"), indent=2))
 
 
4
  import logging
5
  from groq import Groq
6
  import functions
7
+ from utils import python_type, raise_error
8
  from tools import tools
9
 
10
  # Set up logging
 
32
  )
33
  elif message_type == "function_call":
34
  system_message = (
35
+ "You are a movie search assistant bot that utilizes TMDB to help users find movies. "
36
+ "Approach each query step by step, determining the sequence of function calls needed to gather the necessary information. "
37
+ "Execute functions sequentially, using the output from one function to inform the next function call when required. "
38
+ "Only call multiple functions simultaneously when they can run independently of each other. "
39
+ "Once you have identified all the required parameters from previous calls, "
40
+ "finalize your process with a discover_movie function call that returns a list of movie IDs. "
41
+ "Ensure that this call includes all necessary parameters to accurately filter the movies."
 
42
  )
43
  else:
44
  raise ValueError(
 
57
 
58
 
59
  def get_response(client, model, messages, tool_choice="auto"):
60
+ logging.info(
61
+ f"Getting response with model: {model}, \nmessages: {json.dumps(messages, indent=2)}, \ntool_choice: {tool_choice}")
62
  response = client.chat.completions.create(
63
  model=model,
64
  messages=messages,
 
79
  if cot_response.choices[0].finish_reason == "stop":
80
  return cot_response.choices[0]
81
  else:
82
+ raise_error("Failed to generate reasoning chain. Got response: " +
83
+ str(cot_response), Exception)
 
84
 
85
 
86
+ def validate_params(tool_params, param_name, param_value):
87
+ """
88
+ Checks if the parameter value matches with the one defined in tools.py
89
+ """
90
  logging.debug(
91
+ f"Validating parameter: {param_name} with value: {param_value}")
92
+ param_def = tool_params.get(param_name, None)
93
+ if param_def is None:
94
+ logging.error(
95
+ f"Parameter {param_name} not found in tools. Dropping this tool call.")
96
+ return False
97
+ try:
98
+ param_value = python_type(param_def["type"])(param_value)
99
+ except ValueError:
100
+ logging.error(
101
+ f"Parameter {param_name} value cannot be cast to {param_def['type']}. Dropping this tool call.")
102
+ return False
103
+ return True
104
+
105
+
106
+ def extract_leaf_values(json_obj):
107
+ """Recursively extract leaf values from a JSON object or string."""
108
+ # Check if the input is a string and try to parse it
109
+ if isinstance(json_obj, str):
110
+ try:
111
+ json_obj = json.loads(json_obj)
112
+ except json.JSONDecodeError:
113
+ return [json_obj] # Return the string if it's not valid JSON
114
+
115
+ if isinstance(json_obj, dict):
116
+ values = []
117
+ for value in json_obj.values():
118
+ values.extend(extract_leaf_values(value))
119
+ return values
120
+ elif isinstance(json_obj, list):
121
+ values = []
122
+ for item in json_obj:
123
+ values.extend(extract_leaf_values(item))
124
+ return values
125
+ else:
126
+ return [json_obj]
127
+
128
+
129
+ def is_tool_valid(tool_name):
130
+ """Check if the tool name is valid and return its definition."""
131
+ return next((tool for tool in tools if tool["function"]["name"] == tool_name), None)
132
+
133
+
134
+ def validate_tool_parameters(tool_def, tool_args):
135
+ """Validate the parameters of the tool against its definition."""
136
+ tool_params = tool_def["function"]["parameters"]["properties"]
137
+ for param_name, param_value in tool_args.items():
138
+ if not validate_params(tool_params, param_name, param_value):
139
+ logging.error(
140
+ f"Invalid parameter {param_name} for tool {tool_def['function']['name']}. Dropping this tool call.")
141
+ return False
142
+ return True
143
+
144
+
145
+ def are_arguments_valid(tool_args, user_query_values, previous_values):
146
+ """Check if all argument values are valid."""
147
+ arg_values = tool_args.values()
148
+ return all(str(value) in user_query_values or value in previous_values for value in arg_values)
149
+
150
+
151
+ def verify_tool_calls(tool_calls, messages):
152
+ """
153
+ Verify tool calls based on user query and previous tool outputs.
154
+
155
+ :param tool_calls: List of tool calls with arguments.
156
+ :param messages: List containing user query and previous tool outputs.
157
+ :return: List of valid tool calls.
158
+ """
159
+ # Extract user query from the first message with role 'user'
160
+ user_query_values = next((msg["content"]
161
+ for msg in messages if msg["role"] == "user"), None)
162
+
163
+ # Extract previous tool outputs from messages with role 'tool'
164
+ previous_tool_outputs = [msg["content"]
165
+ for msg in messages if msg["role"] == "tool"]
166
+
167
+ previous_values = [
168
+ value for output in previous_tool_outputs for value in extract_leaf_values(output)]
169
+
170
+ valid_tool_calls = []
171
+
172
+ for tool_call in tool_calls:
173
+ tool_name = tool_call.function.name
174
+ tool_args = json.loads(tool_call.function.arguments)
175
+
176
+ tool_def = is_tool_valid(tool_name)
177
+ if tool_def:
178
+ if validate_tool_parameters(tool_def, tool_args):
179
+ valid_tool_calls.append(tool_call)
180
+ else:
181
+ logging.error(
182
+ f"Tool {tool_name} not found in tools. Dropping this tool call.")
183
+
184
+ tool_calls_str = [json.dumps(tool_call.__dict__, default=str)
185
+ for tool_call in valid_tool_calls]
186
  logging.info(
187
+ 'Tool calls validated successfully. Valid tool calls are: %s', tool_calls_str)
188
+ return valid_tool_calls
189
+
190
+
191
+ def gather_movie_data(messages):
192
+ logging.debug(f"Gathering movie data with messages: {messages}")
193
+ response = get_response(client, MODEL, messages, tool_choice="required")
194
+ logging.debug(f"Calling tools based on the response: {response}")
195
  if response.choices[0].finish_reason == "tool_calls":
196
  tool_calls = response.choices[0].message.tool_calls
197
+ # validate tool calls
198
+ valid_tool_calls = verify_tool_calls(tool_calls, messages)
199
+ # valid_tool_calls = tool_calls
200
  updated_messages = messages.copy()
201
+ tool_messages_count = len(
202
+ [msg for msg in messages if msg["role"] == "tool"])
203
+ if tool_messages_count <= 3 and valid_tool_calls:
204
+ tool_call = valid_tool_calls[0] # Run one tool call at a time
205
  logging.info(
206
  f"Tool call: {tool_call.function.name}, Tool call parameters: {tool_call.function.arguments}")
207
  tool_output = execute_tool(tool_call)
208
+ logging.debug(
209
  f"Tool call output: {json.dumps(tool_output, indent=2)}")
210
+ if tool_call.function.name == "discover_movie" or tool_messages_count > 3:
211
  return tool_output["results"] # A list of movies
212
  else:
213
  updated_messages.append(
 
218
  "content": str(tool_output),
219
  }
220
  )
221
+ return gather_movie_data(updated_messages)
222
+ else:
223
+ return "No results found"
224
  else:
225
  raise Exception(
226
  "Failed to gather movie data. Got response: ", response)
227
 
228
 
229
  def execute_tool(tool_call):
230
+ logging.info(f"Executing tool: {tool_call.function.name}")
231
  function_to_call = names_to_functions[tool_call.function.name]
232
  function_args = json.loads(tool_call.function.arguments)
233
  return function_to_call(**function_args)
 
244
 
245
 
246
  if __name__ == "__main__":
247
+ print(json.dumps(chatbot("Movies of Tom Hanks that were released in 1993"), indent=2))
248
+ # print(json.dumps(chatbot("List some movies of Tom Cruise"), indent=2))
utils.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+
4
+ def raise_error(error_message, error_type=Exception):
5
+ logging.error(error_message)
6
+ raise error_type(error_message)
7
+
8
+
9
+ def python_type(type_str):
10
+ if type_str == "string":
11
+ return str
12
+ elif type_str == "integer":
13
+ return int
14
+ elif type_str == "number":
15
+ return float
16
+ elif type_str == "boolean":
17
+ return bool
18
+ elif type_str == "array":
19
+ return list
20
+ elif type_str == "object":
21
+ return dict
22
+ else:
23
+ raise ValueError(f"Unknown type: {type_str}")