m97j's picture
update app tab structure and live test
2249ab6
raw
history blame
1.65 kB
import gradio as gr
from inference import run_inference
from webtest_prompt import build_webtest_prompt
# Web Test UI ํ˜ธ์ถœ ํ•จ์ˆ˜
def gradio_infer(npc_id, npc_location, player_utt):
prompt = build_webtest_prompt(npc_id, npc_location, player_utt)
result = run_inference(prompt)
return result["npc_output_text"], result["deltas"], result["flags_prob"]
# ping: ์ƒํƒœ ํ™•์ธ ๋ฐ ๊นจ์šฐ๊ธฐ
def ping():
# ๋ชจ๋ธ์ด ๋กœ๋“œ๋˜์–ด ์žˆ๋Š”์ง€ ํ™•์ธ, ์—†์œผ๋ฉด ๋กœ๋“œ
global wrapper, tokenizer, model, flags_order
if 'model' not in globals() or model is None:
from model_loader import ModelWrapper
wrapper = ModelWrapper()
tokenizer, model, flags_order = wrapper.get()
return {"status": "awake"}
with gr.Blocks() as demo:
gr.Markdown("## NPC Main Model Inference")
with gr.Tab("Web Test UI"):
npc_id = gr.Textbox(label="NPC ID")
npc_loc = gr.Textbox(label="NPC Location")
player_utt = gr.Textbox(label="Player Utterance")
npc_resp = gr.Textbox(label="NPC Response")
deltas = gr.JSON(label="Deltas")
flags = gr.JSON(label="Flags Probabilities")
btn = gr.Button("Run Inference")
# Web Test ์ „์šฉ (api_name ์ œ๊ฑฐ)
btn.click(
fn=gradio_infer,
inputs=[npc_id, npc_loc, player_utt],
outputs=[npc_resp, deltas, flags]
)
# ping ์—”๋“œํฌ์ธํŠธ (์ƒํƒœ ํ™•์ธ/๊นจ์šฐ๊ธฐ)
gr.Button("Ping Server").click(
fn=ping,
inputs=[],
outputs=[],
api_name="ping"
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)