PersonaChatEngine_hf-serve / model_loader.py
m97j's picture
Adjust code order
69eb2dd
raw
history blame
2.56 kB
import os, json, torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from config import BASE_MODEL, ADAPTERS, DEVICE, HF_TOKEN
def get_current_branch():
if os.path.exists("current_branch.txt"):
with open("current_branch.txt", "r") as f:
return f.read().strip()
return "latest"
class ModelWrapper:
def __init__(self):
# Flags ์ •๋ณด ๋กœ๋“œ
flags_path = os.path.join(os.path.dirname(__file__), "flags.json")
self.flags_order = json.load(open(flags_path, encoding="utf-8"))["ALL_FLAGS"]
self.num_flags = len(self.flags_order)
# ํ† ํฌ๋‚˜์ด์ €๋Š” ๋ฒ ์ด์Šค ๋ชจ๋ธ์—์„œ ๋กœ๋“œ
self.tokenizer = AutoTokenizer.from_pretrained(
BASE_MODEL,
use_fast=True,
token=HF_TOKEN
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.padding_side = "right"
# ๋ฒ ์ด์Šค ๋ชจ๋ธ ๋กœ๋“œ
branch = get_current_branch()
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
device_map="auto",
trust_remote_code=True,
token=HF_TOKEN
)
self.model.resize_token_embeddings(len(self.tokenizer))
# LoRA ์–ด๋Œ‘ํ„ฐ ์ ์šฉ
self.model = PeftModel.from_pretrained(
base,
ADAPTERS,
revision=branch,
device_map="auto",
token=HF_TOKEN
)
# ์ปค์Šคํ…€ ํ—ค๋“œ ์ถ”๊ฐ€
hidden_size = self.model.config.hidden_size
self.model.delta_head = nn.Linear(hidden_size, 2).to(DEVICE)
self.model.flag_head = nn.Linear(hidden_size, self.num_flags).to(DEVICE)
self.model.flag_threshold_head = nn.Linear(hidden_size, self.num_flags).to(DEVICE)
# .pt ํŒŒ์ผ์ด ์—†์œผ๋ฉด ๊ทธ๋ƒฅ ๋„˜์–ด๊ฐ
for head_name, file_name in [
("delta_head", "delta_head.pt"),
("flag_head", "flag_head.pt"),
("flag_threshold_head", "flag_threshold_head.pt")
]:
try:
if os.path.exists(file_name):
getattr(self.model, head_name).load_state_dict(
torch.load(file_name, map_location=DEVICE)
)
except Exception as e:
print(f"[WARN] Failed to load {file_name}: {e}")
self.model.eval()
def get(self):
return self.tokenizer, self.model, self.flags_order