movies-app / agent.py
tmzh
add natural language queries
ebf567b
raw
history blame
No virus
2.25 kB
import re
import json
import functools
import functions
import torch
from tools import tools
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig
)
# Get all functions from functions.py
all_functions = [func for func in dir(functions) if callable(
getattr(functions, func)) and not func.startswith("__")]
# Create names_to_function dict containing partials for all functions in functions.py
names_to_functions = {func: functools.partial(
getattr(functions, func)) for func in all_functions}
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
# specify how to quantize the model
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id, device_map="auto", quantization_config=quantization_config
)
def extract_function_call(output):
match = re.search(r'<\|python_tag\|>(.*)<\|eom_id\|>', output)
if match:
function_call = match.group(1)
return json.loads(function_call)
else:
return None
def chatbot(query):
messages = [
{"role": "system", "content": "You are a movie search assistant bot who uses TMDB to help users find movies. Think step by step and identify the sequence of function calls that will help to answer."},
{"role": "user", "content": query},
]
tokenized_chat = tokenizer.apply_chat_template(
messages, tools=tools, add_generation_prompt=True, tokenize=True, return_tensors="pt")
outputs = model.generate(tokenized_chat, max_new_tokens=128)
answer = tokenizer.batch_decode(outputs[:, tokenized_chat.shape[1]:])[0]
tool_call = extract_function_call(answer)
if tool_call:
function_name = tool_call['name']
function_params = tool_call['parameters']
print("\nfunction_name: ", function_name,
"\nfunction_params: ", function_params)
function_result = names_to_functions[function_name](**function_params)
print(function_result['results'])
return function_result['results']
else:
print("No tool calls found in the answer.")