tmzh commited on
Commit
6c9edd5
·
1 Parent(s): 3a35a20
Files changed (1) hide show
  1. agent.py +120 -161
agent.py CHANGED
@@ -2,6 +2,7 @@ import functools
2
  import json
3
  import os
4
  import logging
 
5
  from groq import Groq
6
  import functions
7
  from utils import python_type, raise_error
@@ -13,25 +14,19 @@ logging.basicConfig(level=logging.DEBUG)
13
  client = Groq(api_key=os.environ["GROQ_API_KEY"])
14
 
15
  MODEL = "llama3-groq-70b-8192-tool-use-preview"
16
- all_functions = [func for func in dir(functions) if callable(
17
- getattr(functions, func)) and not func.startswith("__")]
18
- names_to_functions = {func: functools.partial(
19
- getattr(functions, func)) for func in all_functions}
20
-
21
-
22
- def create_message(prompt, message_type):
23
- logging.debug(
24
- f"Creating message with prompt: {prompt} and message type: {message_type}")
25
- system_message = ""
26
-
27
- if message_type == "reasoning_chain":
28
- system_message = (
29
  "You are a movie search assistant bot who uses TMDB to help users "
30
  "find movies. Think step by step and identify the sequence of "
31
  "reasoning steps that will help to answer the user's query."
32
- )
33
- elif message_type == "function_call":
34
- system_message = (
35
  "You are a movie search assistant bot that utilizes TMDB to help users find movies. "
36
  "Approach each query step by step, determining the sequence of function calls needed to gather the necessary information. "
37
  "Execute functions sequentially, using the output from one function to inform the next function call when required. "
@@ -40,180 +35,144 @@ def create_message(prompt, message_type):
40
  "finalize your process with a discover_movie function call that returns a list of movie IDs. "
41
  "Ensure that this call includes all necessary parameters to accurately filter the movies."
42
  )
43
- else:
44
- raise ValueError(
45
- "Invalid message type. Expected 'reasoning_chain' or 'function_call'")
46
-
 
47
  return [
48
- {
49
- "role": "system",
50
- "content": system_message,
51
- },
52
- {
53
- "role": "user",
54
- "content": prompt,
55
- },
56
  ]
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- def get_response(client, model, messages, tool_choice="auto"):
60
- logging.debug(
61
- f"Getting response with model: {model}, \nmessages: {json.dumps(messages, indent=2)}, \ntool_choice: {tool_choice}")
62
- response = client.chat.completions.create(
63
- model=model,
64
- messages=messages,
65
- tools=tools,
66
- tool_choice=tool_choice,
67
- temperature=0,
68
- max_tokens=4096,
69
- )
70
- logging.debug(f"Response: {response}")
71
- return response
72
-
73
-
74
- def generate_reasoning_chain(user_prompt):
75
  messages = create_message(user_prompt, "reasoning_chain")
76
  logging.debug(f"Generating reasoning chain with messages: {messages}")
77
- cot_response = get_response(client, MODEL, messages, tool_choice="none")
78
- logging.info(f"COT response: {cot_response.choices[0].message.content}")
79
- if cot_response.choices[0].finish_reason == "stop":
80
- return cot_response.choices[0]
81
- else:
82
- raise_error("Failed to generate reasoning chain. Got response: " +
83
- str(cot_response), Exception)
84
-
85
-
86
- def validate_params(tool_params, param_name, param_value):
87
- """
88
- Checks if the parameter value matches with the one defined in tools.py
89
- """
90
- logging.debug(
91
- f"Validating parameter: {param_name} with value: {param_value}")
92
- param_def = tool_params.get(param_name, None)
93
  if param_def is None:
94
- logging.error(
95
- f"Parameter {param_name} not found in tools. Dropping this tool call.")
96
  return False
97
  try:
98
- param_value = python_type(param_def["type"])(param_value)
 
99
  except ValueError:
100
- logging.error(
101
- f"Parameter {param_name} value cannot be cast to {param_def['type']}. Dropping this tool call.")
102
  return False
103
- return True
104
-
105
 
106
- def is_tool_valid(tool_name):
107
- """Check if the tool name is valid and return its definition."""
108
  return next((tool for tool in tools if tool["function"]["name"] == tool_name), None)
109
 
110
-
111
- def validate_tool_parameters(tool_def, tool_args):
112
- """Validate the parameters of the tool against its definition."""
113
  tool_params = tool_def.get("function", {}).get("parameters", {}).get("properties", {})
