File size: 1,767 Bytes
0fc77f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import gradio as gr
from inference import run_inference, reload_model  # reload_model์€ ๋ชจ๋ธ ์žฌ๋กœ๋”ฉ ํ•จ์ˆ˜
from utils_prompt import build_webtest_prompt

def gradio_infer(npc_id, npc_location, player_utt):
    prompt = build_webtest_prompt(npc_id, npc_location, player_utt)
    result = run_inference(prompt)
    return result["npc_output_text"], result["deltas"], result["flags_prob"]


# API ํ˜ธ์ถœ์šฉ
def api_infer(session_id, npc_id, prompt, max_tokens=200):
    result = run_inference(prompt)
    return {
        "session_id": session_id,
        "npc_id": npc_id,
        "npc_response": result["npc_output_text"],
        "deltas": result["deltas"],
        "flags": result["flags_prob"],
        "thresholds": result["flags_thr"]
    }

# Colab์—์„œ ํ˜ธ์ถœํ•  ping endpoint
def ping_reload():
    reload_model(branch="latest")  # latest ๋ธŒ๋žœ์น˜์—์„œ ์žฌ๋‹ค์šด๋กœ๋“œ & ๋กœ๋“œ
    return {"status": "reloaded"}

with gr.Blocks() as demo:
    gr.Markdown("## NPC Main Model Inference")

    with gr.Tab("Web Test UI"):
        npc_id = gr.Textbox(label="NPC ID")
        npc_loc = gr.Textbox(label="NPC Location")
        player_utt = gr.Textbox(label="Player Utterance")
        npc_resp = gr.Textbox(label="NPC Response")
        deltas = gr.JSON(label="Deltas")
        flags = gr.JSON(label="Flags Probabilities")
        btn = gr.Button("Run Inference")
        btn.click(fn=gradio_infer, inputs=[npc_id, npc_loc, player_utt], outputs=[npc_resp, deltas, flags])

    demo.add_api_route("/predict_main", api_infer, methods=["POST"], api_name="predict_main")
    demo.add_api_route("/ping_reload", lambda: ping_reload(), methods=["POST"], api_name="ping_reload")

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)