File size: 5,296 Bytes
ebf567b
95bf5e4
 
 
 
ebf567b
6660737
 
95bf5e4
 
 
 
 
 
ebf567b
 
 
 
 
 
90d520c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95bf5e4
 
 
90d520c
95bf5e4
 
 
 
 
 
6660737
 
95bf5e4
 
 
90d520c
95bf5e4
 
 
 
 
 
 
90d520c
 
6660737
 
90d520c
 
95bf5e4
 
90d520c
95bf5e4
90d520c
ebf567b
90d520c
 
 
ebf567b
 
95bf5e4
 
 
 
90d520c
 
95bf5e4
 
 
 
90d520c
 
95bf5e4
90d520c
 
95bf5e4
 
 
 
 
 
 
 
 
 
 
 
 
90d520c
 
 
ebf567b
95bf5e4
 
 
 
 
 
 
 
 
90d520c
 
 
 
 
95bf5e4
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import functools
import json
import os
import logging
from groq import Groq
import functions
from tools import tools

# Set up logging
logging.basicConfig(level=logging.DEBUG)

client = Groq(api_key=os.environ["GROQ_API_KEY"])

MODEL = "llama3-groq-70b-8192-tool-use-preview"
all_functions = [func for func in dir(functions) if callable(
    getattr(functions, func)) and not func.startswith("__")]
names_to_functions = {func: functools.partial(
    getattr(functions, func)) for func in all_functions}


def create_message(prompt, message_type):
    logging.debug(
        f"Creating message with prompt: {prompt} and message type: {message_type}")
    system_message = ""

    if message_type == "reasoning_chain":
        system_message = (
            "You are a movie search assistant bot who uses TMDB to help users "
            "find movies. Think step by step and identify the sequence of "
            "reasoning steps that will help to answer the user's query."
        )
    elif message_type == "function_call":
        system_message = (
            "You are a movie search assistant bot who uses TMDB to help users "
            "find movies. Think step by step and identify the sequence of "
            "function calls that will help to answer the user's query. Use the "
            "available functions to gather the necessary data. "
            "Do not call multiple functions when they need to be executed in sequence. "
            "Only call multiple functions when they can be executed in parallel. "
            "Stop with a discover_movie function call that returns a list of movie ids. "
            "Ensure the discover_movie function call includes all the necessary parameters to filter the movies accurately."
        )
    else:
        raise ValueError(
            "Invalid message type. Expected 'reasoning_chain' or 'function_call'")

    return [
        {
            "role": "system",
            "content": system_message,
        },
        {
            "role": "user",
            "content": prompt,
        },
    ]


def get_response(client, model, messages, tool_choice="auto"):
    logging.debug(
        f"Getting response with model: {model}, messages: {messages}, tool_choice: {tool_choice}")
    response = client.chat.completions.create(
        model=model,
        messages=messages,
        tools=tools,
        tool_choice=tool_choice,
        temperature=0,
        max_tokens=4096,
    )
    logging.debug(f"Response: {response}")
    return response


def generate_reasoning_chain(user_prompt):
    messages = create_message(user_prompt, "reasoning_chain")
    logging.debug(f"Generating reasoning chain with messages: {messages}")
    cot_response = get_response(client, MODEL, messages, tool_choice="none")
    logging.info(f"COT response: {cot_response.choices[0].message.content}")
    if cot_response.choices[0].finish_reason == "stop":
        return cot_response.choices[0]
    else:
        logging.error(
            "Failed to generate reasoning chain. Got response: ", cot_response)
        raise Exception("Failed to generate reasoning chain")


def gather_movie_data(messages, iteration=0, max_iterations=2):
    logging.debug(
        f"Gathering movie data with messages: {messages}, iteration: {iteration}")
    response = get_response(client, MODEL, messages, tool_choice="required")
    logging.info(
        f"Gathering movie data response: {response}")
    if response.choices[0].finish_reason == "tool_calls":
        tool_calls = response.choices[0].message.tool_calls
        updated_messages = messages.copy()
        for tool_call in tool_calls:
            logging.info(
                f"Tool call: {tool_call.function.name}, Tool call parameters: {tool_call.function.arguments}")
            tool_output = execute_tool(tool_call)
            logging.info(
                f"Tool call output: {json.dumps(tool_output, indent=2)}")
            if tool_call.function.name == "discover_movie":
                return tool_output["results"]  # A list of movies
            else:
                updated_messages.append(
                    {
                        "tool_call_id": tool_call.id,
                        "role": "tool",
                        "name": tool_call.function.name,
                        "content": str(tool_output),
                    }
                )
        if iteration < max_iterations:
            return gather_movie_data(updated_messages, iteration + 1)
    else:
        raise Exception(
            "Failed to gather movie data. Got response: ", response)


def execute_tool(tool_call):
    logging.debug(f"Executing tool: {tool_call.function.name}")
    function_to_call = names_to_functions[tool_call.function.name]
    function_args = json.loads(tool_call.function.arguments)
    return function_to_call(**function_args)


def chatbot(user_prompt):
    cot_response_choice = generate_reasoning_chain(user_prompt)
    cot = create_message(user_prompt, "function_call")
    cot.append({
        'role': cot_response_choice.message.role,
        'content': cot_response_choice.message.content})
    movie_list = gather_movie_data(cot)
    return movie_list


if __name__ == "__main__":
    print(json.dumps(chatbot("List comedy movies with tom cruise in it"), indent=2))