Spaces:
Paused
Paused
File size: 9,707 Bytes
ebf567b 95bf5e4 ebf567b 9f83bcc 6660737 95bf5e4 ebf567b 90d520c 9f83bcc 90d520c 95bf5e4 90d520c 95bf5e4 6660737 95bf5e4 9f83bcc 90d520c 95bf5e4 90d520c 6660737 90d520c 95bf5e4 90d520c 95bf5e4 90d520c ebf567b 9f83bcc ebf567b 9f83bcc 95bf5e4 9f83bcc 90d520c 9f83bcc 95bf5e4 9f83bcc 95bf5e4 9f83bcc 90d520c 95bf5e4 9f83bcc 90d520c 9f83bcc 95bf5e4 9f83bcc 90d520c ebf567b 95bf5e4 9f83bcc 95bf5e4 90d520c 95bf5e4 9f83bcc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 |
import functools
import json
import os
import logging
from groq import Groq
import functions
from utils import python_type, raise_error
from tools import tools
# Set up logging
logging.basicConfig(level=logging.DEBUG)
client = Groq(api_key=os.environ["GROQ_API_KEY"])
MODEL = "llama3-groq-70b-8192-tool-use-preview"
all_functions = [func for func in dir(functions) if callable(
getattr(functions, func)) and not func.startswith("__")]
names_to_functions = {func: functools.partial(
getattr(functions, func)) for func in all_functions}
def create_message(prompt, message_type):
logging.debug(
f"Creating message with prompt: {prompt} and message type: {message_type}")
system_message = ""
if message_type == "reasoning_chain":
system_message = (
"You are a movie search assistant bot who uses TMDB to help users "
"find movies. Think step by step and identify the sequence of "
"reasoning steps that will help to answer the user's query."
)
elif message_type == "function_call":
system_message = (
"You are a movie search assistant bot that utilizes TMDB to help users find movies. "
"Approach each query step by step, determining the sequence of function calls needed to gather the necessary information. "
"Execute functions sequentially, using the output from one function to inform the next function call when required. "
"Only call multiple functions simultaneously when they can run independently of each other. "
"Once you have identified all the required parameters from previous calls, "
"finalize your process with a discover_movie function call that returns a list of movie IDs. "
"Ensure that this call includes all necessary parameters to accurately filter the movies."
)
else:
raise ValueError(
"Invalid message type. Expected 'reasoning_chain' or 'function_call'")
return [
{
"role": "system",
"content": system_message,
},
{
"role": "user",
"content": prompt,
},
]
def get_response(client, model, messages, tool_choice="auto"):
logging.info(
f"Getting response with model: {model}, \nmessages: {json.dumps(messages, indent=2)}, \ntool_choice: {tool_choice}")
response = client.chat.completions.create(
model=model,
messages=messages,
tools=tools,
tool_choice=tool_choice,
temperature=0,
max_tokens=4096,
)
logging.debug(f"Response: {response}")
return response
def generate_reasoning_chain(user_prompt):
messages = create_message(user_prompt, "reasoning_chain")
logging.debug(f"Generating reasoning chain with messages: {messages}")
cot_response = get_response(client, MODEL, messages, tool_choice="none")
logging.info(f"COT response: {cot_response.choices[0].message.content}")
if cot_response.choices[0].finish_reason == "stop":
return cot_response.choices[0]
else:
raise_error("Failed to generate reasoning chain. Got response: " +
str(cot_response), Exception)
def validate_params(tool_params, param_name, param_value):
"""
Checks if the parameter value matches with the one defined in tools.py
"""
logging.debug(
f"Validating parameter: {param_name} with value: {param_value}")
param_def = tool_params.get(param_name, None)
if param_def is None:
logging.error(
f"Parameter {param_name} not found in tools. Dropping this tool call.")
return False
try:
param_value = python_type(param_def["type"])(param_value)
except ValueError:
logging.error(
f"Parameter {param_name} value cannot be cast to {param_def['type']}. Dropping this tool call.")
return False
return True
def extract_leaf_values(json_obj):
"""Recursively extract leaf values from a JSON object or string."""
# Check if the input is a string and try to parse it
if isinstance(json_obj, str):
try:
json_obj = json.loads(json_obj)
except json.JSONDecodeError:
return [json_obj] # Return the string if it's not valid JSON
if isinstance(json_obj, dict):
values = []
for value in json_obj.values():
values.extend(extract_leaf_values(value))
return values
elif isinstance(json_obj, list):
values = []
for item in json_obj:
values.extend(extract_leaf_values(item))
return values
else:
return [json_obj]
def is_tool_valid(tool_name):
"""Check if the tool name is valid and return its definition."""
return next((tool for tool in tools if tool["function"]["name"] == tool_name), None)
def validate_tool_parameters(tool_def, tool_args):
"""Validate the parameters of the tool against its definition."""
tool_params = tool_def["function"]["parameters"]["properties"]
for param_name, param_value in tool_args.items():
if not validate_params(tool_params, param_name, param_value):
logging.error(
f"Invalid parameter {param_name} for tool {tool_def['function']['name']}. Dropping this tool call.")
return False
return True
def are_arguments_valid(tool_args, user_query_values, previous_values):
"""Check if all argument values are valid."""
arg_values = tool_args.values()
return all(str(value) in user_query_values or value in previous_values for value in arg_values)
def verify_tool_calls(tool_calls, messages):
"""
Verify tool calls based on user query and previous tool outputs.
:param tool_calls: List of tool calls with arguments.
:param messages: List containing user query and previous tool outputs.
:return: List of valid tool calls.
"""
# Extract user query from the first message with role 'user'
user_query_values = next((msg["content"]
for msg in messages if msg["role"] == "user"), None)
# Extract previous tool outputs from messages with role 'tool'
previous_tool_outputs = [msg["content"]
for msg in messages if msg["role"] == "tool"]
previous_values = [
value for output in previous_tool_outputs for value in extract_leaf_values(output)]
valid_tool_calls = []
for tool_call in tool_calls:
tool_name = tool_call.function.name
tool_args = json.loads(tool_call.function.arguments)
tool_def = is_tool_valid(tool_name)
if tool_def:
if validate_tool_parameters(tool_def, tool_args):
valid_tool_calls.append(tool_call)
else:
logging.error(
f"Tool {tool_name} not found in tools. Dropping this tool call.")
tool_calls_str = [json.dumps(tool_call.__dict__, default=str)
for tool_call in valid_tool_calls]
logging.info(
'Tool calls validated successfully. Valid tool calls are: %s', tool_calls_str)
return valid_tool_calls
def gather_movie_data(messages):
logging.debug(f"Gathering movie data with messages: {messages}")
response = get_response(client, MODEL, messages, tool_choice="required")
logging.debug(f"Calling tools based on the response: {response}")
if response.choices[0].finish_reason == "tool_calls":
tool_calls = response.choices[0].message.tool_calls
# validate tool calls
valid_tool_calls = verify_tool_calls(tool_calls, messages)
# valid_tool_calls = tool_calls
updated_messages = messages.copy()
tool_messages_count = len(
[msg for msg in messages if msg["role"] == "tool"])
if tool_messages_count <= 3 and valid_tool_calls:
tool_call = valid_tool_calls[0] # Run one tool call at a time
logging.info(
f"Tool call: {tool_call.function.name}, Tool call parameters: {tool_call.function.arguments}")
tool_output = execute_tool(tool_call)
logging.debug(
f"Tool call output: {json.dumps(tool_output, indent=2)}")
if tool_call.function.name == "discover_movie" or tool_messages_count > 3:
return tool_output["results"] # A list of movies
else:
updated_messages.append(
{
"tool_call_id": tool_call.id,
"role": "tool",
"name": tool_call.function.name,
"content": str(tool_output),
}
)
return gather_movie_data(updated_messages)
else:
return "No results found"
else:
raise Exception(
"Failed to gather movie data. Got response: ", response)
def execute_tool(tool_call):
logging.info(f"Executing tool: {tool_call.function.name}")
function_to_call = names_to_functions[tool_call.function.name]
function_args = json.loads(tool_call.function.arguments)
return function_to_call(**function_args)
def chatbot(user_prompt):
cot_response_choice = generate_reasoning_chain(user_prompt)
cot = create_message(user_prompt, "function_call")
cot.append({
'role': cot_response_choice.message.role,
'content': cot_response_choice.message.content})
movie_list = gather_movie_data(cot)
return movie_list
if __name__ == "__main__":
print(json.dumps(chatbot("Movies of Tom Hanks that were released in 1993"), indent=2))
# print(json.dumps(chatbot("List some movies of Tom Cruise"), indent=2))
|