114
  if not tool_params:
115
- # No parameters defined for this tool, so all arguments are valid
116
  return True
117
- for param_name, param_value in tool_args.items():
118
- if not validate_params(tool_params, param_name, param_value):
119
- logging.error(
120
- f"Invalid parameter {param_name} for tool {tool_def['function']['name']}. Dropping this tool call.")
121
- return False
122
- return True
123
-
124
-
125
- def are_arguments_valid(tool_args, user_query_values, previous_values):
126
- """Check if all argument values are valid."""
127
- arg_values = tool_args.values()
128
- return all(str(value) in user_query_values or value in previous_values for value in arg_values)
129
-
130
-
131
- def verify_tool_calls(tool_calls):
132
- """
133
- Verify tool calls based on user query and previous tool outputs.
134
-
135
- :param tool_calls: List of tool calls with arguments.
136
- :return: List of valid tool calls.
137
- """
138
 
 
139
  valid_tool_calls = []
140
-
141
  for tool_call in tool_calls:
142
  tool_name = tool_call.function.name
143
  tool_args = json.loads(tool_call.function.arguments)
144
-
145
  tool_def = is_tool_valid(tool_name)
146
- if tool_def:
147
- if validate_tool_parameters(tool_def, tool_args):
148
- valid_tool_calls.append(tool_call)
149
  else:
150
- logging.error(
151
- f"Tool {tool_name} not found in tools. Dropping this tool call.")
152
-
153
- tool_calls_str = json.dumps(
154
- [tool_call.__dict__ for tool_call in valid_tool_calls], default=str, indent=2)
155
- logging.info(
156
- 'Tool calls validated successfully. Valid tool calls are: %s', tool_calls_str)
157
  return valid_tool_calls
158
 
 
 
 
 
 
159
 
160
- def gather_movie_data(messages):
161
  logging.debug(f"Gathering movie data with messages: {messages}")
162
- response = get_response(client, MODEL, messages, tool_choice="required")
163
- logging.debug(f"Calling tools based on the response: {response}")
164
- if response.choices[0].finish_reason == "tool_calls":
 
 
 
 
165
  tool_calls = response.choices[0].message.tool_calls
166
- # validate tool calls
167
  valid_tool_calls = verify_tool_calls(tool_calls)
168
- # valid_tool_calls = tool_calls
169
- updated_messages = messages.copy()
170
- tool_messages_count = len(
171
- [msg for msg in messages if msg["role"] == "tool"])
172
- if tool_messages_count <= 3 and valid_tool_calls:
173
- tool_call = valid_tool_calls[0] # Run one tool call at a time
174
- logging.debug(
175
- f"Tool call: {tool_call.function.name}, Tool call parameters: {tool_call.function.arguments}")
176
- tool_output = execute_tool(tool_call)
177
- logging.debug(
178
- f"Tool call output: {json.dumps(tool_output, indent=2)}")
179
- if tool_call.function.name == "discover_movie" or tool_messages_count > 3:
180
- return tool_output["results"] # A list of movies
181
- else:
182
- updated_messages.append(
183
- {
184
- "tool_call_id": tool_call.id,
185
- "role": "tool",
186
- "name": tool_call.function.name,
187
- "content": str(tool_output),
188
- }
189
- )
190
- return gather_movie_data(updated_messages)
191
- else:
192
- return "No results found"
193
- else:
194
- raise Exception(
195
- "Failed to gather movie data. Got response: ", response)
196
-
197
-
198
- def execute_tool(tool_call):
199
- logging.info(
200
- f"Executing tool: \n Name: {tool_call.function.name}\n Parameters: {tool_call.function.arguments}")
201
-
202
- function_to_call = names_to_functions[tool_call.function.name]
203
- function_args = json.loads(tool_call.function.arguments)
204
- return function_to_call(**function_args)
205
-
206
-
207
- def chatbot(user_prompt):
208
- cot_response_choice = generate_reasoning_chain(user_prompt)
209
- cot = create_message(user_prompt, "function_call")
210
- cot.append({
211
- 'role': cot_response_choice.message.role,
212
- 'content': cot_response_choice.message.content})
213
- movie_list = gather_movie_data(cot)
214
- return movie_list
215
-
216
 
217
  if __name__ == "__main__":
