Spaces:
Paused
Paused
File size: 2,253 Bytes
ebf567b 6660737 ebf567b 6660737 ebf567b 6660737 ebf567b 6660737 ebf567b 6660737 ebf567b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import re
import json
import functools
import functions
import torch
from tools import tools
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig
)
# Get all functions from functions.py
all_functions = [func for func in dir(functions) if callable(
getattr(functions, func)) and not func.startswith("__")]
# Create names_to_function dict containing partials for all functions in functions.py
names_to_functions = {func: functools.partial(
getattr(functions, func)) for func in all_functions}
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
# specify how to quantize the model
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id, device_map="auto", quantization_config=quantization_config
)
def extract_function_call(output):
match = re.search(r'<\|python_tag\|>(.*)<\|eom_id\|>', output)
if match:
function_call = match.group(1)
return json.loads(function_call)
else:
return None
def chatbot(query):
messages = [
{"role": "system", "content": "You are a movie search assistant bot who uses TMDB to help users find movies. Think step by step and identify the sequence of function calls that will help to answer."},
{"role": "user", "content": query},
]
tokenized_chat = tokenizer.apply_chat_template(
messages, tools=tools, add_generation_prompt=True, tokenize=True, return_tensors="pt")
outputs = model.generate(tokenized_chat, max_new_tokens=128)
answer = tokenizer.batch_decode(outputs[:, tokenized_chat.shape[1]:])[0]
tool_call = extract_function_call(answer)
if tool_call:
function_name = tool_call['name']
function_params = tool_call['parameters']
print("\nfunction_name: ", function_name,
"\nfunction_params: ", function_params)
function_result = names_to_functions[function_name](**function_params)
print(function_result['results'])
return function_result['results']
else:
print("No tool calls found in the answer.")
|