Spaces:
Sleeping
Sleeping
| import os, json, torch | |
| import torch.nn as nn | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import PeftModel | |
| from config import BASE_MODEL, ADAPTERS, DEVICE, HF_TOKEN | |
| ADAPTER_VOCAB_SIZE = 151672 # νμ΅ μμ vocab size (λ‘κ·Έ κΈ°μ€) | |
| SPECIALS = ["<SYS>", "<CTX>", "<PLAYER>", "<NPC>", "<STATE>", "<RAG>", "<PLAYER_STATE>"] | |
| def get_current_branch(): | |
| if os.path.exists("current_branch.txt"): | |
| with open("current_branch.txt", "r") as f: | |
| return f.read().strip() | |
| return "latest" | |
| class ModelWrapper: | |
| def __init__(self): | |
| # Flags μ 보 | |
| flags_path = os.path.join(os.path.dirname(__file__), "flags.json") | |
| self.flags_order = json.load(open(flags_path, encoding="utf-8"))["ALL_FLAGS"] | |
| self.num_flags = len(self.flags_order) | |
| # 1) ν ν¬λμ΄μ (νμ΅κ³Ό λμΌ μ΅μ + SPECIALS) | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| BASE_MODEL, | |
| use_fast=True, | |
| token=HF_TOKEN, | |
| trust_remote_code=True | |
| ) | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| self.tokenizer.padding_side = "right" | |
| # νμ΅ μ μΆκ°νλ νΉμ ν ν° μ¬ν | |
| self.tokenizer.add_special_tokens({"additional_special_tokens": SPECIALS}) | |
| # 2) λ² μ΄μ€ λͺ¨λΈ (μ€νλ‘λ© λκ³ λ‘λ) | |
| base = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| device_map=None, # β μ€νλ‘λ© λΉνμ±ν | |
| low_cpu_mem_usage=False, # β meta ν μ μμ± λ°©μ§ | |
| trust_remote_code=True, | |
| token=HF_TOKEN | |
| ) | |
| # 3) νμ΅ μ vocab sizeλ‘ κ°μ 리μ¬μ΄μ¦ (μ΄λν° λ‘λ μ μ) | |
| base.resize_token_embeddings(ADAPTER_VOCAB_SIZE) | |
| # 4) LoRA μ΄λν° μ μ© (μ€νλ‘λ© λκ³ λ‘λ) | |
| branch = get_current_branch() | |
| self.model = PeftModel.from_pretrained( | |
| base, | |
| ADAPTERS, | |
| revision=branch, | |
| device_map=None, # β μ€νλ‘λ© λΉνμ±ν | |
| low_cpu_mem_usage=False, # β meta ν μ μμ± λ°©μ§ | |
| token=HF_TOKEN | |
| ) | |
| # 5) 컀μ€ν ν€λ | |
| hidden_size = self.model.config.hidden_size | |
| self.model.delta_head = nn.Linear(hidden_size, 2).to(DEVICE) | |
| self.model.flag_head = nn.Linear(hidden_size, self.num_flags).to(DEVICE) | |
| self.model.flag_threshold_head = nn.Linear(hidden_size, self.num_flags).to(DEVICE) | |
| # 6) 컀μ€ν ν€λ κ°μ€μΉ λ‘λ(μμ κ²½μ°) | |
| for head_name, file_name in [ | |
| ("delta_head", "delta_head.pt"), | |
| ("flag_head", "flag_head.pt"), | |
| ("flag_threshold_head", "flag_threshold_head.pt") | |
| ]: | |
| try: | |
| if os.path.exists(file_name): | |
| getattr(self.model, head_name).load_state_dict( | |
| torch.load(file_name, map_location=DEVICE) | |
| ) | |
| except Exception as e: | |
| print(f"[WARN] Failed to load {file_name}: {e}") | |
| # 7) λλ°μ΄μ€ λ°°μΉ | |
| self.model.to(DEVICE) | |
| self.model.eval() | |
| def get(self): | |
| return self.tokenizer, self.model, self.flags_order | |