Spaces:
Paused
Paused
tmzh
commited on
Commit
·
90d520c
1
Parent(s):
95bf5e4
different system message for reasoning vs tool use chains
Browse files- agent.py +51 -14
- functions.py +1 -1
agent.py
CHANGED
|
@@ -18,12 +18,36 @@ names_to_functions = {func: functools.partial(
|
|
| 18 |
getattr(functions, func)) for func in all_functions}
|
| 19 |
|
| 20 |
|
| 21 |
-
def create_message(prompt):
|
| 22 |
-
logging.debug(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
return [
|
| 24 |
{
|
| 25 |
"role": "system",
|
| 26 |
-
"content":
|
| 27 |
},
|
| 28 |
{
|
| 29 |
"role": "user",
|
|
@@ -35,7 +59,7 @@ def create_message(prompt):
|
|
| 35 |
def get_response(client, model, messages, tool_choice="auto"):
|
| 36 |
logging.debug(
|
| 37 |
f"Getting response with model: {model}, messages: {messages}, tool_choice: {tool_choice}")
|
| 38 |
-
|
| 39 |
model=model,
|
| 40 |
messages=messages,
|
| 41 |
tools=tools,
|
|
@@ -43,31 +67,38 @@ def get_response(client, model, messages, tool_choice="auto"):
|
|
| 43 |
temperature=0,
|
| 44 |
max_tokens=4096,
|
| 45 |
)
|
|
|
|
|
|
|
| 46 |
|
| 47 |
|
| 48 |
-
def generate_reasoning_chain(
|
|
|
|
| 49 |
logging.debug(f"Generating reasoning chain with messages: {messages}")
|
| 50 |
cot_response = get_response(client, MODEL, messages, tool_choice="none")
|
|
|
|
| 51 |
if cot_response.choices[0].finish_reason == "stop":
|
| 52 |
-
|
| 53 |
-
'role': cot_response.choices[0].message.role,
|
| 54 |
-
'content': cot_response.choices[0].message.content})
|
| 55 |
-
return messages
|
| 56 |
else:
|
| 57 |
-
|
|
|
|
|
|
|
| 58 |
|
| 59 |
|
| 60 |
def gather_movie_data(messages, iteration=0, max_iterations=2):
|
| 61 |
logging.debug(
|
| 62 |
f"Gathering movie data with messages: {messages}, iteration: {iteration}")
|
| 63 |
response = get_response(client, MODEL, messages, tool_choice="required")
|
|
|
|
|
|
|
| 64 |
if response.choices[0].finish_reason == "tool_calls":
|
| 65 |
tool_calls = response.choices[0].message.tool_calls
|
| 66 |
updated_messages = messages.copy()
|
| 67 |
for tool_call in tool_calls:
|
|
|
|
|
|
|
| 68 |
tool_output = execute_tool(tool_call)
|
| 69 |
-
logging.
|
| 70 |
-
f"Tool call: {
|
| 71 |
if tool_call.function.name == "discover_movie":
|
| 72 |
return tool_output["results"] # A list of movies
|
| 73 |
else:
|
|
@@ -81,6 +112,9 @@ def gather_movie_data(messages, iteration=0, max_iterations=2):
|
|
| 81 |
)
|
| 82 |
if iteration < max_iterations:
|
| 83 |
return gather_movie_data(updated_messages, iteration + 1)
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
|
| 86 |
def execute_tool(tool_call):
|
|
@@ -91,8 +125,11 @@ def execute_tool(tool_call):
|
|
| 91 |
|
| 92 |
|
| 93 |
def chatbot(user_prompt):
|
| 94 |
-
|
| 95 |
-
cot =
|
|
|
|
|
|
|
|
|
|
| 96 |
movie_list = gather_movie_data(cot)
|
| 97 |
return movie_list
|
| 98 |
|
|
|
|
| 18 |
getattr(functions, func)) for func in all_functions}
|
| 19 |
|
| 20 |
|
| 21 |
+
def create_message(prompt, message_type):
|
| 22 |
+
logging.debug(
|
| 23 |
+
f"Creating message with prompt: {prompt} and message type: {message_type}")
|
| 24 |
+
system_message = ""
|
| 25 |
+
|
| 26 |
+
if message_type == "reasoning_chain":
|
| 27 |
+
system_message = (
|
| 28 |
+
"You are a movie search assistant bot who uses TMDB to help users "
|
| 29 |
+
"find movies. Think step by step and identify the sequence of "
|
| 30 |
+
"reasoning steps that will help to answer the user's query."
|
| 31 |
+
)
|
| 32 |
+
elif message_type == "function_call":
|
| 33 |
+
system_message = (
|
| 34 |
+
"You are a movie search assistant bot who uses TMDB to help users "
|
| 35 |
+
"find movies. Think step by step and identify the sequence of "
|
| 36 |
+
"function calls that will help to answer the user's query. Use the "
|
| 37 |
+
"available functions to gather the necessary data. "
|
| 38 |
+
"Do not call multiple functions when they need to be executed in sequence. "
|
| 39 |
+
"Only call multiple functions when they can be executed in parallel. "
|
| 40 |
+
"Stop with a discover_movie function call that returns a list of movie ids. "
|
| 41 |
+
"Ensure the discover_movie function call includes all the necessary parameters to filter the movies accurately."
|
| 42 |
+
)
|
| 43 |
+
else:
|
| 44 |
+
raise ValueError(
|
| 45 |
+
"Invalid message type. Expected 'reasoning_chain' or 'function_call'")
|
| 46 |
+
|
| 47 |
return [
|
| 48 |
{
|
| 49 |
"role": "system",
|
| 50 |
+
"content": system_message,
|
| 51 |
},
|
| 52 |
{
|
| 53 |
"role": "user",
|
|
|
|
| 59 |
def get_response(client, model, messages, tool_choice="auto"):
|
| 60 |
logging.debug(
|
| 61 |
f"Getting response with model: {model}, messages: {messages}, tool_choice: {tool_choice}")
|
| 62 |
+
response = client.chat.completions.create(
|
| 63 |
model=model,
|
| 64 |
messages=messages,
|
| 65 |
tools=tools,
|
|
|
|
| 67 |
temperature=0,
|
| 68 |
max_tokens=4096,
|
| 69 |
)
|
| 70 |
+
logging.debug(f"Response: {response}")
|
| 71 |
+
return response
|
| 72 |
|
| 73 |
|
| 74 |
+
def generate_reasoning_chain(user_prompt):
|
| 75 |
+
messages = create_message(user_prompt, "reasoning_chain")
|
| 76 |
logging.debug(f"Generating reasoning chain with messages: {messages}")
|
| 77 |
cot_response = get_response(client, MODEL, messages, tool_choice="none")
|
| 78 |
+
logging.info(f"COT response: {cot_response.choices[0].message.content}")
|
| 79 |
if cot_response.choices[0].finish_reason == "stop":
|
| 80 |
+
return cot_response.choices[0]
|
|
|
|
|
|
|
|
|
|
| 81 |
else:
|
| 82 |
+
logging.error(
|
| 83 |
+
"Failed to generate reasoning chain. Got response: ", cot_response)
|
| 84 |
+
raise Exception("Failed to generate reasoning chain")
|
| 85 |
|
| 86 |
|
| 87 |
def gather_movie_data(messages, iteration=0, max_iterations=2):
|
| 88 |
logging.debug(
|
| 89 |
f"Gathering movie data with messages: {messages}, iteration: {iteration}")
|
| 90 |
response = get_response(client, MODEL, messages, tool_choice="required")
|
| 91 |
+
logging.info(
|
| 92 |
+
f"Gathering movie data response: {response}")
|
| 93 |
if response.choices[0].finish_reason == "tool_calls":
|
| 94 |
tool_calls = response.choices[0].message.tool_calls
|
| 95 |
updated_messages = messages.copy()
|
| 96 |
for tool_call in tool_calls:
|
| 97 |
+
logging.info(
|
| 98 |
+
f"Tool call: {tool_call.function.name}, Tool call parameters: {tool_call.function.arguments}")
|
| 99 |
tool_output = execute_tool(tool_call)
|
| 100 |
+
logging.info(
|
| 101 |
+
f"Tool call output: {json.dumps(tool_output, indent=2)}")
|
| 102 |
if tool_call.function.name == "discover_movie":
|
| 103 |
return tool_output["results"] # A list of movies
|
| 104 |
else:
|
|
|
|
| 112 |
)
|
| 113 |
if iteration < max_iterations:
|
| 114 |
return gather_movie_data(updated_messages, iteration + 1)
|
| 115 |
+
else:
|
| 116 |
+
raise Exception(
|
| 117 |
+
"Failed to gather movie data. Got response: ", response)
|
| 118 |
|
| 119 |
|
| 120 |
def execute_tool(tool_call):
|
|
|
|
| 125 |
|
| 126 |
|
| 127 |
def chatbot(user_prompt):
|
| 128 |
+
cot_response_choice = generate_reasoning_chain(user_prompt)
|
| 129 |
+
cot = create_message(user_prompt, "function_call")
|
| 130 |
+
cot.append({
|
| 131 |
+
'role': cot_response_choice.message.role,
|
| 132 |
+
'content': cot_response_choice.message.content})
|
| 133 |
movie_list = gather_movie_data(cot)
|
| 134 |
return movie_list
|
| 135 |
|
functions.py
CHANGED
|
@@ -7,7 +7,7 @@ BASE_URL = "https://api.themoviedb.org/3"
|
|
| 7 |
API_KEY = os.environ['TMDB_API_KEY']
|
| 8 |
|
| 9 |
# Set up logging
|
| 10 |
-
logging.basicConfig(level=logging.
|
| 11 |
|
| 12 |
|
| 13 |
def query_tmdb(endpoint, params={}):
|
|
|
|
| 7 |
API_KEY = os.environ['TMDB_API_KEY']
|
| 8 |
|
| 9 |
# Set up logging
|
| 10 |
+
logging.basicConfig(level=logging.INFO)
|
| 11 |
|
| 12 |
|
| 13 |
def query_tmdb(endpoint, params={}):
|