tmzh commited on
Commit
90d520c
1 Parent(s): 95bf5e4

different system message for reasoning vs tool use chains

Browse files
Files changed (2) hide show
  1. agent.py +51 -14
  2. functions.py +1 -1
agent.py CHANGED
@@ -18,12 +18,36 @@ names_to_functions = {func: functools.partial(
18
  getattr(functions, func)) for func in all_functions}
19
 
20
 
21
- def create_message(prompt):
22
- logging.debug(f"Creating message with prompt: {prompt}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  return [
24
  {
25
  "role": "system",
26
- "content": "You are a movie search assistant bot who uses TMDB to help users find movies. Think step by step and identify the sequence of function calls that will help to answer.",
27
  },
28
  {
29
  "role": "user",
@@ -35,7 +59,7 @@ def create_message(prompt):
35
  def get_response(client, model, messages, tool_choice="auto"):
36
  logging.debug(
37
  f"Getting response with model: {model}, messages: {messages}, tool_choice: {tool_choice}")
38
- return client.chat.completions.create(
39
  model=model,
40
  messages=messages,
41
  tools=tools,
@@ -43,31 +67,38 @@ def get_response(client, model, messages, tool_choice="auto"):
43
  temperature=0,
44
  max_tokens=4096,
45
  )
 
 
46
 
47
 
48
- def generate_reasoning_chain(messages):
 
49
  logging.debug(f"Generating reasoning chain with messages: {messages}")
50
  cot_response = get_response(client, MODEL, messages, tool_choice="none")
 
51
  if cot_response.choices[0].finish_reason == "stop":
52
- messages.append({
53
- 'role': cot_response.choices[0].message.role,
54
- 'content': cot_response.choices[0].message.content})
55
- return messages
56
  else:
57
- return None
 
 
58
 
59
 
60
  def gather_movie_data(messages, iteration=0, max_iterations=2):
61
  logging.debug(
62
  f"Gathering movie data with messages: {messages}, iteration: {iteration}")
63
  response = get_response(client, MODEL, messages, tool_choice="required")
 
 
64
  if response.choices[0].finish_reason == "tool_calls":
65
  tool_calls = response.choices[0].message.tool_calls
66
  updated_messages = messages.copy()
67
  for tool_call in tool_calls:
 
 
68
  tool_output = execute_tool(tool_call)
69
- logging.debug(
70
- f"Tool call: {tool_call.function.name}, output: {tool_output}")
71
  if tool_call.function.name == "discover_movie":
72
  return tool_output["results"] # A list of movies
73
  else:
@@ -81,6 +112,9 @@ def gather_movie_data(messages, iteration=0, max_iterations=2):
81
  )
82
  if iteration < max_iterations:
83
  return gather_movie_data(updated_messages, iteration + 1)
 
 
 
84
 
85
 
86
  def execute_tool(tool_call):
@@ -91,8 +125,11 @@ def execute_tool(tool_call):
91
 
92
 
93
  def chatbot(user_prompt):
94
- messages = create_message(user_prompt)
95
- cot = generate_reasoning_chain(messages)
 
 
 
96
  movie_list = gather_movie_data(cot)
97
  return movie_list
98
 
 
18
  getattr(functions, func)) for func in all_functions}
19
 
20
 
