trojblue commited on
Commit
1412dfd
·
1 Parent(s): d0dae1b

adding test space demo

Browse files
Files changed (4) hide show
  1. .gitignore +5 -0
  2. app.py +211 -0
  3. handler.py +215 -0
  4. requirements.txt +7 -0
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ _test/
2
+ __pycache__/
3
+ .venv/
4
+ .ruff_cache/
5
+ assets/
app.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import shutil
4
+ from pathlib import Path
5
+ from typing import Optional
6
+
7
+ import gradio as gr
8
+ from huggingface_hub import snapshot_download
9
+ from PIL import Image
10
+
11
+ # Import your existing inference endpoint implementation
12
+ from handler import EndpointHandler
13
+
14
+
15
+ # ------------------------------------------------------------------------------
16
+ # Asset setup: download weights/tags/mapping so local filenames are unchanged
17
+ # ------------------------------------------------------------------------------
18
+
19
+ REPO_ID = os.environ.get("ASSETS_REPO_ID", "pixai-labs/pixai-tagger-v0.9")
20
+ REVISION = os.environ.get("ASSETS_REVISION") # optional pin, e.g. "main" or a commit
21
+ MODEL_DIR = os.environ.get("MODEL_DIR", "./assets") # where the handler will look
22
+
23
+ REQUIRED_FILES = [
24
+ "model_v0.9.pth",
25
+ "tags_v0.9_13k.json",
26
+ "char_ip_map.json",
27
+ ]
28
+
29
+ def ensure_assets(repo_id: str, revision: Optional[str], target_dir: str):
30
+ """
31
+ 1) snapshot_download the upstream repo (cached by HF Hub)
32
+ 2) copy the required files into `target_dir` with the exact filenames expected
33
+ """
34
+ target = Path(target_dir)
35
+ target.mkdir(parents=True, exist_ok=True)
36
+
37
+ # Only download if something is missing
38
+ missing = [f for f in REQUIRED_FILES if not (target / f).exists()]
39
+ if not missing:
40
+ return
41
+
42
+ # Download snapshot (optionally filtered to speed up)
43
+ snapshot_path = snapshot_download(
44
+ repo_id=repo_id,
45
+ revision=revision,
46
+ allow_patterns=REQUIRED_FILES, # only pull what we need
47
+ )
48
+
49
+ # Copy files into target_dir with the required names
50
+ for fname in REQUIRED_FILES:
51
+ src = Path(snapshot_path) / fname
52
+ dst = target / fname
53
+ if not src.exists():
54
+ raise FileNotFoundError(
55
+ f"Expected '{fname}' not found in snapshot for {repo_id} @ {revision or 'default'}"
56
+ )
57
+ shutil.copyfile(src, dst)
58
+
59
+
60
+ # Fetch assets (no-op if they already exist)
61
+ ensure_assets(REPO_ID, REVISION, MODEL_DIR)
62
+
63
+
64
+ # ------------------------------------------------------------------------------
65
+ # Initialize the handler
66
+ # ------------------------------------------------------------------------------
67
+
68
+ handler = EndpointHandler(MODEL_DIR)
69
+ DEVICE_LABEL = f"Device: {handler.device.upper()}"
70
+
71
+
72
+ # ------------------------------------------------------------------------------
73
+ # Gradio wiring
74
+ # ------------------------------------------------------------------------------
75
+
76
+ def run_inference(
77
+ source_choice: str,
78
+ image: Optional[Image.Image],
79
+ url: str,
80
+ general_threshold: float,
81
+ character_threshold: float,
82
+ ):
83
+ if source_choice == "Upload image":
84
+ if image is None:
85
+ raise gr.Error("Please upload an image.")
86
+ inputs = image
87
+ else:
88
+ if not url or not url.strip():
89
+ raise gr.Error("Please provide an image URL.")
90
+ inputs = {"url": url.strip()}
91
+
92
+ data = {
93
+ "inputs": inputs,
94
+ "parameters": {
95
+ "general_threshold": float(general_threshold),
96
+ "character_threshold": float(character_threshold),
97
+ },
98
+ }
99
+
100
+ started = time.time()
101
+ try:
102
+ out = handler(data)
103
+ except Exception as e:
104
+ raise gr.Error(f"Inference error: {e}") from e
105
+ latency = round(time.time() - started, 4)
106
+
107
+ features = ", ".join(sorted(out.get("feature", []))) or "—"
108
+ characters = ", ".join(sorted(out.get("character", []))) or "—"
109
+ ips = ", ".join(out.get("ip", [])) or "—"
110
+
111
+ meta = {
112
+ "device": handler.device,
113
+ "latency_s_total": latency,
114
+ **out.get("_timings", {}),
115
+ }
116
+
117
+ return features, characters, ips, meta, out
118
+
119
+
120
+ with gr.Blocks(title="PixAI Tagger v0.9 — Demo", fill_height=True) as demo:
121
+ gr.Markdown(
122
+ """
123
+ # PixAI Tagger v0.9 — Gradio Demo
124
+ Downloads model assets from **pixai-labs/pixai-tagger-v0.9** on first run,
125
+ then uses your imported `EndpointHandler` to predict **general**, **character**, and **IP** tags.
126
+
127
+ **Expected local filenames** (kept unchanged):
128
+ - `model_v0.9.pth`
129
+ - `tags_v0.9_13k.json`
130
+ - `char_ip_map.json`
131
+
132
+ Configure via env vars:
133
+ - `ASSETS_REPO_ID` (default: `pixai-labs/pixai-tagger-v0.9`)
134
+ - `ASSETS_REVISION` (optional)
135
+ - `MODEL_DIR` (default: `./assets`)
136
+ """
137
+ )
138
+ with gr.Row():
139
+ gr.Markdown(f"**{DEVICE_LABEL}**")
140
+
141
+ with gr.Row():
142
+ source_choice = gr.Radio(
143
+ choices=["Upload image", "From URL"],
144
+ value="Upload image",
145
+ label="Image source",
146
+ )
147
+
148
+ with gr.Row(variant="panel"):
149
+ with gr.Column(scale=2):
150
+ image = gr.Image(label="Upload image", type="pil", visible=True)
151
+ url = gr.Textbox(label="Image URL", placeholder="https://…", visible=False)
152
+
153
+ def toggle_inputs(choice):
154
+ return (
155
+ gr.update(visible=(choice == "Upload image")),
156
+ gr.update(visible=(choice == "From URL")),
157
+ )
158
+
159
+ source_choice.change(toggle_inputs, [source_choice], [image, url])
160
+
161
+ with gr.Column(scale=1):
162
+ general_threshold = gr.Slider(
163
+ minimum=0.0, maximum=1.0, step=0.01, value=0.30, label="General threshold"
164
+ )
165
+ character_threshold = gr.Slider(
166
+ minimum=0.0, maximum=1.0, step=0.01, value=0.85, label="Character threshold"
167
+ )
168
+ run_btn = gr.Button("Run", variant="primary")
169
+ clear_btn = gr.Button("Clear")
170
+
171
+ with gr.Row():
172
+ with gr.Column():
173
+ gr.Markdown("### Predicted Tags")
174
+ features_out = gr.Textbox(label="General tags", lines=4)
175
+ characters_out = gr.Textbox(label="Character tags", lines=4)
176
+ ip_out = gr.Textbox(label="IP tags", lines=2)
177
+
178
+ with gr.Column():
179
+ gr.Markdown("### Metadata & Raw Output")
180
+ meta_out = gr.JSON(label="Timings/Device")
181
+ raw_out = gr.JSON(label="Raw JSON")
182
+
183
+ examples = gr.Examples(
184
+ label="Examples (URL mode)",
185
+ examples=[
186
+ ["From URL", None, "https://cdn.donmai.us/sample/50/b7/__komeiji_koishi_touhou_drawn_by_cui_ying__sample-50b7006f16e0144d5b5db44cadc2d22f.jpg", 0.30, 0.85],
187
+ ],
188
+ inputs=[source_choice, image, url, general_threshold, character_threshold],
189
+ cache_examples=False,
190
+ )
191
+
192
+ def clear():
193
+ return (None, "", 0.30, 0.85, "", "", "", {}, {})
194
+
195
+ run_btn.click(
196
+ run_inference,
197
+ inputs=[source_choice, image, url, general_threshold, character_threshold],
198
+ outputs=[features_out, characters_out, ip_out, meta_out, raw_out],
199
+ api_name="predict",
200
+ )
201
+ clear_btn.click(
202
+ clear,
203
+ inputs=None,
204
+ outputs=[
205
+ image, url, general_threshold, character_threshold,
206
+ features_out, characters_out, ip_out, meta_out, raw_out
207
+ ],
208
+ )
209
+
210
+ if __name__ == "__main__":
211
+ demo.queue(max_size=8).launch()
handler.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import json
4
+ import logging
5
+ import time
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ import requests
10
+ import timm
11
+ import torch
12
+ import torchvision.transforms as transforms
13
+ from PIL import Image
14
+
15
+
16
+ class TaggingHead(torch.nn.Module):
17
+ def __init__(self, input_dim, num_classes):
18
+ super().__init__()
19
+ self.input_dim = input_dim
20
+ self.num_classes = num_classes
21
+ self.head = torch.nn.Sequential(torch.nn.Linear(input_dim, num_classes))
22
+
23
+ def forward(self, x):
24
+ logits = self.head(x)
25
+ probs = torch.nn.functional.sigmoid(logits)
26
+ return probs
27
+
28
+
29
+ def get_tags(tags_file: Path) -> tuple[dict[str, int], int, int]:
30
+ with tags_file.open("r", encoding="utf-8") as f:
31
+ tag_info = json.load(f)
32
+ tag_map = tag_info["tag_map"]
33
+ tag_split = tag_info["tag_split"]
34
+ gen_tag_count = tag_split["gen_tag_count"]
35
+ character_tag_count = tag_split["character_tag_count"]
36
+ return tag_map, gen_tag_count, character_tag_count
37
+
38
+
39
+ def get_character_ip_mapping(mapping_file: Path):
40
+ with mapping_file.open("r", encoding="utf-8") as f:
41
+ mapping = json.load(f)
42
+ return mapping
43
+
44
+
45
+ def get_encoder():
46
+ base_model_repo = "hf_hub:SmilingWolf/wd-eva02-large-tagger-v3"
47
+ encoder = timm.create_model(base_model_repo, pretrained=False)
48
+ encoder.reset_classifier(0)
49
+ return encoder
50
+
51
+
52
+ def get_decoder():
53
+ decoder = TaggingHead(1024, 13461)
54
+ return decoder
55
+
56
+
57
+ def get_model():
58
+ encoder = get_encoder()
59
+ decoder = get_decoder()
60
+ model = torch.nn.Sequential(encoder, decoder)
61
+ return model
62
+
63
+
64
+ def load_model(weights_file, device):
65
+ model = get_model()
66
+ states_dict = torch.load(weights_file, map_location=device, weights_only=True)
67
+ model.load_state_dict(states_dict)
68
+ model.to(device)
69
+ model.eval()
70
+ return model
71
+
72
+
73
+ def pure_pil_alpha_to_color_v2(
74
+ image: Image.Image, color: tuple[int, int, int] = (255, 255, 255)
75
+ ) -> Image.Image:
76
+ """
77
+ Convert a PIL image with an alpha channel to a RGB image.
78
+ This is a workaround for the fact that the model expects a RGB image, but the image may have an alpha channel.
79
+ This function will convert the image to a RGB image, and fill the alpha channel with the given color.
80
+ The alpha channel is the 4th channel of the image.
81
+ """
82
+ image.load() # needed for split()
83
+ background = Image.new("RGB", image.size, color)
84
+ background.paste(image, mask=image.split()[3]) # 3 is the alpha channel
85
+ return background
86
+
87
+
88
+ def pil_to_rgb(image: Image.Image) -> Image.Image:
89
+ if image.mode == "RGBA":
90
+ image = pure_pil_alpha_to_color_v2(image)
91
+ elif image.mode == "P":
92
+ image = pure_pil_alpha_to_color_v2(image.convert("RGBA"))
93
+ else:
94
+ image = image.convert("RGB")
95
+ return image
96
+
97
+
98
+ class EndpointHandler:
99
+ def __init__(self, path: str):
100
+ repo_path = Path(path)
101
+ assert repo_path.is_dir(), f"Model directory not found: {repo_path}"
102
+ weights_file = repo_path / "model_v0.9.pth"
103
+ tags_file = repo_path / "tags_v0.9_13k.json"
104
+ mapping_file = repo_path / "char_ip_map.json"
105
+ if not weights_file.exists():
106
+ raise FileNotFoundError(f"Model file not found: {weights_file}")
107
+ if not tags_file.exists():
108
+ raise FileNotFoundError(f"Tags file not found: {tags_file}")
109
+ if not mapping_file.exists():
110
+ raise FileNotFoundError(f"Mapping file not found: {mapping_file}")
111
+
112
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
113
+ self.model = load_model(str(weights_file), self.device)
114
+ self.transform = transforms.Compose(
115
+ [
116
+ transforms.Resize((448, 448)),
117
+ transforms.ToTensor(),
118
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
119
+ ]
120
+ )
121
+ self.fetch_image_timeout = 5.0
122
+ self.default_general_threshold = 0.3
123
+ self.default_character_threshold = 0.85
124
+
125
+ tag_map, self.gen_tag_count, self.character_tag_count = get_tags(tags_file)
126
+
127
+ # Invert the tag_map for efficient index-to-tag lookups
128
+ self.index_to_tag_map = {v: k for k, v in tag_map.items()}
129
+
130
+ self.character_ip_mapping = get_character_ip_mapping(mapping_file)
131
+
132
+ def __call__(self, data: dict[str, Any]) -> dict[str, Any]:
133
+ inputs = data.pop("inputs", data)
134
+
135
+ fetch_start_time = time.time()
136
+ if isinstance(inputs, Image.Image):
137
+ image = inputs
138
+ elif image_url := inputs.pop("url", None):
139
+ with requests.get(
140
+ image_url, stream=True, timeout=self.fetch_image_timeout
141
+ ) as res:
142
+ res.raise_for_status()
143
+ image = Image.open(res.raw)
144
+ elif image_base64_encoded := inputs.pop("image", None):
145
+ image = Image.open(io.BytesIO(base64.b64decode(image_base64_encoded)))
146
+ else:
147
+ raise ValueError(f"No image or url provided: {data}")
148
+ # remove alpha channel if it exists
149
+ image = pil_to_rgb(image)
150
+ fetch_time = time.time() - fetch_start_time
151
+
152
+ parameters = data.pop("parameters", {})
153
+ general_threshold = parameters.pop(
154
+ "general_threshold", self.default_general_threshold
155
+ )
156
+ character_threshold = parameters.pop(
157
+ "character_threshold", self.default_character_threshold
158
+ )
159
+
160
+ inference_start_time = time.time()
161
+ with torch.inference_mode():
162
+ # Preprocess image on CPU, then pin memory for faster async transfer
163
+ image_tensor = self.transform(image).unsqueeze(0).pin_memory()
164
+
165
+ # Asynchronously move image to GPU
166
+ image_tensor = image_tensor.to(self.device, non_blocking=True)
167
+
168
+ # Run model on GPU
169
+ probs = self.model(image_tensor)[0] # Get probs for the single image
170
+
171
+ # Perform thresholding directly on the GPU
172
+ general_mask = probs[: self.gen_tag_count] > general_threshold
173
+ character_mask = probs[self.gen_tag_count :] > character_threshold
174
+
175
+ # Get the indices of positive tags on the GPU
176
+ general_indices = general_mask.nonzero(as_tuple=True)[0]
177
+ character_indices = (
178
+ character_mask.nonzero(as_tuple=True)[0] + self.gen_tag_count
179
+ )
180
+
181
+ # Combine indices and move the small result tensor to the CPU
182
+ combined_indices = torch.cat((general_indices, character_indices)).cpu()
183
+
184
+ inference_time = time.time() - inference_start_time
185
+
186
+ post_process_start_time = time.time()
187
+
188
+ cur_gen_tags = []
189
+ cur_char_tags = []
190
+
191
+ # Use the efficient pre-computed map for lookups
192
+ for i in combined_indices:
193
+ idx = i.item()
194
+ tag = self.index_to_tag_map[idx]
195
+ if idx < self.gen_tag_count:
196
+ cur_gen_tags.append(tag)
197
+ else:
198
+ cur_char_tags.append(tag)
199
+
200
+ ip_tags = []
201
+ for tag in cur_char_tags:
202
+ if tag in self.character_ip_mapping:
203
+ ip_tags.extend(self.character_ip_mapping[tag])
204
+ ip_tags = sorted(set(ip_tags))
205
+ post_process_time = time.time() - post_process_start_time
206
+
207
+ logging.info(
208
+ f"Timing - Fetch: {fetch_time:.3f}s, Inference: {inference_time:.3f}s, Post-process: {post_process_time:.3f}s, Total: {fetch_time + inference_time + post_process_time:.3f}s"
209
+ )
210
+
211
+ return {
212
+ "feature": cur_gen_tags,
213
+ "character": cur_char_tags,
214
+ "ip": ip_tags,
215
+ }
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio>=4.31.0
2
+ huggingface_hub>=0.24.0
3
+ torch
4
+ torchvision
5
+ timm
6
+ pillow
7
+ requests