movies-app / agent.py
tmzh
call one tool at a time
9f83bcc
raw
history blame
9.71 kB
import functools
import json
import os
import logging
from groq import Groq
import functions
from utils import python_type, raise_error
from tools import tools
# Set up logging
logging.basicConfig(level=logging.DEBUG)
client = Groq(api_key=os.environ["GROQ_API_KEY"])
MODEL = "llama3-groq-70b-8192-tool-use-preview"
all_functions = [func for func in dir(functions) if callable(
getattr(functions, func)) and not func.startswith("__")]
names_to_functions = {func: functools.partial(
getattr(functions, func)) for func in all_functions}
def create_message(prompt, message_type):
logging.debug(
f"Creating message with prompt: {prompt} and message type: {message_type}")
system_message = ""
if message_type == "reasoning_chain":
system_message = (
"You are a movie search assistant bot who uses TMDB to help users "
"find movies. Think step by step and identify the sequence of "
"reasoning steps that will help to answer the user's query."
)
elif message_type == "function_call":
system_message = (
"You are a movie search assistant bot that utilizes TMDB to help users find movies. "
"Approach each query step by step, determining the sequence of function calls needed to gather the necessary information. "
"Execute functions sequentially, using the output from one function to inform the next function call when required. "
"Only call multiple functions simultaneously when they can run independently of each other. "
"Once you have identified all the required parameters from previous calls, "
"finalize your process with a discover_movie function call that returns a list of movie IDs. "
"Ensure that this call includes all necessary parameters to accurately filter the movies."
)
else:
raise ValueError(
"Invalid message type. Expected 'reasoning_chain' or 'function_call'")
return [
{
"role": "system",
"content": system_message,
},
{
"role": "user",
"content": prompt,
},
]
def get_response(client, model, messages, tool_choice="auto"):
logging.info(
f"Getting response with model: {model}, \nmessages: {json.dumps(messages, indent=2)}, \ntool_choice: {tool_choice}")
response = client.chat.completions.create(
model=model,
messages=messages,
tools=tools,
tool_choice=tool_choice,
temperature=0,
max_tokens=4096,
)
logging.debug(f"Response: {response}")
return response
def generate_reasoning_chain(user_prompt):
messages = create_message(user_prompt, "reasoning_chain")
logging.debug(f"Generating reasoning chain with messages: {messages}")
cot_response = get_response(client, MODEL, messages, tool_choice="none")
logging.info(f"COT response: {cot_response.choices[0].message.content}")
if cot_response.choices[0].finish_reason == "stop":
return cot_response.choices[0]
else:
raise_error("Failed to generate reasoning chain. Got response: " +
str(cot_response), Exception)
def validate_params(tool_params, param_name, param_value):
"""
Checks if the parameter value matches with the one defined in tools.py
"""
logging.debug(
f"Validating parameter: {param_name} with value: {param_value}")
param_def = tool_params.get(param_name, None)
if param_def is None:
logging.error(
f"Parameter {param_name} not found in tools. Dropping this tool call.")
return False
try:
param_value = python_type(param_def["type"])(param_value)
except ValueError:
logging.error(
f"Parameter {param_name} value cannot be cast to {param_def['type']}. Dropping this tool call.")
return False
return True
def extract_leaf_values(json_obj):
"""Recursively extract leaf values from a JSON object or string."""
# Check if the input is a string and try to parse it
if isinstance(json_obj, str):
try:
json_obj = json.loads(json_obj)
except json.JSONDecodeError:
return [json_obj] # Return the string if it's not valid JSON
if isinstance(json_obj, dict):
values = []
for value in json_obj.values():
values.extend(extract_leaf_values(value))
return values
elif isinstance(json_obj, list):
values = []
for item in json_obj:
values.extend(extract_leaf_values(item))
return values
else:
return [json_obj]
def is_tool_valid(tool_name):
"""Check if the tool name is valid and return its definition."""
return next((tool for tool in tools if tool["function"]["name"] == tool_name), None)
def validate_tool_parameters(tool_def, tool_args):
"""Validate the parameters of the tool against its definition."""
tool_params = tool_def["function"]["parameters"]["properties"]
for param_name, param_value in tool_args.items():
if not validate_params(tool_params, param_name, param_value):
logging.error(
f"Invalid parameter {param_name} for tool {tool_def['function']['name']}. Dropping this tool call.")
return False
return True
def are_arguments_valid(tool_args, user_query_values, previous_values):
"""Check if all argument values are valid."""
arg_values = tool_args.values()
return all(str(value) in user_query_values or value in previous_values for value in arg_values)
def verify_tool_calls(tool_calls, messages):
"""
Verify tool calls based on user query and previous tool outputs.
:param tool_calls: List of tool calls with arguments.
:param messages: List containing user query and previous tool outputs.
:return: List of valid tool calls.
"""
# Extract user query from the first message with role 'user'
user_query_values = next((msg["content"]
for msg in messages if msg["role"] == "user"), None)
# Extract previous tool outputs from messages with role 'tool'
previous_tool_outputs = [msg["content"]
for msg in messages if msg["role"] == "tool"]
previous_values = [
value for output in previous_tool_outputs for value in extract_leaf_values(output)]
valid_tool_calls = []
for tool_call in tool_calls:
tool_name = tool_call.function.name
tool_args = json.loads(tool_call.function.arguments)
tool_def = is_tool_valid(tool_name)
if tool_def:
if validate_tool_parameters(tool_def, tool_args):
valid_tool_calls.append(tool_call)
else:
logging.error(
f"Tool {tool_name} not found in tools. Dropping this tool call.")
tool_calls_str = [json.dumps(tool_call.__dict__, default=str)
for tool_call in valid_tool_calls]
logging.info(
'Tool calls validated successfully. Valid tool calls are: %s', tool_calls_str)
return valid_tool_calls
def gather_movie_data(messages):
logging.debug(f"Gathering movie data with messages: {messages}")
response = get_response(client, MODEL, messages, tool_choice="required")
logging.debug(f"Calling tools based on the response: {response}")
if response.choices[0].finish_reason == "tool_calls":
tool_calls = response.choices[0].message.tool_calls
# validate tool calls
valid_tool_calls = verify_tool_calls(tool_calls, messages)
# valid_tool_calls = tool_calls
updated_messages = messages.copy()
tool_messages_count = len(
[msg for msg in messages if msg["role"] == "tool"])
if tool_messages_count <= 3 and valid_tool_calls:
tool_call = valid_tool_calls[0] # Run one tool call at a time
logging.info(
f"Tool call: {tool_call.function.name}, Tool call parameters: {tool_call.function.arguments}")
tool_output = execute_tool(tool_call)
logging.debug(
f"Tool call output: {json.dumps(tool_output, indent=2)}")
if tool_call.function.name == "discover_movie" or tool_messages_count > 3:
return tool_output["results"] # A list of movies
else:
updated_messages.append(
{
"tool_call_id": tool_call.id,
"role": "tool",
"name": tool_call.function.name,
"content": str(tool_output),
}
)
return gather_movie_data(updated_messages)
else:
return "No results found"
else:
raise Exception(
"Failed to gather movie data. Got response: ", response)
def execute_tool(tool_call):
logging.info(f"Executing tool: {tool_call.function.name}")
function_to_call = names_to_functions[tool_call.function.name]
function_args = json.loads(tool_call.function.arguments)
return function_to_call(**function_args)
def chatbot(user_prompt):
cot_response_choice = generate_reasoning_chain(user_prompt)
cot = create_message(user_prompt, "function_call")
cot.append({
'role': cot_response_choice.message.role,
'content': cot_response_choice.message.content})
movie_list = gather_movie_data(cot)
return movie_list
if __name__ == "__main__":
print(json.dumps(chatbot("Movies of Tom Hanks that were released in 1993"), indent=2))
# print(json.dumps(chatbot("List some movies of Tom Cruise"), indent=2))