narugo1992 commited on
Commit
d1055cf
·
verified ·
1 Parent(s): bdb04fa

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -0
app.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from pprint import pformat
3
+
4
+ import gradio as gr
5
+ from hbutils.string import titleize
6
+ from hfutils.repository import hf_hub_repo_url
7
+
8
+ from imgutils.tagging.pixai import _open_default_category_thresholds, get_pixai_tags
9
+
10
+ REPO_ID = 'deepghs/pixai-tagger-v0.9-onnx'
11
+
12
+ if __name__ == '__main__':
13
+ with gr.Blocks() as demo:
14
+ with gr.Row():
15
+ with gr.Column():
16
+ repo_url = hf_hub_repo_url(repo_id=REPO_ID, repo_type='model')
17
+ gr.HTML(f'<h2 style="text-align: center;">Tagger Demo For {REPO_ID}</h2>')
18
+ gr.Markdown(f'This is the quick demo for tagger model [{REPO_ID}]({repo_url}). '
19
+ f'Powered by `dghs-imgutils`\'s quick demo module.')
20
+
21
+ with gr.Row():
22
+ thresholds, names = _open_default_category_thresholds(model_name=REPO_ID)
23
+ categories = sorted(set(names.keys()))
24
+
25
+ with gr.Column():
26
+ with gr.Row():
27
+ gr_input_image = gr.Image(type='pil', label='Original Image')
28
+ with gr.Row():
29
+ gr_thresholds = []
30
+ for category in categories:
31
+ gr_cate_threshold = gr.Slider(
32
+ minimum=0.0,
33
+ maximum=1.0,
34
+ value=thresholds[category],
35
+ step=0.001,
36
+ label=f'Threshold for {titleize(names[category])}',
37
+ )
38
+ gr_thresholds.append(gr_cate_threshold)
39
+ with gr.Row():
40
+ gr_submit = gr.Button(value='Submit', variant='primary')
41
+
42
+ with gr.Column():
43
+ with gr.Tabs():
44
+ gr_preds = []
45
+ for category in categories:
46
+ with gr.Tab(f'{titleize(names[category])}'):
47
+ gr_cate_label = gr.Label(f'{titleize(names[category])} Prediction')
48
+ gr_preds.append(gr_cate_label)
49
+ with gr.Tab('IPs Mapping'):
50
+ gr_ips_mapping = gr.TextArea(label="IPs (string)", lines=15)
51
+ with gr.Tab('Text Output'):
52
+ gr_text_output = gr.TextArea(label="Output (string)", lines=15)
53
+
54
+
55
+ def _fn_submit(image, *thresholds):
56
+ _ths = {
57
+ category: cate_ths
58
+ for category, cate_ths in zip(categories, thresholds)
59
+ }
60
+
61
+ fmt = {
62
+ **names,
63
+ 'ips_mapping': 'ips_mapping',
64
+ 'ips': 'ips',
65
+ }
66
+ res = get_pixai_tags(image=image, model_name=REPO_ID, thresholds=_ths, fmt=fmt)
67
+ with io.StringIO() as sf:
68
+ for category in categories:
69
+ print(f'# {names[category]} (#{category})', file=sf)
70
+ print(f', '.join(res[category].keys()), file=sf)
71
+ print(f'', file=sf)
72
+ print(f'# IPs', file=sf)
73
+ print(f', '.join(res['ips']), file=sf)
74
+ print(f'', file=sf)
75
+ return sf.getvalue(), pformat(res['ips_mapping']), \
76
+ *[res[category] for category in categories]
77
+
78
+
79
+ gr_submit.click(
80
+ fn=_fn_submit,
81
+ inputs=[gr_input_image, *gr_thresholds],
82
+ outputs=[gr_text_output, gr_ips_mapping, *gr_preds]
83
+ )
84
+
85
+ demo.launch()