import os, json, torch import torch.nn as nn from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel from config import BASE_MODEL, ADAPTERS, DEVICE, HF_TOKEN def get_current_branch(): if os.path.exists("current_branch.txt"): with open("current_branch.txt", "r") as f: return f.read().strip() return "latest" class ModelWrapper: def __init__(self): # Flags 정보 로드 flags_path = os.path.join(os.path.dirname(__file__), "flags.json") self.flags_order = json.load(open(flags_path, encoding="utf-8"))["ALL_FLAGS"] self.num_flags = len(self.flags_order) # 토크나이저는 베이스 모델에서 로드 self.tokenizer = AutoTokenizer.from_pretrained( BASE_MODEL, use_fast=True, token=HF_TOKEN ) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.padding_side = "right" # 베이스 모델 로드 branch = get_current_branch() base = AutoModelForCausalLM.from_pretrained( BASE_MODEL, device_map="auto", trust_remote_code=True, token=HF_TOKEN ) # LoRA 어댑터 적용 self.model = PeftModel.from_pretrained( base, ADAPTERS, revision=branch, device_map="auto", token=HF_TOKEN ) # 커스텀 헤드 추가 hidden_size = self.model.config.hidden_size self.model.delta_head = nn.Linear(hidden_size, 2).to(DEVICE) self.model.flag_head = nn.Linear(hidden_size, self.num_flags).to(DEVICE) self.model.flag_threshold_head = nn.Linear(hidden_size, self.num_flags).to(DEVICE) # .pt 파일이 없으면 그냥 넘어감 for head_name, file_name in [ ("delta_head", "delta_head.pt"), ("flag_head", "flag_head.pt"), ("flag_threshold_head", "flag_threshold_head.pt") ]: try: if os.path.exists(file_name): getattr(self.model, head_name).load_state_dict( torch.load(file_name, map_location=DEVICE) ) except Exception as e: print(f"[WARN] Failed to load {file_name}: {e}") self.model.eval() def get(self): return self.tokenizer, self.model, self.flags_order