Spaces:
Paused
Paused
import functools | |
import json | |
import os | |
import logging | |
from groq import Groq | |
import functions | |
from tools import tools | |
# Set up logging | |
logging.basicConfig(level=logging.DEBUG) | |
client = Groq(api_key=os.environ["GROQ_API_KEY"]) | |
MODEL = "llama3-groq-70b-8192-tool-use-preview" | |
all_functions = [func for func in dir(functions) if callable( | |
getattr(functions, func)) and not func.startswith("__")] | |
names_to_functions = {func: functools.partial( | |
getattr(functions, func)) for func in all_functions} | |
def create_message(prompt, message_type): | |
logging.debug( | |
f"Creating message with prompt: {prompt} and message type: {message_type}") | |
system_message = "" | |
if message_type == "reasoning_chain": | |
system_message = ( | |
"You are a movie search assistant bot who uses TMDB to help users " | |
"find movies. Think step by step and identify the sequence of " | |
"reasoning steps that will help to answer the user's query." | |
) | |
elif message_type == "function_call": | |
system_message = ( | |
"You are a movie search assistant bot who uses TMDB to help users " | |
"find movies. Think step by step and identify the sequence of " | |
"function calls that will help to answer the user's query. Use the " | |
"available functions to gather the necessary data. " | |
"Do not call multiple functions when they need to be executed in sequence. " | |
"Only call multiple functions when they can be executed in parallel. " | |
"Stop with a discover_movie function call that returns a list of movie ids. " | |
"Ensure the discover_movie function call includes all the necessary parameters to filter the movies accurately." | |
) | |
else: | |
raise ValueError( | |
"Invalid message type. Expected 'reasoning_chain' or 'function_call'") | |
return [ | |
{ | |
"role": "system", | |
"content": system_message, | |
}, | |
{ | |
"role": "user", | |
"content": prompt, | |
}, | |
] | |
def get_response(client, model, messages, tool_choice="auto"): | |
logging.debug( | |
f"Getting response with model: {model}, messages: {messages}, tool_choice: {tool_choice}") | |
response = client.chat.completions.create( | |
model=model, | |
messages=messages, | |
tools=tools, | |
tool_choice=tool_choice, | |
temperature=0, | |
max_tokens=4096, | |
) | |
logging.debug(f"Response: {response}") | |
return response | |
def generate_reasoning_chain(user_prompt): | |
messages = create_message(user_prompt, "reasoning_chain") | |
logging.debug(f"Generating reasoning chain with messages: {messages}") | |
cot_response = get_response(client, MODEL, messages, tool_choice="none") | |
logging.info(f"COT response: {cot_response.choices[0].message.content}") | |
if cot_response.choices[0].finish_reason == "stop": | |
return cot_response.choices[0] | |
else: | |
logging.error( | |
"Failed to generate reasoning chain. Got response: ", cot_response) | |
raise Exception("Failed to generate reasoning chain") | |
def gather_movie_data(messages, iteration=0, max_iterations=2): | |
logging.debug( | |
f"Gathering movie data with messages: {messages}, iteration: {iteration}") | |
response = get_response(client, MODEL, messages, tool_choice="required") | |
logging.info( | |
f"Gathering movie data response: {response}") | |
if response.choices[0].finish_reason == "tool_calls": | |
tool_calls = response.choices[0].message.tool_calls | |
updated_messages = messages.copy() | |
for tool_call in tool_calls: | |
logging.info( | |
f"Tool call: {tool_call.function.name}, Tool call parameters: {tool_call.function.arguments}") | |
tool_output = execute_tool(tool_call) | |
logging.info( | |
f"Tool call output: {json.dumps(tool_output, indent=2)}") | |
if tool_call.function.name == "discover_movie": | |
return tool_output["results"] # A list of movies | |
else: | |
updated_messages.append( | |
{ | |
"tool_call_id": tool_call.id, | |
"role": "tool", | |
"name": tool_call.function.name, | |
"content": str(tool_output), | |
} | |
) | |
if iteration < max_iterations: | |
return gather_movie_data(updated_messages, iteration + 1) | |
else: | |
raise Exception( | |
"Failed to gather movie data. Got response: ", response) | |
def execute_tool(tool_call): | |
logging.debug(f"Executing tool: {tool_call.function.name}") | |
function_to_call = names_to_functions[tool_call.function.name] | |
function_args = json.loads(tool_call.function.arguments) | |
return function_to_call(**function_args) | |
def chatbot(user_prompt): | |
cot_response_choice = generate_reasoning_chain(user_prompt) | |
cot = create_message(user_prompt, "function_call") | |
cot.append({ | |
'role': cot_response_choice.message.role, | |
'content': cot_response_choice.message.content}) | |
movie_list = gather_movie_data(cot) | |
return movie_list | |
if __name__ == "__main__": | |
print(json.dumps(chatbot("List comedy movies with tom cruise in it"), indent=2)) | |