Spaces:
Running
Running
| import os | |
| import time | |
| import shutil | |
| from pathlib import Path | |
| from typing import Optional | |
| import gradio as gr | |
| from huggingface_hub import snapshot_download | |
| from PIL import Image | |
| # Import your existing inference endpoint implementation | |
| from handler import EndpointHandler | |
| # ------------------------------------------------------------------------------ | |
| # Asset setup: download weights/tags/mapping so local filenames are unchanged | |
| # ------------------------------------------------------------------------------ | |
| REPO_ID = os.environ.get("ASSETS_REPO_ID", "pixai-labs/pixai-tagger-v0.9") | |
| REVISION = os.environ.get("ASSETS_REVISION") # optional pin, e.g. "main" or a commit | |
| MODEL_DIR = os.environ.get("MODEL_DIR", "./assets") # where the handler will look | |
| # Optional: Hugging Face token for private repos | |
| HF_TOKEN = ( | |
| os.environ.get("HUGGINGFACE_HUB_TOKEN") | |
| or os.environ.get("HF_TOKEN") | |
| or os.environ.get("HUGGINGFACE_TOKEN") | |
| or os.environ.get("HUGGINGFACEHUB_API_TOKEN") | |
| ) | |
| REQUIRED_FILES = [ | |
| "model_v0.9.pth", | |
| "tags_v0.9_13k.json", | |
| "char_ip_map.json", | |
| ] | |
| def ensure_assets(repo_id: str, revision: Optional[str], target_dir: str): | |
| """ | |
| 1) snapshot_download the upstream repo (cached by HF Hub) | |
| 2) copy the required files into `target_dir` with the exact filenames expected | |
| """ | |
| target = Path(target_dir) | |
| target.mkdir(parents=True, exist_ok=True) | |
| # Only download if something is missing | |
| missing = [f for f in REQUIRED_FILES if not (target / f).exists()] | |
| if not missing: | |
| return | |
| # Download snapshot (optionally filtered to speed up) | |
| snapshot_path = snapshot_download( | |
| repo_id=repo_id, | |
| revision=revision, | |
| allow_patterns=REQUIRED_FILES, # only pull what we need | |
| token=HF_TOKEN, # authenticate if repo is private | |
| ) | |
| # Copy files into target_dir with the required names | |
| for fname in REQUIRED_FILES: | |
| src = Path(snapshot_path) / fname | |
| dst = target / fname | |
| if not src.exists(): | |
| raise FileNotFoundError( | |
| f"Expected '{fname}' not found in snapshot for {repo_id} @ {revision or 'default'}" | |
| ) | |
| shutil.copyfile(src, dst) | |
| # Fetch assets (no-op if they already exist) | |
| ensure_assets(REPO_ID, REVISION, MODEL_DIR) | |
| # ------------------------------------------------------------------------------ | |
| # Initialize the handler | |
| # ------------------------------------------------------------------------------ | |
| handler = EndpointHandler(MODEL_DIR) | |
| DEVICE_LABEL = f"Device: {handler.device.upper()}" | |
| # ------------------------------------------------------------------------------ | |
| # Gradio wiring | |
| # ------------------------------------------------------------------------------ | |
| def run_inference( | |
| source_choice: str, | |
| image: Optional[Image.Image], | |
| url: str, | |
| general_threshold: float, | |
| character_threshold: float, | |
| mode_val: str, | |
| topk_general_val: int, | |
| topk_character_val: int, | |
| include_scores_val: bool, | |
| underscore_mode_val: bool, | |
| ): | |
| # Determine which input to use based on which Run button invoked the function. | |
| # We'll pass a string flag via source_choice: either "url" or "image". | |
| if source_choice == "image": | |
| if image is None: | |
| raise gr.Error("Please upload an image.") | |
| inputs = image | |
| else: | |
| if not url or not url.strip(): | |
| raise gr.Error("Please provide an image URL.") | |
| inputs = {"url": url.strip()} | |
| params = { | |
| "general_threshold": float(general_threshold), | |
| "character_threshold": float(character_threshold), | |
| "mode": mode_val, | |
| "topk_general": int(topk_general_val), | |
| "topk_character": int(topk_character_val), | |
| "include_scores": bool(include_scores_val), | |
| } | |
| data = {"inputs": inputs, "parameters": params} | |
| started = time.time() | |
| try: | |
| out = handler(data) | |
| except Exception as e: | |
| raise gr.Error(f"Inference error: {e}") from e | |
| latency = round(time.time() - started, 4) | |
| # Individual outputs | |
| if underscore_mode_val: | |
| characters = " ".join(out.get("character", [])) or "β" | |
| ips = " ".join(out.get("ip", [])) or "β" | |
| features = " ".join(out.get("feature", [])) or "β" | |
| elif include_scores_val: | |
| gen_scores = out.get("feature_scores", {}) | |
| char_scores = out.get("character_scores", {}) | |
| characters = ", ".join( | |
| f"{k.replace('_', ' ')} ({char_scores[k]:.2f})" for k in sorted(char_scores, key=char_scores.get, reverse=True) | |
| ) or "β" | |
| ips = ", ".join(tag.replace("_", " ") for tag in out.get("ip", [])) or "β" | |
| features = ", ".join( | |
| f"{k.replace('_', ' ')} ({gen_scores[k]:.2f})" for k in sorted(gen_scores, key=gen_scores.get, reverse=True) | |
| ) or "β" | |
| else: | |
| characters = ", ".join(sorted(t.replace("_", " ") for t in out.get("character", []))) or "β" | |
| ips = ", ".join(tag.replace("_", " ") for tag in out.get("ip", [])) or "β" | |
| features = ", ".join(sorted(t.replace("_", " ") for t in out.get("feature", []))) or "β" | |
| # Combined output: probability-descending if scores available; else character, IP, general | |
| if underscore_mode_val: | |
| combined = " ".join(out.get("character", []) + out.get("ip", []) + out.get("feature", [])) or "β" | |
| else: | |
| char_scores = out.get("character_scores") or {} | |
| gen_scores = out.get("feature_scores") or {} | |
| if include_scores_val and (char_scores or gen_scores): | |
| # Build (tag, score) pairs | |
| char_pairs = [(k, float(char_scores.get(k, 0.0))) for k in out.get("character", [])] | |
| ip_pairs = [(k, 1.0) for k in out.get("ip", [])] # IP has no score; treat equally | |
| gen_pairs = [(k, float(gen_scores.get(k, 0.0))) for k in out.get("feature", [])] | |
| all_pairs = char_pairs + ip_pairs + gen_pairs | |
| all_pairs.sort(key=lambda t: t[1], reverse=True) | |
| combined = ", ".join( | |
| [f"{k.replace('_', ' ')} ({score:.2f})" if (k in char_scores or k in gen_scores) else k.replace('_', ' ') for k, score in all_pairs] | |
| ) or "β" | |
| else: | |
| combined = ", ".join( | |
| list(sorted(t.replace("_", " ") for t in out.get("character", []))) + | |
| [tag.replace("_", " ") for tag in out.get("ip", [])] + | |
| list(sorted(t.replace("_", " ") for t in out.get("feature", []))) | |
| ) or "β" | |
| meta = { | |
| "device": handler.device, | |
| "latency_s_total": latency, | |
| **out.get("_timings", {}), | |
| "params": out.get("_params", {}), | |
| } | |
| return features, characters, ips, combined, meta, out | |
| theme = gr.themes.Soft(primary_hue="indigo", secondary_hue="violet", radius_size="lg",) | |
| with gr.Blocks(title="PixAI Tagger v0.9 β Demo", fill_height=True, theme=theme, analytics_enabled=False) as demo: | |
| gr.Markdown( | |
| """ | |
| # PixAI Tagger v0.9 β Gradio Demo | |
| Downloads model assets from **pixai-labs/pixai-tagger-v0.9** on first run, | |
| then uses your imported `EndpointHandler` to predict **general**, **character**, and **IP** tags. | |
| Configure via env vars: | |
| - `ASSETS_REPO_ID` (default: `pixai-labs/pixai-tagger-v0.9`) | |
| - `ASSETS_REVISION` (optional) | |
| - `MODEL_DIR` (default: `./assets`) | |
| """ | |
| ) | |
| with gr.Row(): | |
| gr.Markdown(f"**{DEVICE_LABEL}** β adjust thresholds or switch to Top-K mode.") | |
| with gr.Accordion("Settings", open=False): | |
| mode = gr.Radio( | |
| choices=["threshold", "topk"], value="threshold", label="Mode" | |
| ) | |
| with gr.Group(visible=True) as threshold_group: | |
| general_threshold = gr.Slider( | |
| minimum=0.0, maximum=1.0, step=0.01, value=0.30, label="General threshold" | |
| ) | |
| character_threshold = gr.Slider( | |
| minimum=0.0, maximum=1.0, step=0.01, value=0.85, label="Character threshold" | |
| ) | |
| with gr.Group(visible=False) as topk_group: | |
| topk_general = gr.Slider( | |
| minimum=0, maximum=100, step=1, value=25, label="Top-K general" | |
| ) | |
| topk_character = gr.Slider( | |
| minimum=0, maximum=100, step=1, value=10, label="Top-K character" | |
| ) | |
| include_scores = gr.Checkbox(value=False, label="Include scores in output") | |
| underscore_mode = gr.Checkbox(value=False, label="Underscore-separated output") | |
| def toggle_mode(selected): | |
| return ( | |
| gr.update(visible=(selected == "threshold")), | |
| gr.update(visible=(selected == "topk")), | |
| ) | |
| mode.change(toggle_mode, inputs=[mode], outputs=[threshold_group, topk_group]) | |
| with gr.Row(variant="panel"): | |
| with gr.Column(scale=2): | |
| image = gr.Image(label="Upload image", type="pil", visible=True, height="420px") | |
| url = gr.Textbox(label="Image URL", placeholder="https://β¦", visible=True) | |
| def toggle_inputs(choice): | |
| return ( | |
| gr.update(visible=(choice == "Upload image")), | |
| gr.update(visible=(choice == "From URL")), | |
| ) | |
| with gr.Column(scale=3): | |
| # No source choice; show both inputs and two run buttons | |
| with gr.Row(): | |
| run_image_btn = gr.Button("Run from image", variant="primary") | |
| run_url_btn = gr.Button("Run from URL") | |
| clear_btn = gr.Button("Clear") | |
| gr.Markdown("### Combined Output (character β IP β general)") | |
| combined_out = gr.Textbox(label="Combined tags", lines=10,) | |
| copy_combined = gr.Button("Copy combined") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Character / General / IP") | |
| with gr.Row(): | |
| with gr.Column(): | |
| characters_out = gr.Textbox(label="Character tags", lines=5,) | |
| with gr.Column(): | |
| features_out = gr.Textbox(label="General tags", lines=5,) | |
| with gr.Column(): | |
| ip_out = gr.Textbox(label="IP tags", lines=5,) | |
| with gr.Row(): | |
| copy_characters = gr.Button("Copy character") | |
| copy_features = gr.Button("Copy general") | |
| copy_ip = gr.Button("Copy IP") | |
| with gr.Accordion("Metadata & Raw Output", open=False): | |
| with gr.Row(): | |
| with gr.Column(): | |
| meta_out = gr.JSON(label="Timings/Device") | |
| raw_out = gr.JSON(label="Raw JSON") | |
| copy_raw = gr.Button("Copy raw JSON") | |
| examples = gr.Examples( | |
| label="Examples (URL mode)", | |
| examples=[ | |
| [None, "https://cdn.donmai.us/sample/50/b7/__komeiji_koishi_touhou_drawn_by_cui_ying__sample-50b7006f16e0144d5b5db44cadc2d22f.jpg", 0.30, 0.85, "threshold", 25, 10, False, False], | |
| ], | |
| inputs=[image, url, general_threshold, character_threshold, mode, topk_general, topk_character, include_scores, underscore_mode], | |
| cache_examples=False, | |
| ) | |
| def clear(): | |
| return (None, "", 0.30, 0.85, "", "", "", "", {}, {}) | |
| # Bind buttons separately with a flag for source | |
| run_url_btn.click( | |
| run_inference, | |
| inputs=[ | |
| gr.State("url"), image, url, | |
| general_threshold, character_threshold, | |
| mode, topk_general, topk_character, include_scores, underscore_mode, | |
| ], | |
| outputs=[features_out, characters_out, ip_out, combined_out, meta_out, raw_out], | |
| api_name="predict_url", | |
| ) | |
| run_image_btn.click( | |
| run_inference, | |
| inputs=[ | |
| gr.State("image"), image, url, | |
| general_threshold, character_threshold, | |
| mode, topk_general, topk_character, include_scores, underscore_mode, | |
| ], | |
| outputs=[features_out, characters_out, ip_out, combined_out, meta_out, raw_out], | |
| api_name="predict_image", | |
| ) | |
| # Copy buttons | |
| copy_combined.click(lambda x: x, inputs=[combined_out], outputs=[combined_out]) | |
| copy_characters.click(lambda x: x, inputs=[characters_out], outputs=[characters_out]) | |
| copy_features.click(lambda x: x, inputs=[features_out], outputs=[features_out]) | |
| copy_ip.click(lambda x: x, inputs=[ip_out], outputs=[ip_out]) | |
| copy_raw.click(lambda x: x, inputs=[raw_out], outputs=[raw_out]) | |
| clear_btn.click( | |
| clear, | |
| inputs=None, | |
| outputs=[ | |
| image, url, general_threshold, character_threshold, | |
| features_out, characters_out, ip_out, meta_out, raw_out | |
| ], | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=8).launch() | |