Spaces:
Running
Running
import os, json, torch | |
import torch.nn as nn | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from peft import PeftModel | |
from config import BASE_MODEL, ADAPTERS, DEVICE, HF_TOKEN | |
ADAPTER_VOCAB_SIZE = 151672 # νμ΅ μμ vocab size (λ‘κ·Έ κΈ°μ€) | |
SPECIALS = ["<SYS>", "<CTX>", "<PLAYER>", "<NPC>", "<STATE>", "<RAG>", "<PLAYER_STATE>"] | |
def get_current_branch(): | |
if os.path.exists("current_branch.txt"): | |
with open("current_branch.txt", "r") as f: | |
return f.read().strip() | |
return "latest" | |
class ModelWrapper: | |
def __init__(self): | |
# Flags μ 보 | |
flags_path = os.path.join(os.path.dirname(__file__), "flags.json") | |
self.flags_order = json.load(open(flags_path, encoding="utf-8"))["ALL_FLAGS"] | |
self.num_flags = len(self.flags_order) | |
# 1) ν ν¬λμ΄μ (νμ΅κ³Ό λμΌ μ΅μ + SPECIALS) | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
BASE_MODEL, | |
use_fast=True, | |
token=HF_TOKEN, | |
trust_remote_code=True | |
) | |
if self.tokenizer.pad_token is None: | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
self.tokenizer.padding_side = "right" | |
# νμ΅ μ μΆκ°νλ νΉμ ν ν° μ¬ν | |
self.tokenizer.add_special_tokens({"additional_special_tokens": SPECIALS}) | |
# 2) λ² μ΄μ€ λͺ¨λΈ (μ€νλ‘λ© λκ³ λ‘λ) | |
base = AutoModelForCausalLM.from_pretrained( | |
BASE_MODEL, | |
device_map=None, # β μ€νλ‘λ© λΉνμ±ν | |
low_cpu_mem_usage=False, # β meta ν μ μμ± λ°©μ§ | |
trust_remote_code=True, | |
token=HF_TOKEN | |
) | |
# 3) νμ΅ μ vocab sizeλ‘ κ°μ 리μ¬μ΄μ¦ (μ΄λν° λ‘λ μ μ) | |
base.resize_token_embeddings(ADAPTER_VOCAB_SIZE) | |
# 4) LoRA μ΄λν° μ μ© (μ€νλ‘λ© λκ³ λ‘λ) | |
branch = get_current_branch() | |
self.model = PeftModel.from_pretrained( | |
base, | |
ADAPTERS, | |
revision=branch, | |
device_map=None, # β μ€νλ‘λ© λΉνμ±ν | |
low_cpu_mem_usage=False, # β meta ν μ μμ± λ°©μ§ | |
token=HF_TOKEN | |
) | |
# 5) 컀μ€ν ν€λ | |
hidden_size = self.model.config.hidden_size | |
self.model.delta_head = nn.Linear(hidden_size, 2).to(DEVICE) | |
self.model.flag_head = nn.Linear(hidden_size, self.num_flags).to(DEVICE) | |
self.model.flag_threshold_head = nn.Linear(hidden_size, self.num_flags).to(DEVICE) | |
# 6) 컀μ€ν ν€λ κ°μ€μΉ λ‘λ(μμ κ²½μ°) | |
for head_name, file_name in [ | |
("delta_head", "delta_head.pt"), | |
("flag_head", "flag_head.pt"), | |
("flag_threshold_head", "flag_threshold_head.pt") | |
]: | |
try: | |
if os.path.exists(file_name): | |
getattr(self.model, head_name).load_state_dict( | |
torch.load(file_name, map_location=DEVICE) | |
) | |
except Exception as e: | |
print(f"[WARN] Failed to load {file_name}: {e}") | |
# 7) λλ°μ΄μ€ λ°°μΉ | |
self.model.to(DEVICE) | |
self.model.eval() | |
def get(self): | |
return self.tokenizer, self.model, self.flags_order | |