218
- # print(json.dumps(chatbot("Movies of Tom Hanks that were released in 1993"), indent=2))
219
- print(json.dumps(chatbot("List some movies of Tom Cruise"), indent=2))
 
2
  import json
3
  import os
4
  import logging
5
+ from typing import List, Dict, Any, Optional
6
  from groq import Groq
7
  import functions
8
  from utils import python_type, raise_error
 
14
  client = Groq(api_key=os.environ["GROQ_API_KEY"])
15
 
16
  MODEL = "llama3-groq-70b-8192-tool-use-preview"
17
+ ALL_FUNCTIONS = [func for func in dir(functions) if callable(getattr(functions, func)) and not func.startswith("__")]
18
+ NAMES_TO_FUNCTIONS = {func: functools.partial(getattr(functions, func)) for func in ALL_FUNCTIONS}
19
+
20
+ def create_message(prompt: str, message_type: str) -> List[Dict[str, str]]:
21
+ logging.debug(f"Creating message with prompt: {prompt} and message type: {message_type}")
22
+
23
+ system_messages = {
24
+ "reasoning_chain": (
 
 
 
 
 
25
  "You are a movie search assistant bot who uses TMDB to help users "
26
  "find movies. Think step by step and identify the sequence of "
27
  "reasoning steps that will help to answer the user's query."
28
+ ),
29
+ "function_call": (
 
30
  "You are a movie search assistant bot that utilizes TMDB to help users find movies. "
31
  "Approach each query step by step, determining the sequence of function calls needed to gather the necessary information. "
32
  "Execute functions sequentially, using the output from one function to inform the next function call when required. "
 
35
  "finalize your process with a discover_movie function call that returns a list of movie IDs. "
36
  "Ensure that this call includes all necessary parameters to accurately filter the movies."
37
  )
38
+ }
39
+
40
+ if message_type not in system_messages:
41
+ raise ValueError("Invalid message type. Expected 'reasoning_chain' or 'function_call'")
42
+
43
  return [
44
+ {"role": "system", "content": system_messages[message_type]},
45
+ {"role": "user", "content": prompt},
 
 
 
 
 
 
46
  ]
47
 
48
+ def get_groq_response(messages: List[Dict[str, str]], tool_choice: str = "auto") -> Any:
49
+ logging.debug(f"Getting response with model: {MODEL}, \nmessages: {json.dumps(messages, indent=2)}, \ntool_choice: {tool_choice}")
50
+ try:
51
+ response = client.chat.completions.create(
52
+ model=MODEL,
53
+ messages=messages,
54
+ tools=tools,
55
+ tool_choice=tool_choice,
56
+ temperature=0,
57
+ max_tokens=4096,
58
+ )
59
+ logging.debug(f"Response: {response}")
60
+ return response
61
+ except Exception as e:
62
+ logging.error(f"Error getting response from Groq: {str(e)}")
63
+ raise
64
 
65
+ def generate_reasoning_chain(user_prompt: str) -> Any:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  messages = create_message(user_prompt, "reasoning_chain")
67
  logging.debug(f"Generating reasoning chain with messages: {messages}")
68
+ try:
69
+ cot_response = get_groq_response(messages, tool_choice="none")
70
+ logging.info(f"COT response: {cot_response.choices[0].message.content}")
71
+ if cot_response.choices[0].finish_reason == "stop":
72
+ return cot_response.choices[0]
73
+ else:
74
+ raise_error("Failed to generate reasoning chain. Got response: " + str(cot_response), Exception)
75
+ except Exception as e:
76
+ logging.error(f"Error generating reasoning chain: {str(e)}")
77
+ raise
78
+
79
+ def validate_parameter(param_name: str, param_value: Any, tool_params: Dict[str, Any]) -> bool:
80
+ logging.debug(f"Validating parameter: {param_name} with value: {param_value}")
81
+ param_def = tool_params.get(param_name)
 
 
82
  if param_def is None:
83
+ logging.error(f"Parameter {param_name} not found in tools. Dropping this tool call.")
 
84
  return False
85
  try:
86
+ python_type(param_def["type"])(param_value)
87
+ return True
88
  except ValueError:
89
+ logging.error(f"Parameter {param_name} value cannot be cast to {param_def['type']}. Dropping this tool call.")
 
90
  return False
 
 
91
 
92
+ def is_tool_valid(tool_name: str) -> Optional[Dict[str, Any]]:
 
93
  return next((tool for tool in tools if tool["function"]["name"] == tool_name), None)
