Spaces:
Running
Running
import os, json, torch | |
import torch.nn as nn | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from peft import PeftModel | |
from config import BASE_MODEL, ADAPTERS, DEVICE, HF_TOKEN | |
def get_current_branch(): | |
if os.path.exists("current_branch.txt"): | |
with open("current_branch.txt", "r") as f: | |
return f.read().strip() | |
return "latest" | |
class ModelWrapper: | |
def __init__(self): | |
# Flags ์ ๋ณด ๋ก๋ | |
flags_path = os.path.join(os.path.dirname(__file__), "flags.json") | |
self.flags_order = json.load(open(flags_path, encoding="utf-8"))["ALL_FLAGS"] | |
self.num_flags = len(self.flags_order) | |
# ํ ํฌ๋์ด์ ๋ ๋ฒ ์ด์ค ๋ชจ๋ธ์์ ๋ก๋ | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
BASE_MODEL, | |
use_fast=True, | |
token=HF_TOKEN | |
) | |
if self.tokenizer.pad_token is None: | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
self.tokenizer.padding_side = "right" | |
# ๋ฒ ์ด์ค ๋ชจ๋ธ ๋ก๋ | |
branch = get_current_branch() | |
base = AutoModelForCausalLM.from_pretrained( | |
BASE_MODEL, | |
device_map="auto", | |
trust_remote_code=True, | |
token=HF_TOKEN | |
) | |
# LoRA ์ด๋ํฐ ์ ์ฉ | |
self.model = PeftModel.from_pretrained( | |
base, | |
ADAPTERS, | |
revision=branch, | |
device_map="auto", | |
token=HF_TOKEN | |
) | |
# ์ปค์คํ ํค๋ ์ถ๊ฐ | |
hidden_size = self.model.config.hidden_size | |
self.model.delta_head = nn.Linear(hidden_size, 2).to(DEVICE) | |
self.model.flag_head = nn.Linear(hidden_size, self.num_flags).to(DEVICE) | |
self.model.flag_threshold_head = nn.Linear(hidden_size, self.num_flags).to(DEVICE) | |
# .pt ํ์ผ์ด ์์ผ๋ฉด ๊ทธ๋ฅ ๋์ด๊ฐ | |
for head_name, file_name in [ | |
("delta_head", "delta_head.pt"), | |
("flag_head", "flag_head.pt"), | |
("flag_threshold_head", "flag_threshold_head.pt") | |
]: | |
try: | |
if os.path.exists(file_name): | |
getattr(self.model, head_name).load_state_dict( | |
torch.load(file_name, map_location=DEVICE) | |
) | |
except Exception as e: | |
print(f"[WARN] Failed to load {file_name}: {e}") | |
self.model.eval() | |
def get(self): | |
return self.tokenizer, self.model, self.flags_order | |