import os, json, torch import torch.nn as nn from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel from config import BASE_MODEL, ADAPTER_MODEL, DEVICE, HF_TOKEN def get_current_branch(): if os.path.exists("current_branch.txt"): with open("current_branch.txt", "r") as f: return f.read().strip() return "latest" class ModelWrapper: def __init__(self): flags_path = os.path.join(os.path.dirname(__file__), "flags.json") self.flags_order = json.load(open(flags_path, encoding="utf-8"))["ALL_FLAGS"] self.num_flags = len(self.flags_order) # 토큰 전달 self.tokenizer = AutoTokenizer.from_pretrained( ADAPTER_MODEL, use_fast=True, token=HF_TOKEN ) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.padding_side = "right" branch = get_current_branch() base = AutoModelForCausalLM.from_pretrained( BASE_MODEL, device_map="auto", trust_remote_code=True, token=HF_TOKEN ) self.model = PeftModel.from_pretrained( base, ADAPTER_MODEL, revision=branch, device_map="auto", token=HF_TOKEN ) hidden_size = self.model.config.hidden_size self.model.delta_head = nn.Linear(hidden_size, 2).to(DEVICE) self.model.flag_head = nn.Linear(hidden_size, self.num_flags).to(DEVICE) self.model.flag_threshold_head = nn.Linear(hidden_size, self.num_flags).to(DEVICE) if os.path.exists("delta_head.pt"): self.model.delta_head.load_state_dict(torch.load("delta_head.pt", map_location=DEVICE)) if os.path.exists("flag_head.pt"): self.model.flag_head.load_state_dict(torch.load("flag_head.pt", map_location=DEVICE)) if os.path.exists("flag_threshold_head.pt"): self.model.flag_threshold_head.load_state_dict(torch.load("flag_threshold_head.pt", map_location=DEVICE)) self.model.eval() def get(self): return self.tokenizer, self.model, self.flags_order