m97j's picture
return debug4
69e85d6
raw
history blame
2.03 kB
import gradio as gr
from inference import run_inference, reload_model # reload_model์€ ๋ชจ๋ธ ์žฌ๋กœ๋”ฉ ํ•จ์ˆ˜
from utils_prompt import build_webtest_prompt
# UI์—์„œ ํ˜ธ์ถœํ•  ํ•จ์ˆ˜
def gradio_infer(npc_id, npc_location, player_utt):
prompt = build_webtest_prompt(npc_id, npc_location, player_utt)
result = run_inference(prompt)
return result["npc_output_text"], result["deltas"], result["flags_prob"]
# API ํ˜ธ์ถœ์šฉ ํ•จ์ˆ˜
def api_infer(session_id, npc_id, prompt, max_tokens=200):
result = run_inference(prompt)
return {
"session_id": session_id,
"npc_id": npc_id,
"npc_response": result["npc_output_text"],
"deltas": result["deltas"],
"flags": result["flags_prob"],
"thresholds": result["flags_thr"]
}
# ๋ชจ๋ธ ์žฌ๋กœ๋”ฉ์šฉ ํ•จ์ˆ˜
def ping_reload():
reload_model(branch="latest") # latest ๋ธŒ๋žœ์น˜์—์„œ ์žฌ๋‹ค์šด๋กœ๋“œ & ๋กœ๋“œ
return {"status": "reloaded"}
with gr.Blocks() as demo:
gr.Markdown("## NPC Main Model Inference")
with gr.Tab("Web Test UI"):
npc_id = gr.Textbox(label="NPC ID")
npc_loc = gr.Textbox(label="NPC Location")
player_utt = gr.Textbox(label="Player Utterance")
npc_resp = gr.Textbox(label="NPC Response")
deltas = gr.JSON(label="Deltas")
flags = gr.JSON(label="Flags Probabilities")
btn = gr.Button("Run Inference")
# UI ๋ฒ„ํŠผ ํด๋ฆญ ์‹œ API ์—”๋“œํฌ์ธํŠธ๋„ ์ž๋™ ์ƒ์„ฑ
btn.click(
fn=gradio_infer,
inputs=[npc_id, npc_loc, player_utt],
outputs=[npc_resp, deltas, flags],
api_name="predict_main" # /api/predict_main ์—”๋“œํฌ์ธํŠธ ์ƒ์„ฑ
)
# ๋ณ„๋„์˜ UI ์—†์ด API๋งŒ ์ œ๊ณตํ•˜๋Š” ์—”๋“œํฌ์ธํŠธ
gr.Button("Reload Model").click(
fn=ping_reload,
inputs=[],
outputs=[],
api_name="ping_reload" # /api/ping_reload ์—”๋“œํฌ์ธํŠธ ์ƒ์„ฑ
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)