File size: 2,253 Bytes
ebf567b
 
 
 
 
 
6660737
 
 
 
 
ebf567b
6660737
 
ebf567b
 
 
 
 
 
 
 
 
6660737
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebf567b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6660737
ebf567b
 
 
6660737
ebf567b
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

import re
import json

import functools
import functions
import torch
from tools import tools
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig
)

# Get all functions from functions.py
all_functions = [func for func in dir(functions) if callable(
    getattr(functions, func)) and not func.startswith("__")]

# Create names_to_function dict containing partials for all functions in functions.py
names_to_functions = {func: functools.partial(
    getattr(functions, func)) for func in all_functions}


model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"

# specify how to quantize the model
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id, device_map="auto", quantization_config=quantization_config
)


def extract_function_call(output):
    match = re.search(r'<\|python_tag\|>(.*)<\|eom_id\|>', output)
    if match:
        function_call = match.group(1)
        return json.loads(function_call)
    else:
        return None


def chatbot(query):
    messages = [
        {"role": "system", "content": "You are a movie search assistant bot who uses TMDB to help users find movies. Think step by step and identify the sequence of function calls that will help to answer."},
        {"role": "user", "content": query},
    ]

    tokenized_chat = tokenizer.apply_chat_template(
        messages, tools=tools, add_generation_prompt=True, tokenize=True, return_tensors="pt")

    outputs = model.generate(tokenized_chat, max_new_tokens=128)
    answer = tokenizer.batch_decode(outputs[:, tokenized_chat.shape[1]:])[0]
    tool_call = extract_function_call(answer)

    if tool_call:
        function_name = tool_call['name']
        function_params = tool_call['parameters']
        print("\nfunction_name: ", function_name,
              "\nfunction_params: ", function_params)
        function_result = names_to_functions[function_name](**function_params)
        print(function_result['results'])
        return function_result['results']
    else:
        print("No tool calls found in the answer.")