Spaces:
Runtime error
Runtime error
import modal | |
from modal import App, Image | |
# Setup - define our infrastructure with code! | |
app = modal.App("pricer-service") | |
image = Image.debian_slim().pip_install("torch", "transformers", "bitsandbytes", "accelerate", "peft") | |
secrets = [modal.Secret.from_name("hf-secret")] | |
# Constants | |
GPU = "T4" | |
BASE_MODEL = "meta-llama/Meta-Llama-3.1-8B" | |
PROJECT_NAME = "pricer" | |
HF_USER = "ed-donner" # your HF name here! Or use mine if you just want to reproduce my results. | |
RUN_NAME = "2024-09-13_13.04.39" | |
PROJECT_RUN_NAME = f"{PROJECT_NAME}-{RUN_NAME}" | |
REVISION = "e8d637df551603dc86cd7a1598a8f44af4d7ae36" | |
FINETUNED_MODEL = f"{HF_USER}/{PROJECT_RUN_NAME}" | |
def price(description: str) -> float: | |
import os | |
import re | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, set_seed | |
from peft import PeftModel | |
QUESTION = "How much does this cost to the nearest dollar?" | |
PREFIX = "Price is $" | |
prompt = f"{QUESTION}\n{description}\n{PREFIX}" | |
# Quant Config | |
quant_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
bnb_4bit_quant_type="nf4" | |
) | |
# Load model and tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.padding_side = "right" | |
base_model = AutoModelForCausalLM.from_pretrained( | |
BASE_MODEL, | |
quantization_config=quant_config, | |
device_map="auto" | |
) | |
fine_tuned_model = PeftModel.from_pretrained(base_model, FINETUNED_MODEL, revision=REVISION) | |
set_seed(42) | |
inputs = tokenizer.encode(prompt, return_tensors="pt").to("cuda") | |
attention_mask = torch.ones(inputs.shape, device="cuda") | |
outputs = fine_tuned_model.generate(inputs, attention_mask=attention_mask, max_new_tokens=5, num_return_sequences=1) | |
result = tokenizer.decode(outputs[0]) | |
contents = result.split("Price is $")[1] | |
contents = contents.replace(',','') | |
match = re.search(r"[-+]?\d*\.\d+|\d+", contents) | |
return float(match.group()) if match else 0 | |