Spaces:
Running
Running
File size: 2,147 Bytes
0fc77f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import torch
from config import DEVICE, MAX_LENGTH, GEN_MAX_NEW_TOKENS, GEN_TEMPERATURE, GEN_TOP_P
from model_loader import ModelWrapper
# μ μ λ‘λ (μλ² μμ μ 1ν)
wrapper = ModelWrapper()
tokenizer, model, flags_order = wrapper.get()
GEN_PARAMS = {
"max_new_tokens": GEN_MAX_NEW_TOKENS,
"temperature": GEN_TEMPERATURE,
"top_p": GEN_TOP_P,
"do_sample": True,
"repetition_penalty": 1.05,
}
def run_inference(prompt: str):
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MAX_LENGTH).to(DEVICE)
with torch.no_grad():
gen_ids = model.generate(**inputs, **GEN_PARAMS)
generated_text = tokenizer.decode(
gen_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True
)
outputs = model(**inputs, output_hidden_states=True)
h = outputs.hidden_states[-1]
STATE_ID = tokenizer.convert_tokens_to_ids("<STATE>")
ids = inputs["input_ids"]
mask = (ids == STATE_ID).unsqueeze(-1)
if mask.any():
counts = mask.sum(dim=1).clamp_min(1)
pooled = (h * mask).sum(dim=1) / counts
else:
pooled = h[:, -1, :]
delta_pred = torch.tanh(model.delta_head(pooled))[0].cpu().tolist()
flag_prob = torch.sigmoid(model.flag_head(pooled))[0].cpu().tolist()
flag_thr = torch.sigmoid(model.flag_threshold_head(pooled))[0].cpu().tolist()
flags_prob_dict = {name: round(prob, 6) for name, prob in zip(flags_order, flag_prob)}
flags_thr_dict = {name: round(thr, 6) for name, thr in zip(flags_order, flag_thr)}
return {
"npc_output_text": generated_text.strip(),
"deltas": {
"trust": float(delta_pred[0]),
"relationship": float(delta_pred[1]),
},
"flags_prob": flags_prob_dict,
"flags_thr": flags_thr_dict,
}
def reload_model(branch="latest"):
global wrapper, tokenizer, model, flags_order
wrapper = ModelWrapper(branch=branch) # branch μΈμλ‘ latest μ λ¬
tokenizer, model, flags_order = wrapper.get()
print(f"Model reloaded from branch: {branch}")
|