Spaces:
Running
Running
import os, json, torch | |
import torch.nn as nn | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from peft import PeftModel | |
from config import BASE_MODEL, ADAPTER_MODEL, DEVICE, HF_TOKEN | |
def get_current_branch(): | |
if os.path.exists("current_branch.txt"): | |
with open("current_branch.txt", "r") as f: | |
return f.read().strip() | |
return "latest" | |
class ModelWrapper: | |
def __init__(self): | |
flags_path = os.path.join(os.path.dirname(__file__), "flags.json") | |
self.flags_order = json.load(open(flags_path, encoding="utf-8"))["ALL_FLAGS"] | |
self.num_flags = len(self.flags_order) | |
# 토큰 전달 | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
ADAPTER_MODEL, | |
use_fast=True, | |
token=HF_TOKEN | |
) | |
if self.tokenizer.pad_token is None: | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
self.tokenizer.padding_side = "right" | |
branch = get_current_branch() | |
base = AutoModelForCausalLM.from_pretrained( | |
BASE_MODEL, | |
device_map="auto", | |
trust_remote_code=True, | |
token=HF_TOKEN | |
) | |
self.model = PeftModel.from_pretrained( | |
base, | |
ADAPTER_MODEL, | |
revision=branch, | |
device_map="auto", | |
token=HF_TOKEN | |
) | |
hidden_size = self.model.config.hidden_size | |
self.model.delta_head = nn.Linear(hidden_size, 2).to(DEVICE) | |
self.model.flag_head = nn.Linear(hidden_size, self.num_flags).to(DEVICE) | |
self.model.flag_threshold_head = nn.Linear(hidden_size, self.num_flags).to(DEVICE) | |
if os.path.exists("delta_head.pt"): | |
self.model.delta_head.load_state_dict(torch.load("delta_head.pt", map_location=DEVICE)) | |
if os.path.exists("flag_head.pt"): | |
self.model.flag_head.load_state_dict(torch.load("flag_head.pt", map_location=DEVICE)) | |
if os.path.exists("flag_threshold_head.pt"): | |
self.model.flag_threshold_head.load_state_dict(torch.load("flag_threshold_head.pt", map_location=DEVICE)) | |
self.model.eval() | |
def get(self): | |
return self.tokenizer, self.model, self.flags_order | |