import io from pprint import pformat import gradio as gr from hbutils.string import titleize from hfutils.repository import hf_hub_repo_url from imgutils.tagging.pixai import _open_default_category_thresholds, get_pixai_tags REPO_ID = 'deepghs/pixai-tagger-v0.9-onnx' if __name__ == '__main__': with gr.Blocks() as demo: with gr.Row(): with gr.Column(): repo_url = hf_hub_repo_url(repo_id=REPO_ID, repo_type='model') gr.HTML(f'

Tagger Demo For {REPO_ID}

') gr.Markdown(f'This is the quick demo for tagger model [{REPO_ID}]({repo_url}). ' f'Powered by `dghs-imgutils`\'s quick demo module.') with gr.Row(): thresholds, names = _open_default_category_thresholds(model_name=REPO_ID) categories = sorted(set(names.keys())) with gr.Column(): with gr.Row(): gr_input_image = gr.Image(type='pil', label='Original Image') with gr.Row(): gr_thresholds = [] for category in categories: gr_cate_threshold = gr.Slider( minimum=0.0, maximum=1.0, value=thresholds[category], step=0.001, label=f'Threshold for {titleize(names[category])}', ) gr_thresholds.append(gr_cate_threshold) with gr.Row(): gr_submit = gr.Button(value='Submit', variant='primary') with gr.Column(): with gr.Tabs(): gr_preds = [] for category in categories: with gr.Tab(f'{titleize(names[category])}'): gr_cate_label = gr.Label(f'{titleize(names[category])} Prediction') gr_preds.append(gr_cate_label) with gr.Tab('IPs Mapping'): gr_ips_mapping = gr.TextArea(label="IPs (string)", lines=15) with gr.Tab('Text Output'): gr_text_output = gr.TextArea(label="Output (string)", lines=15) def _fn_submit(image, *thresholds): _ths = { category: cate_ths for category, cate_ths in zip(categories, thresholds) } fmt = { **names, 'ips_mapping': 'ips_mapping', 'ips': 'ips', } res = get_pixai_tags(image=image, model_name=REPO_ID, thresholds=_ths, fmt=fmt) with io.StringIO() as sf: for category in categories: print(f'# {names[category]} (#{category})', file=sf) print(f', '.join(res[category].keys()), file=sf) print(f'', file=sf) print(f'# IPs', file=sf) print(f', '.join(res['ips']), file=sf) print(f'', file=sf) return sf.getvalue(), pformat(res['ips_mapping']), \ *[res[category] for category in categories] gr_submit.click( fn=_fn_submit, inputs=[gr_input_image, *gr_thresholds], outputs=[gr_text_output, gr_ips_mapping, *gr_preds] ) demo.launch()