Spaces:
Running
Running
import re | |
import json | |
import functools | |
import functions | |
import torch | |
from tools import tools | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
BitsAndBytesConfig | |
) | |
# Get all functions from functions.py | |
all_functions = [func for func in dir(functions) if callable( | |
getattr(functions, func)) and not func.startswith("__")] | |
# Create names_to_function dict containing partials for all functions in functions.py | |
names_to_functions = {func: functools.partial( | |
getattr(functions, func)) for func in all_functions} | |
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct" | |
# specify how to quantize the model | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, device_map="auto", quantization_config=quantization_config | |
) | |
def extract_function_call(output): | |
match = re.search(r'<\|python_tag\|>(.*)<\|eom_id\|>', output) | |
if match: | |
function_call = match.group(1) | |
return json.loads(function_call) | |
else: | |
return None | |
def chatbot(query): | |
messages = [ | |
{"role": "system", "content": "You are a movie search assistant bot who uses TMDB to help users find movies. Think step by step and identify the sequence of function calls that will help to answer."}, | |
{"role": "user", "content": query}, | |
] | |
tokenized_chat = tokenizer.apply_chat_template( | |
messages, tools=tools, add_generation_prompt=True, tokenize=True, return_tensors="pt") | |
outputs = model.generate(tokenized_chat, max_new_tokens=128) | |
answer = tokenizer.batch_decode(outputs[:, tokenized_chat.shape[1]:])[0] | |
tool_call = extract_function_call(answer) | |
if tool_call: | |
function_name = tool_call['name'] | |
function_params = tool_call['parameters'] | |
print("\nfunction_name: ", function_name, | |
"\nfunction_params: ", function_params) | |
function_result = names_to_functions[function_name](**function_params) | |
print(function_result['results']) | |
return function_result['results'] | |
else: | |
print("No tool calls found in the answer.") | |