movies-app / agent.py
tmzh
different system message for reasoning vs tool use chains
90d520c
raw
history blame
5.3 kB
import functools
import json
import os
import logging
from groq import Groq
import functions
from tools import tools
# Set up logging
logging.basicConfig(level=logging.DEBUG)
client = Groq(api_key=os.environ["GROQ_API_KEY"])
MODEL = "llama3-groq-70b-8192-tool-use-preview"
all_functions = [func for func in dir(functions) if callable(
getattr(functions, func)) and not func.startswith("__")]
names_to_functions = {func: functools.partial(
getattr(functions, func)) for func in all_functions}
def create_message(prompt, message_type):
logging.debug(
f"Creating message with prompt: {prompt} and message type: {message_type}")
system_message = ""
if message_type == "reasoning_chain":
system_message = (
"You are a movie search assistant bot who uses TMDB to help users "
"find movies. Think step by step and identify the sequence of "
"reasoning steps that will help to answer the user's query."
)
elif message_type == "function_call":
system_message = (
"You are a movie search assistant bot who uses TMDB to help users "
"find movies. Think step by step and identify the sequence of "
"function calls that will help to answer the user's query. Use the "
"available functions to gather the necessary data. "
"Do not call multiple functions when they need to be executed in sequence. "
"Only call multiple functions when they can be executed in parallel. "
"Stop with a discover_movie function call that returns a list of movie ids. "
"Ensure the discover_movie function call includes all the necessary parameters to filter the movies accurately."
)
else:
raise ValueError(
"Invalid message type. Expected 'reasoning_chain' or 'function_call'")
return [
{
"role": "system",
"content": system_message,
},
{
"role": "user",
"content": prompt,
},
]
def get_response(client, model, messages, tool_choice="auto"):
logging.debug(
f"Getting response with model: {model}, messages: {messages}, tool_choice: {tool_choice}")
response = client.chat.completions.create(
model=model,
messages=messages,
tools=tools,
tool_choice=tool_choice,
temperature=0,
max_tokens=4096,
)
logging.debug(f"Response: {response}")
return response
def generate_reasoning_chain(user_prompt):
messages = create_message(user_prompt, "reasoning_chain")
logging.debug(f"Generating reasoning chain with messages: {messages}")
cot_response = get_response(client, MODEL, messages, tool_choice="none")
logging.info(f"COT response: {cot_response.choices[0].message.content}")
if cot_response.choices[0].finish_reason == "stop":
return cot_response.choices[0]
else:
logging.error(
"Failed to generate reasoning chain. Got response: ", cot_response)
raise Exception("Failed to generate reasoning chain")
def gather_movie_data(messages, iteration=0, max_iterations=2):
logging.debug(
f"Gathering movie data with messages: {messages}, iteration: {iteration}")
response = get_response(client, MODEL, messages, tool_choice="required")
logging.info(
f"Gathering movie data response: {response}")
if response.choices[0].finish_reason == "tool_calls":
tool_calls = response.choices[0].message.tool_calls
updated_messages = messages.copy()
for tool_call in tool_calls:
logging.info(
f"Tool call: {tool_call.function.name}, Tool call parameters: {tool_call.function.arguments}")
tool_output = execute_tool(tool_call)
logging.info(
f"Tool call output: {json.dumps(tool_output, indent=2)}")
if tool_call.function.name == "discover_movie":
return tool_output["results"] # A list of movies
else:
updated_messages.append(
{
"tool_call_id": tool_call.id,
"role": "tool",
"name": tool_call.function.name,
"content": str(tool_output),
}
)
if iteration < max_iterations:
return gather_movie_data(updated_messages, iteration + 1)
else:
raise Exception(
"Failed to gather movie data. Got response: ", response)
def execute_tool(tool_call):
logging.debug(f"Executing tool: {tool_call.function.name}")
function_to_call = names_to_functions[tool_call.function.name]
function_args = json.loads(tool_call.function.arguments)
return function_to_call(**function_args)
def chatbot(user_prompt):
cot_response_choice = generate_reasoning_chain(user_prompt)
cot = create_message(user_prompt, "function_call")
cot.append({
'role': cot_response_choice.message.role,
'content': cot_response_choice.message.content})
movie_list = gather_movie_data(cot)
return movie_list
if __name__ == "__main__":
print(json.dumps(chatbot("List comedy movies with tom cruise in it"), indent=2))