Spaces:
Paused
Paused
tmzh
commited on
Commit
·
6c9edd5
1
Parent(s):
3a35a20
refactor
Browse files
agent.py
CHANGED
@@ -2,6 +2,7 @@ import functools
|
|
2 |
import json
|
3 |
import os
|
4 |
import logging
|
|
|
5 |
from groq import Groq
|
6 |
import functions
|
7 |
from utils import python_type, raise_error
|
@@ -13,25 +14,19 @@ logging.basicConfig(level=logging.DEBUG)
|
|
13 |
client = Groq(api_key=os.environ["GROQ_API_KEY"])
|
14 |
|
15 |
MODEL = "llama3-groq-70b-8192-tool-use-preview"
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
f"Creating message with prompt: {prompt} and message type: {message_type}")
|
25 |
-
system_message = ""
|
26 |
-
|
27 |
-
if message_type == "reasoning_chain":
|
28 |
-
system_message = (
|
29 |
"You are a movie search assistant bot who uses TMDB to help users "
|
30 |
"find movies. Think step by step and identify the sequence of "
|
31 |
"reasoning steps that will help to answer the user's query."
|
32 |
-
)
|
33 |
-
|
34 |
-
system_message = (
|
35 |
"You are a movie search assistant bot that utilizes TMDB to help users find movies. "
|
36 |
"Approach each query step by step, determining the sequence of function calls needed to gather the necessary information. "
|
37 |
"Execute functions sequentially, using the output from one function to inform the next function call when required. "
|
@@ -40,180 +35,144 @@ def create_message(prompt, message_type):
|
|
40 |
"finalize your process with a discover_movie function call that returns a list of movie IDs. "
|
41 |
"Ensure that this call includes all necessary parameters to accurately filter the movies."
|
42 |
)
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
|
|
47 |
return [
|
48 |
-
{
|
49 |
-
|
50 |
-
"content": system_message,
|
51 |
-
},
|
52 |
-
{
|
53 |
-
"role": "user",
|
54 |
-
"content": prompt,
|
55 |
-
},
|
56 |
]
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
-
def
|
60 |
-
logging.debug(
|
61 |
-
f"Getting response with model: {model}, \nmessages: {json.dumps(messages, indent=2)}, \ntool_choice: {tool_choice}")
|
62 |
-
response = client.chat.completions.create(
|
63 |
-
model=model,
|
64 |
-
messages=messages,
|
65 |
-
tools=tools,
|
66 |
-
tool_choice=tool_choice,
|
67 |
-
temperature=0,
|
68 |
-
max_tokens=4096,
|
69 |
-
)
|
70 |
-
logging.debug(f"Response: {response}")
|
71 |
-
return response
|
72 |
-
|
73 |
-
|
74 |
-
def generate_reasoning_chain(user_prompt):
|
75 |
messages = create_message(user_prompt, "reasoning_chain")
|
76 |
logging.debug(f"Generating reasoning chain with messages: {messages}")
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
""
|
90 |
-
|
91 |
-
f"Validating parameter: {param_name} with value: {param_value}")
|
92 |
-
param_def = tool_params.get(param_name, None)
|
93 |
if param_def is None:
|
94 |
-
logging.error(
|
95 |
-
f"Parameter {param_name} not found in tools. Dropping this tool call.")
|
96 |
return False
|
97 |
try:
|
98 |
-
|
|
|
99 |
except ValueError:
|
100 |
-
logging.error(
|
101 |
-
f"Parameter {param_name} value cannot be cast to {param_def['type']}. Dropping this tool call.")
|
102 |
return False
|
103 |
-
return True
|
104 |
-
|
105 |
|
106 |
-
def is_tool_valid(tool_name):
|
107 |
-
"""Check if the tool name is valid and return its definition."""
|
108 |
return next((tool for tool in tools if tool["function"]["name"] == tool_name), None)
|
109 |
|
110 |
-
|
111 |
-
def validate_tool_parameters(tool_def, tool_args):
|
112 |
-
"""Validate the parameters of the tool against its definition."""
|
113 |
tool_params = tool_def.get("function", {}).get("parameters", {}).get("properties", {})
|
114 |
if not tool_params:
|
115 |
-
# No parameters defined for this tool, so all arguments are valid
|
116 |
return True
|
117 |
-
for param_name, param_value in tool_args.items()
|
118 |
-
if not validate_params(tool_params, param_name, param_value):
|
119 |
-
logging.error(
|
120 |
-
f"Invalid parameter {param_name} for tool {tool_def['function']['name']}. Dropping this tool call.")
|
121 |
-
return False
|
122 |
-
return True
|
123 |
-
|
124 |
-
|
125 |
-
def are_arguments_valid(tool_args, user_query_values, previous_values):
|
126 |
-
"""Check if all argument values are valid."""
|
127 |
-
arg_values = tool_args.values()
|
128 |
-
return all(str(value) in user_query_values or value in previous_values for value in arg_values)
|
129 |
-
|
130 |
-
|
131 |
-
def verify_tool_calls(tool_calls):
|
132 |
-
"""
|
133 |
-
Verify tool calls based on user query and previous tool outputs.
|
134 |
-
|
135 |
-
:param tool_calls: List of tool calls with arguments.
|
136 |
-
:return: List of valid tool calls.
|
137 |
-
"""
|
138 |
|
|
|
139 |
valid_tool_calls = []
|
140 |
-
|
141 |
for tool_call in tool_calls:
|
142 |
tool_name = tool_call.function.name
|
143 |
tool_args = json.loads(tool_call.function.arguments)
|
144 |
-
|
145 |
tool_def = is_tool_valid(tool_name)
|
146 |
-
if tool_def:
|
147 |
-
|
148 |
-
valid_tool_calls.append(tool_call)
|
149 |
else:
|
150 |
-
logging.error(
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
[tool_call.__dict__ for tool_call in valid_tool_calls], default=str, indent=2)
|
155 |
-
logging.info(
|
156 |
-
'Tool calls validated successfully. Valid tool calls are: %s', tool_calls_str)
|
157 |
return valid_tool_calls
|
158 |
|
|
|
|
|
|
|
|
|
|
|
159 |
|
160 |
-
def gather_movie_data(messages):
|
161 |
logging.debug(f"Gathering movie data with messages: {messages}")
|
162 |
-
|
163 |
-
|
164 |
-
|
|
|
|
|
|
|
|
|
165 |
tool_calls = response.choices[0].message.tool_calls
|
166 |
-
# validate tool calls
|
167 |
valid_tool_calls = verify_tool_calls(tool_calls)
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
cot.append({
|
211 |
-
'role': cot_response_choice.message.role,
|
212 |
-
'content': cot_response_choice.message.content})
|
213 |
-
movie_list = gather_movie_data(cot)
|
214 |
-
return movie_list
|
215 |
-
|
216 |
|
217 |
if __name__ == "__main__":
|
218 |
-
|
219 |
-
print(json.dumps(
|
|
|
2 |
import json
|
3 |
import os
|
4 |
import logging
|
5 |
+
from typing import List, Dict, Any, Optional
|
6 |
from groq import Groq
|
7 |
import functions
|
8 |
from utils import python_type, raise_error
|
|
|
14 |
client = Groq(api_key=os.environ["GROQ_API_KEY"])
|
15 |
|
16 |
MODEL = "llama3-groq-70b-8192-tool-use-preview"
|
17 |
+
ALL_FUNCTIONS = [func for func in dir(functions) if callable(getattr(functions, func)) and not func.startswith("__")]
|
18 |
+
NAMES_TO_FUNCTIONS = {func: functools.partial(getattr(functions, func)) for func in ALL_FUNCTIONS}
|
19 |
+
|
20 |
+
def create_message(prompt: str, message_type: str) -> List[Dict[str, str]]:
|
21 |
+
logging.debug(f"Creating message with prompt: {prompt} and message type: {message_type}")
|
22 |
+
|
23 |
+
system_messages = {
|
24 |
+
"reasoning_chain": (
|
|
|
|
|
|
|
|
|
|
|
25 |
"You are a movie search assistant bot who uses TMDB to help users "
|
26 |
"find movies. Think step by step and identify the sequence of "
|
27 |
"reasoning steps that will help to answer the user's query."
|
28 |
+
),
|
29 |
+
"function_call": (
|
|
|
30 |
"You are a movie search assistant bot that utilizes TMDB to help users find movies. "
|
31 |
"Approach each query step by step, determining the sequence of function calls needed to gather the necessary information. "
|
32 |
"Execute functions sequentially, using the output from one function to inform the next function call when required. "
|
|
|
35 |
"finalize your process with a discover_movie function call that returns a list of movie IDs. "
|
36 |
"Ensure that this call includes all necessary parameters to accurately filter the movies."
|
37 |
)
|
38 |
+
}
|
39 |
+
|
40 |
+
if message_type not in system_messages:
|
41 |
+
raise ValueError("Invalid message type. Expected 'reasoning_chain' or 'function_call'")
|
42 |
+
|
43 |
return [
|
44 |
+
{"role": "system", "content": system_messages[message_type]},
|
45 |
+
{"role": "user", "content": prompt},
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
]
|
47 |
|
48 |
+
def get_groq_response(messages: List[Dict[str, str]], tool_choice: str = "auto") -> Any:
|
49 |
+
logging.debug(f"Getting response with model: {MODEL}, \nmessages: {json.dumps(messages, indent=2)}, \ntool_choice: {tool_choice}")
|
50 |
+
try:
|
51 |
+
response = client.chat.completions.create(
|
52 |
+
model=MODEL,
|
53 |
+
messages=messages,
|
54 |
+
tools=tools,
|
55 |
+
tool_choice=tool_choice,
|
56 |
+
temperature=0,
|
57 |
+
max_tokens=4096,
|
58 |
+
)
|
59 |
+
logging.debug(f"Response: {response}")
|
60 |
+
return response
|
61 |
+
except Exception as e:
|
62 |
+
logging.error(f"Error getting response from Groq: {str(e)}")
|
63 |
+
raise
|
64 |
|
65 |
+
def generate_reasoning_chain(user_prompt: str) -> Any:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
messages = create_message(user_prompt, "reasoning_chain")
|
67 |
logging.debug(f"Generating reasoning chain with messages: {messages}")
|
68 |
+
try:
|
69 |
+
cot_response = get_groq_response(messages, tool_choice="none")
|
70 |
+
logging.info(f"COT response: {cot_response.choices[0].message.content}")
|
71 |
+
if cot_response.choices[0].finish_reason == "stop":
|
72 |
+
return cot_response.choices[0]
|
73 |
+
else:
|
74 |
+
raise_error("Failed to generate reasoning chain. Got response: " + str(cot_response), Exception)
|
75 |
+
except Exception as e:
|
76 |
+
logging.error(f"Error generating reasoning chain: {str(e)}")
|
77 |
+
raise
|
78 |
+
|
79 |
+
def validate_parameter(param_name: str, param_value: Any, tool_params: Dict[str, Any]) -> bool:
|
80 |
+
logging.debug(f"Validating parameter: {param_name} with value: {param_value}")
|
81 |
+
param_def = tool_params.get(param_name)
|
|
|
|
|
82 |
if param_def is None:
|
83 |
+
logging.error(f"Parameter {param_name} not found in tools. Dropping this tool call.")
|
|
|
84 |
return False
|
85 |
try:
|
86 |
+
python_type(param_def["type"])(param_value)
|
87 |
+
return True
|
88 |
except ValueError:
|
89 |
+
logging.error(f"Parameter {param_name} value cannot be cast to {param_def['type']}. Dropping this tool call.")
|
|
|
90 |
return False
|
|
|
|
|
91 |
|
92 |
+
def is_tool_valid(tool_name: str) -> Optional[Dict[str, Any]]:
|
|
|
93 |
return next((tool for tool in tools if tool["function"]["name"] == tool_name), None)
|
94 |
|
95 |
+
def validate_tool_parameters(tool_def: Dict[str, Any], tool_args: Dict[str, Any]) -> bool:
|
|
|
|
|
96 |
tool_params = tool_def.get("function", {}).get("parameters", {}).get("properties", {})
|
97 |
if not tool_params:
|
|
|
98 |
return True
|
99 |
+
return all(validate_parameter(param_name, param_value, tool_params) for param_name, param_value in tool_args.items())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
+
def verify_tool_calls(tool_calls: List[Any]) -> List[Any]:
|
102 |
valid_tool_calls = []
|
|
|
103 |
for tool_call in tool_calls:
|
104 |
tool_name = tool_call.function.name
|
105 |
tool_args = json.loads(tool_call.function.arguments)
|
|
|
106 |
tool_def = is_tool_valid(tool_name)
|
107 |
+
if tool_def and validate_tool_parameters(tool_def, tool_args):
|
108 |
+
valid_tool_calls.append(tool_call)
|
|
|
109 |
else:
|
110 |
+
logging.error(f"Invalid tool call: {tool_name}. Dropping this tool call.")
|
111 |
+
|
112 |
+
tool_calls_str = json.dumps([tool_call.__dict__ for tool_call in valid_tool_calls], default=str, indent=2)
|
113 |
+
logging.info('Tool calls validated successfully. Valid tool calls are: %s', tool_calls_str)
|
|
|
|
|
|
|
114 |
return valid_tool_calls
|
115 |
|
116 |
+
def execute_tool(tool_call: Any) -> Dict[str, Any]:
|
117 |
+
logging.info(f"Executing tool: \n Name: {tool_call.function.name}\n Parameters: {tool_call.function.arguments}")
|
118 |
+
function_to_call = NAMES_TO_FUNCTIONS[tool_call.function.name]
|
119 |
+
function_args = json.loads(tool_call.function.arguments)
|
120 |
+
return function_to_call(**function_args)
|
121 |
|
122 |
+
def gather_movie_data(messages: List[Dict[str, str]], max_tool_calls: int = 3) -> Optional[List[Dict[str, Any]]]:
|
123 |
logging.debug(f"Gathering movie data with messages: {messages}")
|
124 |
+
try:
|
125 |
+
response = get_groq_response(messages, tool_choice="required")
|
126 |
+
logging.debug(f"Calling tools based on the response: {response}")
|
127 |
+
|
128 |
+
if response.choices[0].finish_reason != "tool_calls":
|
129 |
+
raise Exception("Failed to gather movie data. Got response: " + str(response))
|
130 |
+
|
131 |
tool_calls = response.choices[0].message.tool_calls
|
|
|
132 |
valid_tool_calls = verify_tool_calls(tool_calls)
|
133 |
+
|
134 |
+
tool_messages_count = len([msg for msg in messages if msg["role"] == "tool"])
|
135 |
+
|
136 |
+
if tool_messages_count >= max_tool_calls or not valid_tool_calls:
|
137 |
+
return None # No results found or max tool calls reached
|
138 |
+
|
139 |
+
tool_call = valid_tool_calls[0] # Run one tool call at a time
|
140 |
+
logging.debug(f"Tool call: {tool_call.function.name}, Tool call parameters: {tool_call.function.arguments}")
|
141 |
+
|
142 |
+
tool_output = execute_tool(tool_call)
|
143 |
+
logging.debug(f"Tool call output: {json.dumps(tool_output, indent=2)}")
|
144 |
+
|
145 |
+
if tool_call.function.name == "discover_movie":
|
146 |
+
return tool_output["results"] # A list of movies
|
147 |
+
|
148 |
+
updated_messages = messages + [
|
149 |
+
{
|
150 |
+
"tool_call_id": tool_call.id,
|
151 |
+
"role": "tool",
|
152 |
+
"name": tool_call.function.name,
|
153 |
+
"content": str(tool_output),
|
154 |
+
}
|
155 |
+
]
|
156 |
+
return gather_movie_data(updated_messages, max_tool_calls)
|
157 |
+
|
158 |
+
except Exception as e:
|
159 |
+
logging.error(f"Error gathering movie data: {str(e)}")
|
160 |
+
return None
|
161 |
+
|
162 |
+
def chatbot(user_prompt: str) -> Optional[List[Dict[str, Any]]]:
|
163 |
+
try:
|
164 |
+
cot_response_choice = generate_reasoning_chain(user_prompt)
|
165 |
+
cot = create_message(user_prompt, "function_call")
|
166 |
+
cot.append({
|
167 |
+
'role': cot_response_choice.message.role,
|
168 |
+
'content': cot_response_choice.message.content
|
169 |
+
})
|
170 |
+
movie_list = gather_movie_data(cot)
|
171 |
+
return movie_list
|
172 |
+
except Exception as e:
|
173 |
+
logging.error(f"Error in chatbot: {str(e)}")
|
174 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
|
176 |
if __name__ == "__main__":
|
177 |
+
result = chatbot("List some movies of Tom Cruise")
|
178 |
+
print(json.dumps(result, indent=2) if result else "No results found")
|