Spaces:
Running
Running
File size: 3,330 Bytes
0fc77f3 79d952a 0fc77f3 f1f1dc0 0fc77f3 90bc37b 0fc77f3 f1f1dc0 0fc77f3 f1f1dc0 90bc37b 79d952a 90bc37b f1f1dc0 90bc37b 0fc77f3 f1f1dc0 0fc77f3 f1f1dc0 90bc37b f1f1dc0 90bc37b 79d952a f1f1dc0 69eb2dd f1f1dc0 90bc37b 79d952a 90bc37b f1f1dc0 90bc37b 0fc77f3 f1f1dc0 0fc77f3 f1f1dc0 79d952a 0fc77f3 f1f1dc0 0fc77f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import os, json, torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from config import BASE_MODEL, ADAPTERS, DEVICE, HF_TOKEN
ADAPTER_VOCAB_SIZE = 151672 # νμ΅ μμ vocab size (λ‘κ·Έ κΈ°μ€)
SPECIALS = ["<SYS>", "<CTX>", "<PLAYER>", "<NPC>", "<STATE>", "<RAG>", "<PLAYER_STATE>"]
def get_current_branch():
if os.path.exists("current_branch.txt"):
with open("current_branch.txt", "r") as f:
return f.read().strip()
return "latest"
class ModelWrapper:
def __init__(self):
# Flags μ 보
flags_path = os.path.join(os.path.dirname(__file__), "flags.json")
self.flags_order = json.load(open(flags_path, encoding="utf-8"))["ALL_FLAGS"]
self.num_flags = len(self.flags_order)
# 1) ν ν¬λμ΄μ (νμ΅κ³Ό λμΌ μ΅μ
+ SPECIALS)
self.tokenizer = AutoTokenizer.from_pretrained(
BASE_MODEL,
use_fast=True,
token=HF_TOKEN,
trust_remote_code=True
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.padding_side = "right"
# νμ΅ μ μΆκ°νλ νΉμ ν ν° μ¬ν
self.tokenizer.add_special_tokens({"additional_special_tokens": SPECIALS})
# 2) λ² μ΄μ€ λͺ¨λΈ (μ€νλ‘λ© λκ³ λ‘λ)
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
device_map=None, # β
μ€νλ‘λ© λΉνμ±ν
low_cpu_mem_usage=False, # β
meta ν
μ μμ± λ°©μ§
trust_remote_code=True,
token=HF_TOKEN
)
# 3) νμ΅ μ vocab sizeλ‘ κ°μ 리μ¬μ΄μ¦ (μ΄λν° λ‘λ μ μ)
base.resize_token_embeddings(ADAPTER_VOCAB_SIZE)
# 4) LoRA μ΄λν° μ μ© (μ€νλ‘λ© λκ³ λ‘λ)
branch = get_current_branch()
self.model = PeftModel.from_pretrained(
base,
ADAPTERS,
revision=branch,
device_map=None, # β
μ€νλ‘λ© λΉνμ±ν
low_cpu_mem_usage=False, # β
meta ν
μ μμ± λ°©μ§
token=HF_TOKEN
)
# 5) 컀μ€ν
ν€λ
hidden_size = self.model.config.hidden_size
self.model.delta_head = nn.Linear(hidden_size, 2).to(DEVICE)
self.model.flag_head = nn.Linear(hidden_size, self.num_flags).to(DEVICE)
self.model.flag_threshold_head = nn.Linear(hidden_size, self.num_flags).to(DEVICE)
# 6) 컀μ€ν
ν€λ κ°μ€μΉ λ‘λ(μμ κ²½μ°)
for head_name, file_name in [
("delta_head", "delta_head.pt"),
("flag_head", "flag_head.pt"),
("flag_threshold_head", "flag_threshold_head.pt")
]:
try:
if os.path.exists(file_name):
getattr(self.model, head_name).load_state_dict(
torch.load(file_name, map_location=DEVICE)
)
except Exception as e:
print(f"[WARN] Failed to load {file_name}: {e}")
# 7) λλ°μ΄μ€ λ°°μΉ
self.model.to(DEVICE)
self.model.eval()
def get(self):
return self.tokenizer, self.model, self.flags_order
|