m97j commited on
Commit
800d562
ยท
1 Parent(s): f1f1dc0

edit gradio ui

Browse files
Files changed (1) hide show
  1. app.py +18 -6
app.py CHANGED
@@ -2,13 +2,13 @@ import gradio as gr
2
  from inference import run_inference, reload_model # reload_model์€ ๋ชจ๋ธ ์žฌ๋กœ๋”ฉ ํ•จ์ˆ˜
3
  from utils_prompt import build_webtest_prompt
4
 
 
5
  def gradio_infer(npc_id, npc_location, player_utt):
6
  prompt = build_webtest_prompt(npc_id, npc_location, player_utt)
7
  result = run_inference(prompt)
8
  return result["npc_output_text"], result["deltas"], result["flags_prob"]
9
 
10
-
11
- # API ํ˜ธ์ถœ์šฉ
12
  def api_infer(session_id, npc_id, prompt, max_tokens=200):
13
  result = run_inference(prompt)
14
  return {
@@ -20,7 +20,7 @@ def api_infer(session_id, npc_id, prompt, max_tokens=200):
20
  "thresholds": result["flags_thr"]
21
  }
22
 
23
- # Colab์—์„œ ํ˜ธ์ถœํ•  ping endpoint
24
  def ping_reload():
25
  reload_model(branch="latest") # latest ๋ธŒ๋žœ์น˜์—์„œ ์žฌ๋‹ค์šด๋กœ๋“œ & ๋กœ๋“œ
26
  return {"status": "reloaded"}
@@ -36,10 +36,22 @@ with gr.Blocks() as demo:
36
  deltas = gr.JSON(label="Deltas")
37
  flags = gr.JSON(label="Flags Probabilities")
38
  btn = gr.Button("Run Inference")
39
- btn.click(fn=gradio_infer, inputs=[npc_id, npc_loc, player_utt], outputs=[npc_resp, deltas, flags])
40
 
41
- demo.add_api_route("/predict_main", api_infer, methods=["POST"], api_name="predict_main")
42
- demo.add_api_route("/ping_reload", lambda: ping_reload(), methods=["POST"], api_name="ping_reload")
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  if __name__ == "__main__":
45
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
2
  from inference import run_inference, reload_model # reload_model์€ ๋ชจ๋ธ ์žฌ๋กœ๋”ฉ ํ•จ์ˆ˜
3
  from utils_prompt import build_webtest_prompt
4
 
5
+ # UI์—์„œ ํ˜ธ์ถœํ•  ํ•จ์ˆ˜
6
  def gradio_infer(npc_id, npc_location, player_utt):
7
  prompt = build_webtest_prompt(npc_id, npc_location, player_utt)
8
  result = run_inference(prompt)
9
  return result["npc_output_text"], result["deltas"], result["flags_prob"]
10
 
11
+ # API ํ˜ธ์ถœ์šฉ ํ•จ์ˆ˜
 
12
  def api_infer(session_id, npc_id, prompt, max_tokens=200):
13
  result = run_inference(prompt)
14
  return {
 
20
  "thresholds": result["flags_thr"]
21
  }
22
 
23
+ # ๋ชจ๋ธ ์žฌ๋กœ๋”ฉ์šฉ ํ•จ์ˆ˜
24
  def ping_reload():
25
  reload_model(branch="latest") # latest ๋ธŒ๋žœ์น˜์—์„œ ์žฌ๋‹ค์šด๋กœ๋“œ & ๋กœ๋“œ
26
  return {"status": "reloaded"}
 
36
  deltas = gr.JSON(label="Deltas")
37
  flags = gr.JSON(label="Flags Probabilities")
38
  btn = gr.Button("Run Inference")
 
39
 
40
+ # UI ๋ฒ„ํŠผ ํด๋ฆญ ์‹œ API ์—”๋“œํฌ์ธํŠธ๋„ ์ž๋™ ์ƒ์„ฑ
41
+ btn.click(
42
+ fn=gradio_infer,
43
+ inputs=[npc_id, npc_loc, player_utt],
44
+ outputs=[npc_resp, deltas, flags],
45
+ api_name="predict_main" # /api/predict_main ์—”๋“œํฌ์ธํŠธ ์ƒ์„ฑ
46
+ )
47
+
48
+ # ๋ณ„๋„์˜ UI ์—†์ด API๋งŒ ์ œ๊ณตํ•˜๋Š” ์—”๋“œํฌ์ธํŠธ
49
+ gr.Button("Reload Model").click(
50
+ fn=ping_reload,
51
+ inputs=[],
52
+ outputs=[],
53
+ api_name="ping_reload" # /api/ping_reload ์—”๋“œํฌ์ธํŠธ ์ƒ์„ฑ
54
+ )
55
 
56
  if __name__ == "__main__":
57
  demo.launch(server_name="0.0.0.0", server_port=7860)