File size: 2,147 Bytes
0fc77f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
from config import DEVICE, MAX_LENGTH, GEN_MAX_NEW_TOKENS, GEN_TEMPERATURE, GEN_TOP_P
from model_loader import ModelWrapper

# μ „μ—­ λ‘œλ“œ (μ„œλ²„ μ‹œμž‘ μ‹œ 1회)
wrapper = ModelWrapper()
tokenizer, model, flags_order = wrapper.get()

GEN_PARAMS = {
    "max_new_tokens": GEN_MAX_NEW_TOKENS,
    "temperature": GEN_TEMPERATURE,
    "top_p": GEN_TOP_P,
    "do_sample": True,
    "repetition_penalty": 1.05,
}

def run_inference(prompt: str):
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MAX_LENGTH).to(DEVICE)

    with torch.no_grad():
        gen_ids = model.generate(**inputs, **GEN_PARAMS)
        generated_text = tokenizer.decode(
            gen_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True
        )

        outputs = model(**inputs, output_hidden_states=True)
        h = outputs.hidden_states[-1]

        STATE_ID = tokenizer.convert_tokens_to_ids("<STATE>")
        ids = inputs["input_ids"]
        mask = (ids == STATE_ID).unsqueeze(-1)
        if mask.any():
            counts = mask.sum(dim=1).clamp_min(1)
            pooled = (h * mask).sum(dim=1) / counts
        else:
            pooled = h[:, -1, :]

        delta_pred = torch.tanh(model.delta_head(pooled))[0].cpu().tolist()
        flag_prob = torch.sigmoid(model.flag_head(pooled))[0].cpu().tolist()
        flag_thr = torch.sigmoid(model.flag_threshold_head(pooled))[0].cpu().tolist()

    flags_prob_dict = {name: round(prob, 6) for name, prob in zip(flags_order, flag_prob)}
    flags_thr_dict = {name: round(thr, 6) for name, thr in zip(flags_order, flag_thr)}

    return {
        "npc_output_text": generated_text.strip(),
        "deltas": {
            "trust": float(delta_pred[0]),
            "relationship": float(delta_pred[1]),
        },
        "flags_prob": flags_prob_dict,
        "flags_thr": flags_thr_dict,
    }

def reload_model(branch="latest"):
    global wrapper, tokenizer, model, flags_order
    wrapper = ModelWrapper(branch=branch)  # branch 인자둜 latest 전달
    tokenizer, model, flags_order = wrapper.get()
    print(f"Model reloaded from branch: {branch}")