94
 
95
+ def validate_tool_parameters(tool_def: Dict[str, Any], tool_args: Dict[str, Any]) -> bool:
 
 
96
  tool_params = tool_def.get("function", {}).get("parameters", {}).get("properties", {})
97
  if not tool_params:
 
98
  return True
99
+ return all(validate_parameter(param_name, param_value, tool_params) for param_name, param_value in tool_args.items())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
+ def verify_tool_calls(tool_calls: List[Any]) -> List[Any]:
102
  valid_tool_calls = []
 
103
  for tool_call in tool_calls:
104
  tool_name = tool_call.function.name
105
  tool_args = json.loads(tool_call.function.arguments)
 
106
  tool_def = is_tool_valid(tool_name)
107
+ if tool_def and validate_tool_parameters(tool_def, tool_args):
108
+ valid_tool_calls.append(tool_call)
 
109
  else:
110
+ logging.error(f"Invalid tool call: {tool_name}. Dropping this tool call.")
111
+
112
+ tool_calls_str = json.dumps([tool_call.__dict__ for tool_call in valid_tool_calls], default=str, indent=2)
113
+ logging.info('Tool calls validated successfully. Valid tool calls are: %s', tool_calls_str)
 
 
 
114
  return valid_tool_calls
115
 
116
+ def execute_tool(tool_call: Any) -> Dict[str, Any]:
117
+ logging.info(f"Executing tool: \n Name: {tool_call.function.name}\n Parameters: {tool_call.function.arguments}")
118
+ function_to_call = NAMES_TO_FUNCTIONS[tool_call.function.name]
119
+ function_args = json.loads(tool_call.function.arguments)
120
+ return function_to_call(**function_args)
121
 
122
+ def gather_movie_data(messages: List[Dict[str, str]], max_tool_calls: int = 3) -> Optional[List[Dict[str, Any]]]:
123
  logging.debug(f"Gathering movie data with messages: {messages}")
124
+ try:
125
+ response = get_groq_response(messages, tool_choice="required")
126
+ logging.debug(f"Calling tools based on the response: {response}")
127
+
128
+ if response.choices[0].finish_reason != "tool_calls":
129
+ raise Exception("Failed to gather movie data. Got response: " + str(response))
130
+
131
  tool_calls = response.choices[0].message.tool_calls
 
132
  valid_tool_calls = verify_tool_calls(tool_calls)
133
+
134
+ tool_messages_count = len([msg for msg in messages if msg["role"] == "tool"])
135
+
136
+ if tool_messages_count >= max_tool_calls or not valid_tool_calls:
137
+ return None # No results found or max tool calls reached
138
+
139
+ tool_call = valid_tool_calls[0] # Run one tool call at a time
140
+ logging.debug(f"Tool call: {tool_call.function.name}, Tool call parameters: {tool_call.function.arguments}")
141
+
142
+ tool_output = execute_tool(tool_call)
143
+ logging.debug(f"Tool call output: {json.dumps(tool_output, indent=2)}")
144
+
145
+ if tool_call.function.name == "discover_movie":
146
+ return tool_output["results"] # A list of movies
147
+
148
+ updated_messages = messages + [
149
+ {
150
+ "tool_call_id": tool_call.id,
151
+ "role": "tool",
152
+ "name": tool_call.function.name,
153
+ "content": str(tool_output),
154
+ }
155
+ ]
156
+ return gather_movie_data(updated_messages, max_tool_calls)
157
+
158
+ except Exception as e:
159
+ logging.error(f"Error gathering movie data: {str(e)}")
160
+ return None
161
+
162
+ def chatbot(user_prompt: str) -> Optional[List[Dict[str, Any]]]:
163
+ try:
164
+ cot_response_choice = generate_reasoning_chain(user_prompt)
165
+ cot = create_message(user_prompt, "function_call")
166
+ cot.append({
167
+ 'role': cot_response_choice.message.role,
168
+ 'content': cot_response_choice.message.content
169
+ })
170
+ movie_list = gather_movie_data(cot)
171
+ return movie_list
172
+ except Exception as e:
173
+ logging.error(f"Error in chatbot: {str(e)}")
174
+ return None
 
 
 
 
 
 
175
 
176
  if __name__ == "__main__":
177
+ result = chatbot("List some movies of Tom Cruise")
178
+ print(json.dumps(result, indent=2) if result else "No results found")