Spaces:
Running
Running
edit gradio ui
Browse files
app.py
CHANGED
@@ -2,13 +2,13 @@ import gradio as gr
|
|
2 |
from inference import run_inference, reload_model # reload_model์ ๋ชจ๋ธ ์ฌ๋ก๋ฉ ํจ์
|
3 |
from utils_prompt import build_webtest_prompt
|
4 |
|
|
|
5 |
def gradio_infer(npc_id, npc_location, player_utt):
|
6 |
prompt = build_webtest_prompt(npc_id, npc_location, player_utt)
|
7 |
result = run_inference(prompt)
|
8 |
return result["npc_output_text"], result["deltas"], result["flags_prob"]
|
9 |
|
10 |
-
|
11 |
-
# API ํธ์ถ์ฉ
|
12 |
def api_infer(session_id, npc_id, prompt, max_tokens=200):
|
13 |
result = run_inference(prompt)
|
14 |
return {
|
@@ -20,7 +20,7 @@ def api_infer(session_id, npc_id, prompt, max_tokens=200):
|
|
20 |
"thresholds": result["flags_thr"]
|
21 |
}
|
22 |
|
23 |
-
#
|
24 |
def ping_reload():
|
25 |
reload_model(branch="latest") # latest ๋ธ๋์น์์ ์ฌ๋ค์ด๋ก๋ & ๋ก๋
|
26 |
return {"status": "reloaded"}
|
@@ -36,10 +36,22 @@ with gr.Blocks() as demo:
|
|
36 |
deltas = gr.JSON(label="Deltas")
|
37 |
flags = gr.JSON(label="Flags Probabilities")
|
38 |
btn = gr.Button("Run Inference")
|
39 |
-
btn.click(fn=gradio_infer, inputs=[npc_id, npc_loc, player_utt], outputs=[npc_resp, deltas, flags])
|
40 |
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
if __name__ == "__main__":
|
45 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
2 |
from inference import run_inference, reload_model # reload_model์ ๋ชจ๋ธ ์ฌ๋ก๋ฉ ํจ์
|
3 |
from utils_prompt import build_webtest_prompt
|
4 |
|
5 |
+
# UI์์ ํธ์ถํ ํจ์
|
6 |
def gradio_infer(npc_id, npc_location, player_utt):
|
7 |
prompt = build_webtest_prompt(npc_id, npc_location, player_utt)
|
8 |
result = run_inference(prompt)
|
9 |
return result["npc_output_text"], result["deltas"], result["flags_prob"]
|
10 |
|
11 |
+
# API ํธ์ถ์ฉ ํจ์
|
|
|
12 |
def api_infer(session_id, npc_id, prompt, max_tokens=200):
|
13 |
result = run_inference(prompt)
|
14 |
return {
|
|
|
20 |
"thresholds": result["flags_thr"]
|
21 |
}
|
22 |
|
23 |
+
# ๋ชจ๋ธ ์ฌ๋ก๋ฉ์ฉ ํจ์
|
24 |
def ping_reload():
|
25 |
reload_model(branch="latest") # latest ๋ธ๋์น์์ ์ฌ๋ค์ด๋ก๋ & ๋ก๋
|
26 |
return {"status": "reloaded"}
|
|
|
36 |
deltas = gr.JSON(label="Deltas")
|
37 |
flags = gr.JSON(label="Flags Probabilities")
|
38 |
btn = gr.Button("Run Inference")
|
|
|
39 |
|
40 |
+
# UI ๋ฒํผ ํด๋ฆญ ์ API ์๋ํฌ์ธํธ๋ ์๋ ์์ฑ
|
41 |
+
btn.click(
|
42 |
+
fn=gradio_infer,
|
43 |
+
inputs=[npc_id, npc_loc, player_utt],
|
44 |
+
outputs=[npc_resp, deltas, flags],
|
45 |
+
api_name="predict_main" # /api/predict_main ์๋ํฌ์ธํธ ์์ฑ
|
46 |
+
)
|
47 |
+
|
48 |
+
# ๋ณ๋์ UI ์์ด API๋ง ์ ๊ณตํ๋ ์๋ํฌ์ธํธ
|
49 |
+
gr.Button("Reload Model").click(
|
50 |
+
fn=ping_reload,
|
51 |
+
inputs=[],
|
52 |
+
outputs=[],
|
53 |
+
api_name="ping_reload" # /api/ping_reload ์๋ํฌ์ธํธ ์์ฑ
|
54 |
+
)
|
55 |
|
56 |
if __name__ == "__main__":
|
57 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|