Spaces:
Running
Running
import os | |
import time | |
import shutil | |
from pathlib import Path | |
from typing import Optional | |
import gradio as gr | |
from huggingface_hub import snapshot_download | |
from PIL import Image | |
# Import your existing inference endpoint implementation | |
from handler import EndpointHandler | |
# ------------------------------------------------------------------------------ | |
# Asset setup: download weights/tags/mapping so local filenames are unchanged | |
# ------------------------------------------------------------------------------ | |
REPO_ID = os.environ.get("ASSETS_REPO_ID", "pixai-labs/pixai-tagger-v0.9") | |
REVISION = os.environ.get("ASSETS_REVISION") # optional pin, e.g. "main" or a commit | |
MODEL_DIR = os.environ.get("MODEL_DIR", "./assets") # where the handler will look | |
# Optional: Hugging Face token for private repos | |
HF_TOKEN = ( | |
os.environ.get("HUGGINGFACE_HUB_TOKEN") | |
or os.environ.get("HF_TOKEN") | |
or os.environ.get("HUGGINGFACE_TOKEN") | |
or os.environ.get("HUGGINGFACEHUB_API_TOKEN") | |
) | |
REQUIRED_FILES = [ | |
"model_v0.9.pth", | |
"tags_v0.9_13k.json", | |
"char_ip_map.json", | |
] | |
def ensure_assets(repo_id: str, revision: Optional[str], target_dir: str): | |
""" | |
1) snapshot_download the upstream repo (cached by HF Hub) | |
2) copy the required files into `target_dir` with the exact filenames expected | |
""" | |
target = Path(target_dir) | |
target.mkdir(parents=True, exist_ok=True) | |
# Only download if something is missing | |
missing = [f for f in REQUIRED_FILES if not (target / f).exists()] | |
if not missing: | |
return | |
# Download snapshot (optionally filtered to speed up) | |
snapshot_path = snapshot_download( | |
repo_id=repo_id, | |
revision=revision, | |
allow_patterns=REQUIRED_FILES, # only pull what we need | |
token=HF_TOKEN, # authenticate if repo is private | |
) | |
# Copy files into target_dir with the required names | |
for fname in REQUIRED_FILES: | |
src = Path(snapshot_path) / fname | |
dst = target / fname | |
if not src.exists(): | |
raise FileNotFoundError( | |
f"Expected '{fname}' not found in snapshot for {repo_id} @ {revision or 'default'}" | |
) | |
shutil.copyfile(src, dst) | |
# Fetch assets (no-op if they already exist) | |
ensure_assets(REPO_ID, REVISION, MODEL_DIR) | |
# ------------------------------------------------------------------------------ | |
# Initialize the handler | |
# ------------------------------------------------------------------------------ | |
handler = EndpointHandler(MODEL_DIR) | |
DEVICE_LABEL = f"Device: {handler.device.upper()}" | |
# ------------------------------------------------------------------------------ | |
# Gradio wiring | |
# ------------------------------------------------------------------------------ | |
def run_inference( | |
source_choice: str, | |
image: Optional[Image.Image], | |
url: str, | |
general_threshold: float, | |
character_threshold: float, | |
): | |
if source_choice == "Upload image": | |
if image is None: | |
raise gr.Error("Please upload an image.") | |
inputs = image | |
else: | |
if not url or not url.strip(): | |
raise gr.Error("Please provide an image URL.") | |
inputs = {"url": url.strip()} | |
data = { | |
"inputs": inputs, | |
"parameters": { | |
"general_threshold": float(general_threshold), | |
"character_threshold": float(character_threshold), | |
}, | |
} | |
started = time.time() | |
try: | |
out = handler(data) | |
except Exception as e: | |
raise gr.Error(f"Inference error: {e}") from e | |
latency = round(time.time() - started, 4) | |
features = ", ".join(sorted(out.get("feature", []))) or "β" | |
characters = ", ".join(sorted(out.get("character", []))) or "β" | |
ips = ", ".join(out.get("ip", [])) or "β" | |
meta = { | |
"device": handler.device, | |
"latency_s_total": latency, | |
**out.get("_timings", {}), | |
} | |
return features, characters, ips, meta, out | |
with gr.Blocks(title="PixAI Tagger v0.9 β Demo", fill_height=True) as demo: | |
gr.Markdown( | |
""" | |
# PixAI Tagger v0.9 β Gradio Demo | |
Downloads model assets from **pixai-labs/pixai-tagger-v0.9** on first run, | |
then uses your imported `EndpointHandler` to predict **general**, **character**, and **IP** tags. | |
**Expected local filenames** (kept unchanged): | |
- `model_v0.9.pth` | |
- `tags_v0.9_13k.json` | |
- `char_ip_map.json` | |
Configure via env vars: | |
- `ASSETS_REPO_ID` (default: `pixai-labs/pixai-tagger-v0.9`) | |
- `ASSETS_REVISION` (optional) | |
- `MODEL_DIR` (default: `./assets`) | |
""" | |
) | |
with gr.Row(): | |
gr.Markdown(f"**{DEVICE_LABEL}**") | |
with gr.Row(): | |
source_choice = gr.Radio( | |
choices=["Upload image", "From URL"], | |
value="Upload image", | |
label="Image source", | |
) | |
with gr.Row(variant="panel"): | |
with gr.Column(scale=2): | |
image = gr.Image(label="Upload image", type="pil", visible=True) | |
url = gr.Textbox(label="Image URL", placeholder="https://β¦", visible=False) | |
def toggle_inputs(choice): | |
return ( | |
gr.update(visible=(choice == "Upload image")), | |
gr.update(visible=(choice == "From URL")), | |
) | |
source_choice.change(toggle_inputs, [source_choice], [image, url]) | |
with gr.Column(scale=1): | |
general_threshold = gr.Slider( | |
minimum=0.0, maximum=1.0, step=0.01, value=0.30, label="General threshold" | |
) | |
character_threshold = gr.Slider( | |
minimum=0.0, maximum=1.0, step=0.01, value=0.85, label="Character threshold" | |
) | |
run_btn = gr.Button("Run", variant="primary") | |
clear_btn = gr.Button("Clear") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### Predicted Tags") | |
features_out = gr.Textbox(label="General tags", lines=4) | |
characters_out = gr.Textbox(label="Character tags", lines=4) | |
ip_out = gr.Textbox(label="IP tags", lines=2) | |
with gr.Column(): | |
gr.Markdown("### Metadata & Raw Output") | |
meta_out = gr.JSON(label="Timings/Device") | |
raw_out = gr.JSON(label="Raw JSON") | |
examples = gr.Examples( | |
label="Examples (URL mode)", | |
examples=[ | |
["From URL", None, "https://cdn.donmai.us/sample/50/b7/__komeiji_koishi_touhou_drawn_by_cui_ying__sample-50b7006f16e0144d5b5db44cadc2d22f.jpg", 0.30, 0.85], | |
], | |
inputs=[source_choice, image, url, general_threshold, character_threshold], | |
cache_examples=False, | |
) | |
def clear(): | |
return (None, "", 0.30, 0.85, "", "", "", {}, {}) | |
run_btn.click( | |
run_inference, | |
inputs=[source_choice, image, url, general_threshold, character_threshold], | |
outputs=[features_out, characters_out, ip_out, meta_out, raw_out], | |
api_name="predict", | |
) | |
clear_btn.click( | |
clear, | |
inputs=None, | |
outputs=[ | |
image, url, general_threshold, character_threshold, | |
features_out, characters_out, ip_out, meta_out, raw_out | |
], | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=8).launch() | |