File size: 3,549 Bytes
d1055cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import io
from pprint import pformat

import gradio as gr
from hbutils.string import titleize
from hfutils.repository import hf_hub_repo_url

from imgutils.tagging.pixai import _open_default_category_thresholds, get_pixai_tags

REPO_ID = 'deepghs/pixai-tagger-v0.9-onnx'

if __name__ == '__main__':
    with gr.Blocks() as demo:
        with gr.Row():
            with gr.Column():
                repo_url = hf_hub_repo_url(repo_id=REPO_ID, repo_type='model')
                gr.HTML(f'<h2 style="text-align: center;">Tagger Demo For {REPO_ID}</h2>')
                gr.Markdown(f'This is the quick demo for tagger model [{REPO_ID}]({repo_url}). '
                            f'Powered by `dghs-imgutils`\'s quick demo module.')

        with gr.Row():
            thresholds, names = _open_default_category_thresholds(model_name=REPO_ID)
            categories = sorted(set(names.keys()))

            with gr.Column():
                with gr.Row():
                    gr_input_image = gr.Image(type='pil', label='Original Image')
                with gr.Row():
                    gr_thresholds = []
                    for category in categories:
                        gr_cate_threshold = gr.Slider(
                            minimum=0.0,
                            maximum=1.0,
                            value=thresholds[category],
                            step=0.001,
                            label=f'Threshold for {titleize(names[category])}',
                        )
                        gr_thresholds.append(gr_cate_threshold)
                with gr.Row():
                    gr_submit = gr.Button(value='Submit', variant='primary')

            with gr.Column():
                with gr.Tabs():
                    gr_preds = []
                    for category in categories:
                        with gr.Tab(f'{titleize(names[category])}'):
                            gr_cate_label = gr.Label(f'{titleize(names[category])} Prediction')
                            gr_preds.append(gr_cate_label)
                    with gr.Tab('IPs Mapping'):
                        gr_ips_mapping = gr.TextArea(label="IPs (string)", lines=15)
                    with gr.Tab('Text Output'):
                        gr_text_output = gr.TextArea(label="Output (string)", lines=15)


            def _fn_submit(image, *thresholds):
                _ths = {
                    category: cate_ths
                    for category, cate_ths in zip(categories, thresholds)
                }

                fmt = {
                    **names,
                    'ips_mapping': 'ips_mapping',
                    'ips': 'ips',
                }
                res = get_pixai_tags(image=image, model_name=REPO_ID, thresholds=_ths, fmt=fmt)
                with io.StringIO() as sf:
                    for category in categories:
                        print(f'# {names[category]} (#{category})', file=sf)
                        print(f', '.join(res[category].keys()), file=sf)
                        print(f'', file=sf)
                    print(f'# IPs', file=sf)
                    print(f', '.join(res['ips']), file=sf)
                    print(f'', file=sf)
                    return sf.getvalue(), pformat(res['ips_mapping']), \
                        *[res[category] for category in categories]


            gr_submit.click(
                fn=_fn_submit,
                inputs=[gr_input_image, *gr_thresholds],
                outputs=[gr_text_output, gr_ips_mapping, *gr_preds]
            )

    demo.launch()