m97j commited on
Commit
f1f1dc0
Β·
1 Parent(s): b621289
Files changed (1) hide show
  1. model_loader.py +23 -12
model_loader.py CHANGED
@@ -4,6 +4,10 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
4
  from peft import PeftModel
5
  from config import BASE_MODEL, ADAPTERS, DEVICE, HF_TOKEN
6
 
 
 
 
 
7
  def get_current_branch():
8
  if os.path.exists("current_branch.txt"):
9
  with open("current_branch.txt", "r") as f:
@@ -12,49 +16,54 @@ def get_current_branch():
12
 
13
  class ModelWrapper:
14
  def __init__(self):
15
- # Flags 정보 λ‘œλ“œ
16
  flags_path = os.path.join(os.path.dirname(__file__), "flags.json")
17
  self.flags_order = json.load(open(flags_path, encoding="utf-8"))["ALL_FLAGS"]
18
  self.num_flags = len(self.flags_order)
19
 
20
- # ν† ν¬λ‚˜μ΄μ €λŠ” 베이슀 λͺ¨λΈμ—μ„œ λ‘œλ“œ
21
  self.tokenizer = AutoTokenizer.from_pretrained(
22
  BASE_MODEL,
23
  use_fast=True,
24
- token=HF_TOKEN
 
25
  )
26
  if self.tokenizer.pad_token is None:
27
  self.tokenizer.pad_token = self.tokenizer.eos_token
28
  self.tokenizer.padding_side = "right"
 
 
29
 
30
- # 베이슀 λͺ¨λΈ λ‘œλ“œ
31
- branch = get_current_branch()
32
  base = AutoModelForCausalLM.from_pretrained(
33
  BASE_MODEL,
34
- device_map="auto",
 
35
  trust_remote_code=True,
36
  token=HF_TOKEN
37
  )
38
 
39
- base.resize_token_embeddings(151672)
 
40
 
41
- # LoRA μ–΄λŒ‘ν„° 적용
 
42
  self.model = PeftModel.from_pretrained(
43
  base,
44
  ADAPTERS,
45
  revision=branch,
46
- device_map="auto",
 
47
  token=HF_TOKEN
48
  )
49
 
50
-
51
- # μ»€μŠ€ν…€ ν—€λ“œ μΆ”κ°€
52
  hidden_size = self.model.config.hidden_size
53
  self.model.delta_head = nn.Linear(hidden_size, 2).to(DEVICE)
54
  self.model.flag_head = nn.Linear(hidden_size, self.num_flags).to(DEVICE)
55
  self.model.flag_threshold_head = nn.Linear(hidden_size, self.num_flags).to(DEVICE)
56
 
57
- # .pt 파일이 μ—†μœΌλ©΄ κ·Έλƒ₯ λ„˜μ–΄κ°
58
  for head_name, file_name in [
59
  ("delta_head", "delta_head.pt"),
60
  ("flag_head", "flag_head.pt"),
@@ -68,6 +77,8 @@ class ModelWrapper:
68
  except Exception as e:
69
  print(f"[WARN] Failed to load {file_name}: {e}")
70
 
 
 
71
  self.model.eval()
72
 
73
  def get(self):
 
4
  from peft import PeftModel
5
  from config import BASE_MODEL, ADAPTERS, DEVICE, HF_TOKEN
6
 
7
+ ADAPTER_VOCAB_SIZE = 151672 # ν•™μŠ΅ μ‹œμ  vocab size (둜그 κΈ°μ€€)
8
+
9
+ SPECIALS = ["<SYS>", "<CTX>", "<PLAYER>", "<NPC>", "<STATE>", "<RAG>", "<PLAYER_STATE>"]
10
+
11
  def get_current_branch():
12
  if os.path.exists("current_branch.txt"):
13
  with open("current_branch.txt", "r") as f:
 
16
 
17
  class ModelWrapper:
18
  def __init__(self):
19
+ # Flags 정보
20
  flags_path = os.path.join(os.path.dirname(__file__), "flags.json")
21
  self.flags_order = json.load(open(flags_path, encoding="utf-8"))["ALL_FLAGS"]
22
  self.num_flags = len(self.flags_order)
23
 
24
+ # 1) ν† ν¬λ‚˜μ΄μ € (ν•™μŠ΅κ³Ό 동일 μ˜΅μ…˜ + SPECIALS)
25
  self.tokenizer = AutoTokenizer.from_pretrained(
26
  BASE_MODEL,
27
  use_fast=True,
28
+ token=HF_TOKEN,
29
+ trust_remote_code=True
30
  )
31
  if self.tokenizer.pad_token is None:
32
  self.tokenizer.pad_token = self.tokenizer.eos_token
33
  self.tokenizer.padding_side = "right"
34
+ # ν•™μŠ΅ μ‹œ μΆ”κ°€ν–ˆλ˜ 특수 토큰 μž¬ν˜„
35
+ self.tokenizer.add_special_tokens({"additional_special_tokens": SPECIALS})
36
 
37
+ # 2) 베이슀 λͺ¨λΈ (μ˜€ν”„λ‘œλ”© 끄고 λ‘œλ“œ)
 
38
  base = AutoModelForCausalLM.from_pretrained(
39
  BASE_MODEL,
40
+ device_map=None, # βœ… μ˜€ν”„λ‘œλ”© λΉ„ν™œμ„±ν™”
41
+ low_cpu_mem_usage=False, # βœ… meta ν…μ„œ 생성 λ°©μ§€
42
  trust_remote_code=True,
43
  token=HF_TOKEN
44
  )
45
 
46
+ # 3) ν•™μŠ΅ μ‹œ vocab size둜 κ°•μ œ λ¦¬μ‚¬μ΄μ¦ˆ (μ–΄λŒ‘ν„° λ‘œλ“œ 전에)
47
+ base.resize_token_embeddings(ADAPTER_VOCAB_SIZE)
48
 
49
+ # 4) LoRA μ–΄λŒ‘ν„° 적용 (μ˜€ν”„λ‘œλ”© 끄고 λ‘œλ“œ)
50
+ branch = get_current_branch()
51
  self.model = PeftModel.from_pretrained(
52
  base,
53
  ADAPTERS,
54
  revision=branch,
55
+ device_map=None, # βœ… μ˜€ν”„λ‘œλ”© λΉ„ν™œμ„±ν™”
56
+ low_cpu_mem_usage=False, # βœ… meta ν…μ„œ 생성 λ°©μ§€
57
  token=HF_TOKEN
58
  )
59
 
60
+ # 5) μ»€μŠ€ν…€ ν—€λ“œ
 
61
  hidden_size = self.model.config.hidden_size
62
  self.model.delta_head = nn.Linear(hidden_size, 2).to(DEVICE)
63
  self.model.flag_head = nn.Linear(hidden_size, self.num_flags).to(DEVICE)
64
  self.model.flag_threshold_head = nn.Linear(hidden_size, self.num_flags).to(DEVICE)
65
 
66
+ # 6) μ»€μŠ€ν…€ ν—€λ“œ κ°€μ€‘μΉ˜ λ‘œλ“œ(μžˆμ„ 경우)
67
  for head_name, file_name in [
68
  ("delta_head", "delta_head.pt"),
69
  ("flag_head", "flag_head.pt"),
 
77
  except Exception as e:
78
  print(f"[WARN] Failed to load {file_name}: {e}")
79
 
80
+ # 7) λ””λ°”μ΄μŠ€ 배치
81
+ self.model.to(DEVICE)
82
  self.model.eval()
83
 
84
  def get(self):