Spaces:
Running
Running
import gradio as gr | |
from inference import run_inference, reload_model # reload_model์ ๋ชจ๋ธ ์ฌ๋ก๋ฉ ํจ์ | |
from utils_prompt import build_webtest_prompt | |
def gradio_infer(npc_id, npc_location, player_utt): | |
prompt = build_webtest_prompt(npc_id, npc_location, player_utt) | |
result = run_inference(prompt) | |
return result["npc_output_text"], result["deltas"], result["flags_prob"] | |
# API ํธ์ถ์ฉ | |
def api_infer(session_id, npc_id, prompt, max_tokens=200): | |
result = run_inference(prompt) | |
return { | |
"session_id": session_id, | |
"npc_id": npc_id, | |
"npc_response": result["npc_output_text"], | |
"deltas": result["deltas"], | |
"flags": result["flags_prob"], | |
"thresholds": result["flags_thr"] | |
} | |
# Colab์์ ํธ์ถํ ping endpoint | |
def ping_reload(): | |
reload_model(branch="latest") # latest ๋ธ๋์น์์ ์ฌ๋ค์ด๋ก๋ & ๋ก๋ | |
return {"status": "reloaded"} | |
with gr.Blocks() as demo: | |
gr.Markdown("## NPC Main Model Inference") | |
with gr.Tab("Web Test UI"): | |
npc_id = gr.Textbox(label="NPC ID") | |
npc_loc = gr.Textbox(label="NPC Location") | |
player_utt = gr.Textbox(label="Player Utterance") | |
npc_resp = gr.Textbox(label="NPC Response") | |
deltas = gr.JSON(label="Deltas") | |
flags = gr.JSON(label="Flags Probabilities") | |
btn = gr.Button("Run Inference") | |
btn.click(fn=gradio_infer, inputs=[npc_id, npc_loc, player_utt], outputs=[npc_resp, deltas, flags]) | |
demo.add_api_route("/predict_main", api_infer, methods=["POST"], api_name="predict_main") | |
demo.add_api_route("/ping_reload", lambda: ping_reload(), methods=["POST"], api_name="ping_reload") | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860) | |