Spaces:
Running
Running
import io | |
from pprint import pformat | |
import gradio as gr | |
from hbutils.string import titleize | |
from hfutils.repository import hf_hub_repo_url | |
from imgutils.tagging.pixai import _open_default_category_thresholds, get_pixai_tags | |
REPO_ID = 'deepghs/pixai-tagger-v0.9-onnx' | |
if __name__ == '__main__': | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
repo_url = hf_hub_repo_url(repo_id=REPO_ID, repo_type='model') | |
gr.HTML(f'<h2 style="text-align: center;">Tagger Demo For {REPO_ID}</h2>') | |
gr.Markdown(f'This is the quick demo for tagger model [{REPO_ID}]({repo_url}). ' | |
f'Powered by `dghs-imgutils`\'s quick demo module.') | |
with gr.Row(): | |
thresholds, names = _open_default_category_thresholds(model_name=REPO_ID) | |
categories = sorted(set(names.keys())) | |
with gr.Column(): | |
with gr.Row(): | |
gr_input_image = gr.Image(type='pil', label='Original Image') | |
with gr.Row(): | |
gr_thresholds = [] | |
for category in categories: | |
gr_cate_threshold = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=thresholds[category], | |
step=0.001, | |
label=f'Threshold for {titleize(names[category])}', | |
) | |
gr_thresholds.append(gr_cate_threshold) | |
with gr.Row(): | |
gr_submit = gr.Button(value='Submit', variant='primary') | |
with gr.Column(): | |
with gr.Tabs(): | |
gr_preds = [] | |
for category in categories: | |
with gr.Tab(f'{titleize(names[category])}'): | |
gr_cate_label = gr.Label(f'{titleize(names[category])} Prediction') | |
gr_preds.append(gr_cate_label) | |
with gr.Tab('IPs Mapping'): | |
gr_ips_mapping = gr.TextArea(label="IPs (string)", lines=15) | |
with gr.Tab('Text Output'): | |
gr_text_output = gr.TextArea(label="Output (string)", lines=15) | |
def _fn_submit(image, *thresholds): | |
_ths = { | |
category: cate_ths | |
for category, cate_ths in zip(categories, thresholds) | |
} | |
fmt = { | |
**names, | |
'ips_mapping': 'ips_mapping', | |
'ips': 'ips', | |
} | |
res = get_pixai_tags(image=image, model_name=REPO_ID, thresholds=_ths, fmt=fmt) | |
with io.StringIO() as sf: | |
for category in categories: | |
print(f'# {names[category]} (#{category})', file=sf) | |
print(f', '.join(res[category].keys()), file=sf) | |
print(f'', file=sf) | |
print(f'# IPs', file=sf) | |
print(f', '.join(res['ips']), file=sf) | |
print(f'', file=sf) | |
return sf.getvalue(), pformat(res['ips_mapping']), \ | |
*[res[category] for category in categories] | |
gr_submit.click( | |
fn=_fn_submit, | |
inputs=[gr_input_image, *gr_thresholds], | |
outputs=[gr_text_output, gr_ips_mapping, *gr_preds] | |
) | |
demo.launch() | |