Spaces:
Paused
Paused
File size: 5,296 Bytes
ebf567b 95bf5e4 ebf567b 6660737 95bf5e4 ebf567b 90d520c 95bf5e4 90d520c 95bf5e4 6660737 95bf5e4 90d520c 95bf5e4 90d520c 6660737 90d520c 95bf5e4 90d520c 95bf5e4 90d520c ebf567b 90d520c ebf567b 95bf5e4 90d520c 95bf5e4 90d520c 95bf5e4 90d520c 95bf5e4 90d520c ebf567b 95bf5e4 90d520c 95bf5e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import functools
import json
import os
import logging
from groq import Groq
import functions
from tools import tools
# Set up logging
logging.basicConfig(level=logging.DEBUG)
client = Groq(api_key=os.environ["GROQ_API_KEY"])
MODEL = "llama3-groq-70b-8192-tool-use-preview"
all_functions = [func for func in dir(functions) if callable(
getattr(functions, func)) and not func.startswith("__")]
names_to_functions = {func: functools.partial(
getattr(functions, func)) for func in all_functions}
def create_message(prompt, message_type):
logging.debug(
f"Creating message with prompt: {prompt} and message type: {message_type}")
system_message = ""
if message_type == "reasoning_chain":
system_message = (
"You are a movie search assistant bot who uses TMDB to help users "
"find movies. Think step by step and identify the sequence of "
"reasoning steps that will help to answer the user's query."
)
elif message_type == "function_call":
system_message = (
"You are a movie search assistant bot who uses TMDB to help users "
"find movies. Think step by step and identify the sequence of "
"function calls that will help to answer the user's query. Use the "
"available functions to gather the necessary data. "
"Do not call multiple functions when they need to be executed in sequence. "
"Only call multiple functions when they can be executed in parallel. "
"Stop with a discover_movie function call that returns a list of movie ids. "
"Ensure the discover_movie function call includes all the necessary parameters to filter the movies accurately."
)
else:
raise ValueError(
"Invalid message type. Expected 'reasoning_chain' or 'function_call'")
return [
{
"role": "system",
"content": system_message,
},
{
"role": "user",
"content": prompt,
},
]
def get_response(client, model, messages, tool_choice="auto"):
logging.debug(
f"Getting response with model: {model}, messages: {messages}, tool_choice: {tool_choice}")
response = client.chat.completions.create(
model=model,
messages=messages,
tools=tools,
tool_choice=tool_choice,
temperature=0,
max_tokens=4096,
)
logging.debug(f"Response: {response}")
return response
def generate_reasoning_chain(user_prompt):
messages = create_message(user_prompt, "reasoning_chain")
logging.debug(f"Generating reasoning chain with messages: {messages}")
cot_response = get_response(client, MODEL, messages, tool_choice="none")
logging.info(f"COT response: {cot_response.choices[0].message.content}")
if cot_response.choices[0].finish_reason == "stop":
return cot_response.choices[0]
else:
logging.error(
"Failed to generate reasoning chain. Got response: ", cot_response)
raise Exception("Failed to generate reasoning chain")
def gather_movie_data(messages, iteration=0, max_iterations=2):
logging.debug(
f"Gathering movie data with messages: {messages}, iteration: {iteration}")
response = get_response(client, MODEL, messages, tool_choice="required")
logging.info(
f"Gathering movie data response: {response}")
if response.choices[0].finish_reason == "tool_calls":
tool_calls = response.choices[0].message.tool_calls
updated_messages = messages.copy()
for tool_call in tool_calls:
logging.info(
f"Tool call: {tool_call.function.name}, Tool call parameters: {tool_call.function.arguments}")
tool_output = execute_tool(tool_call)
logging.info(
f"Tool call output: {json.dumps(tool_output, indent=2)}")
if tool_call.function.name == "discover_movie":
return tool_output["results"] # A list of movies
else:
updated_messages.append(
{
"tool_call_id": tool_call.id,
"role": "tool",
"name": tool_call.function.name,
"content": str(tool_output),
}
)
if iteration < max_iterations:
return gather_movie_data(updated_messages, iteration + 1)
else:
raise Exception(
"Failed to gather movie data. Got response: ", response)
def execute_tool(tool_call):
logging.debug(f"Executing tool: {tool_call.function.name}")
function_to_call = names_to_functions[tool_call.function.name]
function_args = json.loads(tool_call.function.arguments)
return function_to_call(**function_args)
def chatbot(user_prompt):
cot_response_choice = generate_reasoning_chain(user_prompt)
cot = create_message(user_prompt, "function_call")
cot.append({
'role': cot_response_choice.message.role,
'content': cot_response_choice.message.content})
movie_list = gather_movie_data(cot)
return movie_list
if __name__ == "__main__":
print(json.dumps(chatbot("List comedy movies with tom cruise in it"), indent=2))
|