Spaces:
Running
Running
tmzh
commited on
Commit
•
90d520c
1
Parent(s):
95bf5e4
different system message for reasoning vs tool use chains
Browse files- agent.py +51 -14
- functions.py +1 -1
agent.py
CHANGED
@@ -18,12 +18,36 @@ names_to_functions = {func: functools.partial(
|
|
18 |
getattr(functions, func)) for func in all_functions}
|
19 |
|
20 |
|
21 |
-
def create_message(prompt):
|
22 |
-
logging.debug(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
return [
|
24 |
{
|
25 |
"role": "system",
|
26 |
-
"content":
|
27 |
},
|
28 |
{
|
29 |
"role": "user",
|
@@ -35,7 +59,7 @@ def create_message(prompt):
|
|
35 |
def get_response(client, model, messages, tool_choice="auto"):
|
36 |
logging.debug(
|
37 |
f"Getting response with model: {model}, messages: {messages}, tool_choice: {tool_choice}")
|
38 |
-
|
39 |
model=model,
|
40 |
messages=messages,
|
41 |
tools=tools,
|
@@ -43,31 +67,38 @@ def get_response(client, model, messages, tool_choice="auto"):
|
|
43 |
temperature=0,
|
44 |
max_tokens=4096,
|
45 |
)
|
|
|
|
|
46 |
|
47 |
|
48 |
-
def generate_reasoning_chain(
|
|
|
49 |
logging.debug(f"Generating reasoning chain with messages: {messages}")
|
50 |
cot_response = get_response(client, MODEL, messages, tool_choice="none")
|
|
|
51 |
if cot_response.choices[0].finish_reason == "stop":
|
52 |
-
|
53 |
-
'role': cot_response.choices[0].message.role,
|
54 |
-
'content': cot_response.choices[0].message.content})
|
55 |
-
return messages
|
56 |
else:
|
57 |
-
|
|
|
|
|
58 |
|
59 |
|
60 |
def gather_movie_data(messages, iteration=0, max_iterations=2):
|
61 |
logging.debug(
|
62 |
f"Gathering movie data with messages: {messages}, iteration: {iteration}")
|
63 |
response = get_response(client, MODEL, messages, tool_choice="required")
|
|
|
|
|
64 |
if response.choices[0].finish_reason == "tool_calls":
|
65 |
tool_calls = response.choices[0].message.tool_calls
|
66 |
updated_messages = messages.copy()
|
67 |
for tool_call in tool_calls:
|
|
|
|
|
68 |
tool_output = execute_tool(tool_call)
|
69 |
-
logging.
|
70 |
-
f"Tool call: {
|
71 |
if tool_call.function.name == "discover_movie":
|
72 |
return tool_output["results"] # A list of movies
|
73 |
else:
|
@@ -81,6 +112,9 @@ def gather_movie_data(messages, iteration=0, max_iterations=2):
|
|
81 |
)
|
82 |
if iteration < max_iterations:
|
83 |
return gather_movie_data(updated_messages, iteration + 1)
|
|
|
|
|
|
|
84 |
|
85 |
|
86 |
def execute_tool(tool_call):
|
@@ -91,8 +125,11 @@ def execute_tool(tool_call):
|
|
91 |
|
92 |
|
93 |
def chatbot(user_prompt):
|
94 |
-
|
95 |
-
cot =
|
|
|
|
|
|
|
96 |
movie_list = gather_movie_data(cot)
|
97 |
return movie_list
|
98 |
|
|
|
18 |
getattr(functions, func)) for func in all_functions}
|
19 |
|
20 |
|
21 |
+
def create_message(prompt, message_type):
|
22 |
+
logging.debug(
|
23 |
+
f"Creating message with prompt: {prompt} and message type: {message_type}")
|
24 |
+
system_message = ""
|
25 |
+
|
26 |
+
if message_type == "reasoning_chain":
|
27 |
+
system_message = (
|
28 |
+
"You are a movie search assistant bot who uses TMDB to help users "
|
29 |
+
"find movies. Think step by step and identify the sequence of "
|
30 |
+
"reasoning steps that will help to answer the user's query."
|
31 |
+
)
|
32 |
+
elif message_type == "function_call":
|
33 |
+
system_message = (
|
34 |
+
"You are a movie search assistant bot who uses TMDB to help users "
|
35 |
+
"find movies. Think step by step and identify the sequence of "
|
36 |
+
"function calls that will help to answer the user's query. Use the "
|
37 |
+
"available functions to gather the necessary data. "
|
38 |
+
"Do not call multiple functions when they need to be executed in sequence. "
|
39 |
+
"Only call multiple functions when they can be executed in parallel. "
|
40 |
+
"Stop with a discover_movie function call that returns a list of movie ids. "
|
41 |
+
"Ensure the discover_movie function call includes all the necessary parameters to filter the movies accurately."
|
42 |
+
)
|
43 |
+
else:
|
44 |
+
raise ValueError(
|
45 |
+
"Invalid message type. Expected 'reasoning_chain' or 'function_call'")
|
46 |
+
|
47 |
return [
|
48 |
{
|
49 |
"role": "system",
|
50 |
+
"content": system_message,
|
51 |
},
|
52 |
{
|
53 |
"role": "user",
|
|
|
59 |
def get_response(client, model, messages, tool_choice="auto"):
|
60 |
logging.debug(
|
61 |
f"Getting response with model: {model}, messages: {messages}, tool_choice: {tool_choice}")
|
62 |
+
response = client.chat.completions.create(
|
63 |
model=model,
|
64 |
messages=messages,
|
65 |
tools=tools,
|
|
|
67 |
temperature=0,
|
68 |
max_tokens=4096,
|
69 |
)
|
70 |
+
logging.debug(f"Response: {response}")
|
71 |
+
return response
|
72 |
|
73 |
|
74 |
+
def generate_reasoning_chain(user_prompt):
|
75 |
+
messages = create_message(user_prompt, "reasoning_chain")
|
76 |
logging.debug(f"Generating reasoning chain with messages: {messages}")
|
77 |
cot_response = get_response(client, MODEL, messages, tool_choice="none")
|
78 |
+
logging.info(f"COT response: {cot_response.choices[0].message.content}")
|
79 |
if cot_response.choices[0].finish_reason == "stop":
|
80 |
+
return cot_response.choices[0]
|
|
|
|
|
|
|
81 |
else:
|
82 |
+
logging.error(
|
83 |
+
"Failed to generate reasoning chain. Got response: ", cot_response)
|
84 |
+
raise Exception("Failed to generate reasoning chain")
|
85 |
|
86 |
|
87 |
def gather_movie_data(messages, iteration=0, max_iterations=2):
|
88 |
logging.debug(
|
89 |
f"Gathering movie data with messages: {messages}, iteration: {iteration}")
|
90 |
response = get_response(client, MODEL, messages, tool_choice="required")
|
91 |
+
logging.info(
|
92 |
+
f"Gathering movie data response: {response}")
|
93 |
if response.choices[0].finish_reason == "tool_calls":
|
94 |
tool_calls = response.choices[0].message.tool_calls
|
95 |
updated_messages = messages.copy()
|
96 |
for tool_call in tool_calls:
|
97 |
+
logging.info(
|
98 |
+
f"Tool call: {tool_call.function.name}, Tool call parameters: {tool_call.function.arguments}")
|
99 |
tool_output = execute_tool(tool_call)
|
100 |
+
logging.info(
|
101 |
+
f"Tool call output: {json.dumps(tool_output, indent=2)}")
|
102 |
if tool_call.function.name == "discover_movie":
|
103 |
return tool_output["results"] # A list of movies
|
104 |
else:
|
|
|
112 |
)
|
113 |
if iteration < max_iterations:
|
114 |
return gather_movie_data(updated_messages, iteration + 1)
|
115 |
+
else:
|
116 |
+
raise Exception(
|
117 |
+
"Failed to gather movie data. Got response: ", response)
|
118 |
|
119 |
|
120 |
def execute_tool(tool_call):
|
|
|
125 |
|
126 |
|
127 |
def chatbot(user_prompt):
|
128 |
+
cot_response_choice = generate_reasoning_chain(user_prompt)
|
129 |
+
cot = create_message(user_prompt, "function_call")
|
130 |
+
cot.append({
|
131 |
+
'role': cot_response_choice.message.role,
|
132 |
+
'content': cot_response_choice.message.content})
|
133 |
movie_list = gather_movie_data(cot)
|
134 |
return movie_list
|
135 |
|
functions.py
CHANGED
@@ -7,7 +7,7 @@ BASE_URL = "https://api.themoviedb.org/3"
|
|
7 |
API_KEY = os.environ['TMDB_API_KEY']
|
8 |
|
9 |
# Set up logging
|
10 |
-
logging.basicConfig(level=logging.
|
11 |
|
12 |
|
13 |
def query_tmdb(endpoint, params={}):
|
|
|
7 |
API_KEY = os.environ['TMDB_API_KEY']
|
8 |
|
9 |
# Set up logging
|
10 |
+
logging.basicConfig(level=logging.INFO)
|
11 |
|
12 |
|
13 |
def query_tmdb(endpoint, params={}):
|