import io
from pprint import pformat
import gradio as gr
from hbutils.string import titleize
from hfutils.repository import hf_hub_repo_url
from imgutils.tagging.pixai import _open_default_category_thresholds, get_pixai_tags
REPO_ID = 'deepghs/pixai-tagger-v0.9-onnx'
if __name__ == '__main__':
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
repo_url = hf_hub_repo_url(repo_id=REPO_ID, repo_type='model')
gr.HTML(f'
Tagger Demo For {REPO_ID}
')
gr.Markdown(f'This is the quick demo for tagger model [{REPO_ID}]({repo_url}). '
f'Powered by `dghs-imgutils`\'s quick demo module.')
with gr.Row():
thresholds, names = _open_default_category_thresholds(model_name=REPO_ID)
categories = sorted(set(names.keys()))
with gr.Column():
with gr.Row():
gr_input_image = gr.Image(type='pil', label='Original Image')
with gr.Row():
gr_thresholds = []
for category in categories:
gr_cate_threshold = gr.Slider(
minimum=0.0,
maximum=1.0,
value=thresholds[category],
step=0.001,
label=f'Threshold for {titleize(names[category])}',
)
gr_thresholds.append(gr_cate_threshold)
with gr.Row():
gr_submit = gr.Button(value='Submit', variant='primary')
with gr.Column():
with gr.Tabs():
gr_preds = []
for category in categories:
with gr.Tab(f'{titleize(names[category])}'):
gr_cate_label = gr.Label(f'{titleize(names[category])} Prediction')
gr_preds.append(gr_cate_label)
with gr.Tab('IPs Mapping'):
gr_ips_mapping = gr.TextArea(label="IPs (string)", lines=15)
with gr.Tab('Text Output'):
gr_text_output = gr.TextArea(label="Output (string)", lines=15)
def _fn_submit(image, *thresholds):
_ths = {
category: cate_ths
for category, cate_ths in zip(categories, thresholds)
}
fmt = {
**names,
'ips_mapping': 'ips_mapping',
'ips': 'ips',
}
res = get_pixai_tags(image=image, model_name=REPO_ID, thresholds=_ths, fmt=fmt)
with io.StringIO() as sf:
for category in categories:
print(f'# {names[category]} (#{category})', file=sf)
print(f', '.join(res[category].keys()), file=sf)
print(f'', file=sf)
print(f'# IPs', file=sf)
print(f', '.join(res['ips']), file=sf)
print(f'', file=sf)
return sf.getvalue(), pformat(res['ips_mapping']), \
*[res[category] for category in categories]
gr_submit.click(
fn=_fn_submit,
inputs=[gr_input_image, *gr_thresholds],
outputs=[gr_text_output, gr_ips_mapping, *gr_preds]
)
demo.launch()