21
+ def create_message(prompt, message_type):
22
+ logging.debug(
23
+ f"Creating message with prompt: {prompt} and message type: {message_type}")
24
+ system_message = ""
25
+
26
+ if message_type == "reasoning_chain":
27
+ system_message = (
28
+ "You are a movie search assistant bot who uses TMDB to help users "
29
+ "find movies. Think step by step and identify the sequence of "
30
+ "reasoning steps that will help to answer the user's query."
31
+ )
32
+ elif message_type == "function_call":
33
+ system_message = (
34
+ "You are a movie search assistant bot who uses TMDB to help users "
35
+ "find movies. Think step by step and identify the sequence of "
36
+ "function calls that will help to answer the user's query. Use the "
37
+ "available functions to gather the necessary data. "
38
+ "Do not call multiple functions when they need to be executed in sequence. "
39
+ "Only call multiple functions when they can be executed in parallel. "
40
+ "Stop with a discover_movie function call that returns a list of movie ids. "
41
+ "Ensure the discover_movie function call includes all the necessary parameters to filter the movies accurately."
42
+ )
43
+ else:
44
+ raise ValueError(
45
+ "Invalid message type. Expected 'reasoning_chain' or 'function_call'")
46
+
47
  return [
48
  {
49
  "role": "system",
50
+ "content": system_message,
51
  },
52
  {
53
  "role": "user",
 
59
  def get_response(client, model, messages, tool_choice="auto"):
60
  logging.debug(
61
  f"Getting response with model: {model}, messages: {messages}, tool_choice: {tool_choice}")
62
+ response = client.chat.completions.create(
63
  model=model,
64
  messages=messages,
65
  tools=tools,
 
67
  temperature=0,
68
  max_tokens=4096,
69
  )
70
+ logging.debug(f"Response: {response}")
71
+ return response
72
 
73
 
74
+ def generate_reasoning_chain(user_prompt):
75
+ messages = create_message(user_prompt, "reasoning_chain")
76
  logging.debug(f"Generating reasoning chain with messages: {messages}")
77
  cot_response = get_response(client, MODEL, messages, tool_choice="none")
78
+ logging.info(f"COT response: {cot_response.choices[0].message.content}")
79
  if cot_response.choices[0].finish_reason == "stop":
80
+ return cot_response.choices[0]
 
 
 
81
  else:
82
+ logging.error(
83
+ "Failed to generate reasoning chain. Got response: ", cot_response)
84
+ raise Exception("Failed to generate reasoning chain")
85
 
86
 
87
  def gather_movie_data(messages, iteration=0, max_iterations=2):
88
  logging.debug(
89
  f"Gathering movie data with messages: {messages}, iteration: {iteration}")
90
  response = get_response(client, MODEL, messages, tool_choice="required")
91
+ logging.info(
92
+ f"Gathering movie data response: {response}")
93
  if response.choices[0].finish_reason == "tool_calls":
94
  tool_calls = response.choices[0].message.tool_calls
95
  updated_messages = messages.copy()
96
  for tool_call in tool_calls:
97
+ logging.info(
98
+ f"Tool call: {tool_call.function.name}, Tool call parameters: {tool_call.function.arguments}")
99
  tool_output = execute_tool(tool_call)
100
+ logging.info(
101
+ f"Tool call output: {json.dumps(tool_output, indent=2)}")
102
  if tool_call.function.name == "discover_movie":
103
  return tool_output["results"] # A list of movies
104
  else:
 
112
  )
113
  if iteration < max_iterations:
114
  return gather_movie_data(updated_messages, iteration + 1)
115
+ else:
116
+ raise Exception(
117
+ "Failed to gather movie data. Got response: ", response)
118
 
119
 
120
  def execute_tool(tool_call):
 
125
 
126
 
127
  def chatbot(user_prompt):
128
+ cot_response_choice = generate_reasoning_chain(user_prompt)
129
+ cot = create_message(user_prompt, "function_call")
130
+ cot.append({
131
+ 'role': cot_response_choice.message.role,
132
+ 'content': cot_response_choice.message.content})
133
  movie_list = gather_movie_data(cot)
134
  return movie_list
135
 
functions.py CHANGED
@@ -7,7 +7,7 @@ BASE_URL = "https://api.themoviedb.org/3"
7
  API_KEY = os.environ['TMDB_API_KEY']
8
 
9
  # Set up logging
10
- logging.basicConfig(level=logging.DEBUG)
11
 
12
 
13
  def query_tmdb(endpoint, params={}):
 
7
  API_KEY = os.environ['TMDB_API_KEY']
8
 
9
  # Set up logging
10
+ logging.basicConfig(level=logging.INFO)
11
 
12
 
13
  def query_tmdb(endpoint, params={}):