File size: 3,330 Bytes
0fc77f3
 
 
 
79d952a
0fc77f3
f1f1dc0
 
 
 
0fc77f3
 
 
 
90bc37b
0fc77f3
 
 
f1f1dc0
0fc77f3
 
 
 
f1f1dc0
90bc37b
79d952a
90bc37b
f1f1dc0
 
90bc37b
0fc77f3
 
 
f1f1dc0
 
0fc77f3
f1f1dc0
90bc37b
 
f1f1dc0
 
90bc37b
 
 
79d952a
f1f1dc0
 
69eb2dd
f1f1dc0
 
90bc37b
 
79d952a
90bc37b
f1f1dc0
 
90bc37b
 
0fc77f3
f1f1dc0
0fc77f3
 
 
 
 
f1f1dc0
79d952a
 
 
 
 
 
 
 
 
 
 
 
0fc77f3
f1f1dc0
 
0fc77f3
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os, json, torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from config import BASE_MODEL, ADAPTERS, DEVICE, HF_TOKEN

ADAPTER_VOCAB_SIZE = 151672  # ν•™μŠ΅ μ‹œμ  vocab size (둜그 κΈ°μ€€)

SPECIALS = ["<SYS>", "<CTX>", "<PLAYER>", "<NPC>", "<STATE>", "<RAG>", "<PLAYER_STATE>"]

def get_current_branch():
    if os.path.exists("current_branch.txt"):
        with open("current_branch.txt", "r") as f:
            return f.read().strip()
    return "latest"

class ModelWrapper:
    def __init__(self):
        # Flags 정보
        flags_path = os.path.join(os.path.dirname(__file__), "flags.json")
        self.flags_order = json.load(open(flags_path, encoding="utf-8"))["ALL_FLAGS"]
        self.num_flags = len(self.flags_order)

        # 1) ν† ν¬λ‚˜μ΄μ € (ν•™μŠ΅κ³Ό 동일 μ˜΅μ…˜ + SPECIALS)
        self.tokenizer = AutoTokenizer.from_pretrained(
            BASE_MODEL,
            use_fast=True,
            token=HF_TOKEN,
            trust_remote_code=True
        )
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.padding_side = "right"
        # ν•™μŠ΅ μ‹œ μΆ”κ°€ν–ˆλ˜ 특수 토큰 μž¬ν˜„
        self.tokenizer.add_special_tokens({"additional_special_tokens": SPECIALS})

        # 2) 베이슀 λͺ¨λΈ (μ˜€ν”„λ‘œλ”© 끄고 λ‘œλ“œ)
        base = AutoModelForCausalLM.from_pretrained(
            BASE_MODEL,
            device_map=None,              # βœ… μ˜€ν”„λ‘œλ”© λΉ„ν™œμ„±ν™”
            low_cpu_mem_usage=False,      # βœ… meta ν…μ„œ 생성 λ°©μ§€
            trust_remote_code=True,
            token=HF_TOKEN
        )

        # 3) ν•™μŠ΅ μ‹œ vocab size둜 κ°•μ œ λ¦¬μ‚¬μ΄μ¦ˆ (μ–΄λŒ‘ν„° λ‘œλ“œ 전에)
        base.resize_token_embeddings(ADAPTER_VOCAB_SIZE)

        # 4) LoRA μ–΄λŒ‘ν„° 적용 (μ˜€ν”„λ‘œλ”© 끄고 λ‘œλ“œ)
        branch = get_current_branch()
        self.model = PeftModel.from_pretrained(
            base,
            ADAPTERS,
            revision=branch,
            device_map=None,              # βœ… μ˜€ν”„λ‘œλ”© λΉ„ν™œμ„±ν™”
            low_cpu_mem_usage=False,      # βœ… meta ν…μ„œ 생성 λ°©μ§€
            token=HF_TOKEN
        )

        # 5) μ»€μŠ€ν…€ ν—€λ“œ
        hidden_size = self.model.config.hidden_size
        self.model.delta_head = nn.Linear(hidden_size, 2).to(DEVICE)
        self.model.flag_head = nn.Linear(hidden_size, self.num_flags).to(DEVICE)
        self.model.flag_threshold_head = nn.Linear(hidden_size, self.num_flags).to(DEVICE)

        # 6) μ»€μŠ€ν…€ ν—€λ“œ κ°€μ€‘μΉ˜ λ‘œλ“œ(μžˆμ„ 경우)
        for head_name, file_name in [
            ("delta_head", "delta_head.pt"),
            ("flag_head", "flag_head.pt"),
            ("flag_threshold_head", "flag_threshold_head.pt")
        ]:
            try:
                if os.path.exists(file_name):
                    getattr(self.model, head_name).load_state_dict(
                        torch.load(file_name, map_location=DEVICE)
                    )
            except Exception as e:
                print(f"[WARN] Failed to load {file_name}: {e}")

        # 7) λ””λ°”μ΄μŠ€ 배치
        self.model.to(DEVICE)
        self.model.eval()

    def get(self):
        return self.tokenizer, self.model, self.flags_order