import functools import json import os import logging from groq import Groq import functions from utils import python_type, raise_error from tools import tools # Set up logging logging.basicConfig(level=logging.DEBUG) client = Groq(api_key=os.environ["GROQ_API_KEY"]) MODEL = "llama3-groq-70b-8192-tool-use-preview" all_functions = [func for func in dir(functions) if callable( getattr(functions, func)) and not func.startswith("__")] names_to_functions = {func: functools.partial( getattr(functions, func)) for func in all_functions} def create_message(prompt, message_type): logging.debug( f"Creating message with prompt: {prompt} and message type: {message_type}") system_message = "" if message_type == "reasoning_chain": system_message = ( "You are a movie search assistant bot who uses TMDB to help users " "find movies. Think step by step and identify the sequence of " "reasoning steps that will help to answer the user's query." ) elif message_type == "function_call": system_message = ( "You are a movie search assistant bot that utilizes TMDB to help users find movies. " "Approach each query step by step, determining the sequence of function calls needed to gather the necessary information. " "Execute functions sequentially, using the output from one function to inform the next function call when required. " "Only call multiple functions simultaneously when they can run independently of each other. " "Once you have identified all the required parameters from previous calls, " "finalize your process with a discover_movie function call that returns a list of movie IDs. " "Ensure that this call includes all necessary parameters to accurately filter the movies." ) else: raise ValueError( "Invalid message type. Expected 'reasoning_chain' or 'function_call'") return [ { "role": "system", "content": system_message, }, { "role": "user", "content": prompt, }, ] def get_response(client, model, messages, tool_choice="auto"): logging.info( f"Getting response with model: {model}, \nmessages: {json.dumps(messages, indent=2)}, \ntool_choice: {tool_choice}") response = client.chat.completions.create( model=model, messages=messages, tools=tools, tool_choice=tool_choice, temperature=0, max_tokens=4096, ) logging.debug(f"Response: {response}") return response def generate_reasoning_chain(user_prompt): messages = create_message(user_prompt, "reasoning_chain") logging.debug(f"Generating reasoning chain with messages: {messages}") cot_response = get_response(client, MODEL, messages, tool_choice="none") logging.info(f"COT response: {cot_response.choices[0].message.content}") if cot_response.choices[0].finish_reason == "stop": return cot_response.choices[0] else: raise_error("Failed to generate reasoning chain. Got response: " + str(cot_response), Exception) def validate_params(tool_params, param_name, param_value): """ Checks if the parameter value matches with the one defined in tools.py """ logging.debug( f"Validating parameter: {param_name} with value: {param_value}") param_def = tool_params.get(param_name, None) if param_def is None: logging.error( f"Parameter {param_name} not found in tools. Dropping this tool call.") return False try: param_value = python_type(param_def["type"])(param_value) except ValueError: logging.error( f"Parameter {param_name} value cannot be cast to {param_def['type']}. Dropping this tool call.") return False return True def extract_leaf_values(json_obj): """Recursively extract leaf values from a JSON object or string.""" # Check if the input is a string and try to parse it if isinstance(json_obj, str): try: json_obj = json.loads(json_obj) except json.JSONDecodeError: return [json_obj] # Return the string if it's not valid JSON if isinstance(json_obj, dict): values = [] for value in json_obj.values(): values.extend(extract_leaf_values(value)) return values elif isinstance(json_obj, list): values = [] for item in json_obj: values.extend(extract_leaf_values(item)) return values else: return [json_obj] def is_tool_valid(tool_name): """Check if the tool name is valid and return its definition.""" return next((tool for tool in tools if tool["function"]["name"] == tool_name), None) def validate_tool_parameters(tool_def, tool_args): """Validate the parameters of the tool against its definition.""" tool_params = tool_def["function"]["parameters"]["properties"] for param_name, param_value in tool_args.items(): if not validate_params(tool_params, param_name, param_value): logging.error( f"Invalid parameter {param_name} for tool {tool_def['function']['name']}. Dropping this tool call.") return False return True def are_arguments_valid(tool_args, user_query_values, previous_values): """Check if all argument values are valid.""" arg_values = tool_args.values() return all(str(value) in user_query_values or value in previous_values for value in arg_values) def verify_tool_calls(tool_calls, messages): """ Verify tool calls based on user query and previous tool outputs. :param tool_calls: List of tool calls with arguments. :param messages: List containing user query and previous tool outputs. :return: List of valid tool calls. """ # Extract user query from the first message with role 'user' user_query_values = next((msg["content"] for msg in messages if msg["role"] == "user"), None) # Extract previous tool outputs from messages with role 'tool' previous_tool_outputs = [msg["content"] for msg in messages if msg["role"] == "tool"] previous_values = [ value for output in previous_tool_outputs for value in extract_leaf_values(output)] valid_tool_calls = [] for tool_call in tool_calls: tool_name = tool_call.function.name tool_args = json.loads(tool_call.function.arguments) tool_def = is_tool_valid(tool_name) if tool_def: if validate_tool_parameters(tool_def, tool_args): valid_tool_calls.append(tool_call) else: logging.error( f"Tool {tool_name} not found in tools. Dropping this tool call.") tool_calls_str = [json.dumps(tool_call.__dict__, default=str) for tool_call in valid_tool_calls] logging.info( 'Tool calls validated successfully. Valid tool calls are: %s', tool_calls_str) return valid_tool_calls def gather_movie_data(messages): logging.debug(f"Gathering movie data with messages: {messages}") response = get_response(client, MODEL, messages, tool_choice="required") logging.debug(f"Calling tools based on the response: {response}") if response.choices[0].finish_reason == "tool_calls": tool_calls = response.choices[0].message.tool_calls # validate tool calls valid_tool_calls = verify_tool_calls(tool_calls, messages) # valid_tool_calls = tool_calls updated_messages = messages.copy() tool_messages_count = len( [msg for msg in messages if msg["role"] == "tool"]) if tool_messages_count <= 3 and valid_tool_calls: tool_call = valid_tool_calls[0] # Run one tool call at a time logging.info( f"Tool call: {tool_call.function.name}, Tool call parameters: {tool_call.function.arguments}") tool_output = execute_tool(tool_call) logging.debug( f"Tool call output: {json.dumps(tool_output, indent=2)}") if tool_call.function.name == "discover_movie" or tool_messages_count > 3: return tool_output["results"] # A list of movies else: updated_messages.append( { "tool_call_id": tool_call.id, "role": "tool", "name": tool_call.function.name, "content": str(tool_output), } ) return gather_movie_data(updated_messages) else: return "No results found" else: raise Exception( "Failed to gather movie data. Got response: ", response) def execute_tool(tool_call): logging.info(f"Executing tool: {tool_call.function.name}") function_to_call = names_to_functions[tool_call.function.name] function_args = json.loads(tool_call.function.arguments) return function_to_call(**function_args) def chatbot(user_prompt): cot_response_choice = generate_reasoning_chain(user_prompt) cot = create_message(user_prompt, "function_call") cot.append({ 'role': cot_response_choice.message.role, 'content': cot_response_choice.message.content}) movie_list = gather_movie_data(cot) return movie_list if __name__ == "__main__": print(json.dumps(chatbot("Movies of Tom Hanks that were released in 1993"), indent=2)) # print(json.dumps(chatbot("List some movies of Tom Cruise"), indent=2))