Spaces:
Running
Running
adding test space demo
Browse files- .gitignore +5 -0
- app.py +211 -0
- handler.py +215 -0
- requirements.txt +7 -0
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_test/
|
2 |
+
__pycache__/
|
3 |
+
.venv/
|
4 |
+
.ruff_cache/
|
5 |
+
assets/
|
app.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import shutil
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Optional
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
from huggingface_hub import snapshot_download
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
# Import your existing inference endpoint implementation
|
12 |
+
from handler import EndpointHandler
|
13 |
+
|
14 |
+
|
15 |
+
# ------------------------------------------------------------------------------
|
16 |
+
# Asset setup: download weights/tags/mapping so local filenames are unchanged
|
17 |
+
# ------------------------------------------------------------------------------
|
18 |
+
|
19 |
+
REPO_ID = os.environ.get("ASSETS_REPO_ID", "pixai-labs/pixai-tagger-v0.9")
|
20 |
+
REVISION = os.environ.get("ASSETS_REVISION") # optional pin, e.g. "main" or a commit
|
21 |
+
MODEL_DIR = os.environ.get("MODEL_DIR", "./assets") # where the handler will look
|
22 |
+
|
23 |
+
REQUIRED_FILES = [
|
24 |
+
"model_v0.9.pth",
|
25 |
+
"tags_v0.9_13k.json",
|
26 |
+
"char_ip_map.json",
|
27 |
+
]
|
28 |
+
|
29 |
+
def ensure_assets(repo_id: str, revision: Optional[str], target_dir: str):
|
30 |
+
"""
|
31 |
+
1) snapshot_download the upstream repo (cached by HF Hub)
|
32 |
+
2) copy the required files into `target_dir` with the exact filenames expected
|
33 |
+
"""
|
34 |
+
target = Path(target_dir)
|
35 |
+
target.mkdir(parents=True, exist_ok=True)
|
36 |
+
|
37 |
+
# Only download if something is missing
|
38 |
+
missing = [f for f in REQUIRED_FILES if not (target / f).exists()]
|
39 |
+
if not missing:
|
40 |
+
return
|
41 |
+
|
42 |
+
# Download snapshot (optionally filtered to speed up)
|
43 |
+
snapshot_path = snapshot_download(
|
44 |
+
repo_id=repo_id,
|
45 |
+
revision=revision,
|
46 |
+
allow_patterns=REQUIRED_FILES, # only pull what we need
|
47 |
+
)
|
48 |
+
|
49 |
+
# Copy files into target_dir with the required names
|
50 |
+
for fname in REQUIRED_FILES:
|
51 |
+
src = Path(snapshot_path) / fname
|
52 |
+
dst = target / fname
|
53 |
+
if not src.exists():
|
54 |
+
raise FileNotFoundError(
|
55 |
+
f"Expected '{fname}' not found in snapshot for {repo_id} @ {revision or 'default'}"
|
56 |
+
)
|
57 |
+
shutil.copyfile(src, dst)
|
58 |
+
|
59 |
+
|
60 |
+
# Fetch assets (no-op if they already exist)
|
61 |
+
ensure_assets(REPO_ID, REVISION, MODEL_DIR)
|
62 |
+
|
63 |
+
|
64 |
+
# ------------------------------------------------------------------------------
|
65 |
+
# Initialize the handler
|
66 |
+
# ------------------------------------------------------------------------------
|
67 |
+
|
68 |
+
handler = EndpointHandler(MODEL_DIR)
|
69 |
+
DEVICE_LABEL = f"Device: {handler.device.upper()}"
|
70 |
+
|
71 |
+
|
72 |
+
# ------------------------------------------------------------------------------
|
73 |
+
# Gradio wiring
|
74 |
+
# ------------------------------------------------------------------------------
|
75 |
+
|
76 |
+
def run_inference(
|
77 |
+
source_choice: str,
|
78 |
+
image: Optional[Image.Image],
|
79 |
+
url: str,
|
80 |
+
general_threshold: float,
|
81 |
+
character_threshold: float,
|
82 |
+
):
|
83 |
+
if source_choice == "Upload image":
|
84 |
+
if image is None:
|
85 |
+
raise gr.Error("Please upload an image.")
|
86 |
+
inputs = image
|
87 |
+
else:
|
88 |
+
if not url or not url.strip():
|
89 |
+
raise gr.Error("Please provide an image URL.")
|
90 |
+
inputs = {"url": url.strip()}
|
91 |
+
|
92 |
+
data = {
|
93 |
+
"inputs": inputs,
|
94 |
+
"parameters": {
|
95 |
+
"general_threshold": float(general_threshold),
|
96 |
+
"character_threshold": float(character_threshold),
|
97 |
+
},
|
98 |
+
}
|
99 |
+
|
100 |
+
started = time.time()
|
101 |
+
try:
|
102 |
+
out = handler(data)
|
103 |
+
except Exception as e:
|
104 |
+
raise gr.Error(f"Inference error: {e}") from e
|
105 |
+
latency = round(time.time() - started, 4)
|
106 |
+
|
107 |
+
features = ", ".join(sorted(out.get("feature", []))) or "—"
|
108 |
+
characters = ", ".join(sorted(out.get("character", []))) or "—"
|
109 |
+
ips = ", ".join(out.get("ip", [])) or "—"
|
110 |
+
|
111 |
+
meta = {
|
112 |
+
"device": handler.device,
|
113 |
+
"latency_s_total": latency,
|
114 |
+
**out.get("_timings", {}),
|
115 |
+
}
|
116 |
+
|
117 |
+
return features, characters, ips, meta, out
|
118 |
+
|
119 |
+
|
120 |
+
with gr.Blocks(title="PixAI Tagger v0.9 — Demo", fill_height=True) as demo:
|
121 |
+
gr.Markdown(
|
122 |
+
"""
|
123 |
+
# PixAI Tagger v0.9 — Gradio Demo
|
124 |
+
Downloads model assets from **pixai-labs/pixai-tagger-v0.9** on first run,
|
125 |
+
then uses your imported `EndpointHandler` to predict **general**, **character**, and **IP** tags.
|
126 |
+
|
127 |
+
**Expected local filenames** (kept unchanged):
|
128 |
+
- `model_v0.9.pth`
|
129 |
+
- `tags_v0.9_13k.json`
|
130 |
+
- `char_ip_map.json`
|
131 |
+
|
132 |
+
Configure via env vars:
|
133 |
+
- `ASSETS_REPO_ID` (default: `pixai-labs/pixai-tagger-v0.9`)
|
134 |
+
- `ASSETS_REVISION` (optional)
|
135 |
+
- `MODEL_DIR` (default: `./assets`)
|
136 |
+
"""
|
137 |
+
)
|
138 |
+
with gr.Row():
|
139 |
+
gr.Markdown(f"**{DEVICE_LABEL}**")
|
140 |
+
|
141 |
+
with gr.Row():
|
142 |
+
source_choice = gr.Radio(
|
143 |
+
choices=["Upload image", "From URL"],
|
144 |
+
value="Upload image",
|
145 |
+
label="Image source",
|
146 |
+
)
|
147 |
+
|
148 |
+
with gr.Row(variant="panel"):
|
149 |
+
with gr.Column(scale=2):
|
150 |
+
image = gr.Image(label="Upload image", type="pil", visible=True)
|
151 |
+
url = gr.Textbox(label="Image URL", placeholder="https://…", visible=False)
|
152 |
+
|
153 |
+
def toggle_inputs(choice):
|
154 |
+
return (
|
155 |
+
gr.update(visible=(choice == "Upload image")),
|
156 |
+
gr.update(visible=(choice == "From URL")),
|
157 |
+
)
|
158 |
+
|
159 |
+
source_choice.change(toggle_inputs, [source_choice], [image, url])
|
160 |
+
|
161 |
+
with gr.Column(scale=1):
|
162 |
+
general_threshold = gr.Slider(
|
163 |
+
minimum=0.0, maximum=1.0, step=0.01, value=0.30, label="General threshold"
|
164 |
+
)
|
165 |
+
character_threshold = gr.Slider(
|
166 |
+
minimum=0.0, maximum=1.0, step=0.01, value=0.85, label="Character threshold"
|
167 |
+
)
|
168 |
+
run_btn = gr.Button("Run", variant="primary")
|
169 |
+
clear_btn = gr.Button("Clear")
|
170 |
+
|
171 |
+
with gr.Row():
|
172 |
+
with gr.Column():
|
173 |
+
gr.Markdown("### Predicted Tags")
|
174 |
+
features_out = gr.Textbox(label="General tags", lines=4)
|
175 |
+
characters_out = gr.Textbox(label="Character tags", lines=4)
|
176 |
+
ip_out = gr.Textbox(label="IP tags", lines=2)
|
177 |
+
|
178 |
+
with gr.Column():
|
179 |
+
gr.Markdown("### Metadata & Raw Output")
|
180 |
+
meta_out = gr.JSON(label="Timings/Device")
|
181 |
+
raw_out = gr.JSON(label="Raw JSON")
|
182 |
+
|
183 |
+
examples = gr.Examples(
|
184 |
+
label="Examples (URL mode)",
|
185 |
+
examples=[
|
186 |
+
["From URL", None, "https://cdn.donmai.us/sample/50/b7/__komeiji_koishi_touhou_drawn_by_cui_ying__sample-50b7006f16e0144d5b5db44cadc2d22f.jpg", 0.30, 0.85],
|
187 |
+
],
|
188 |
+
inputs=[source_choice, image, url, general_threshold, character_threshold],
|
189 |
+
cache_examples=False,
|
190 |
+
)
|
191 |
+
|
192 |
+
def clear():
|
193 |
+
return (None, "", 0.30, 0.85, "", "", "", {}, {})
|
194 |
+
|
195 |
+
run_btn.click(
|
196 |
+
run_inference,
|
197 |
+
inputs=[source_choice, image, url, general_threshold, character_threshold],
|
198 |
+
outputs=[features_out, characters_out, ip_out, meta_out, raw_out],
|
199 |
+
api_name="predict",
|
200 |
+
)
|
201 |
+
clear_btn.click(
|
202 |
+
clear,
|
203 |
+
inputs=None,
|
204 |
+
outputs=[
|
205 |
+
image, url, general_threshold, character_threshold,
|
206 |
+
features_out, characters_out, ip_out, meta_out, raw_out
|
207 |
+
],
|
208 |
+
)
|
209 |
+
|
210 |
+
if __name__ == "__main__":
|
211 |
+
demo.queue(max_size=8).launch()
|
handler.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import io
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
import time
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import Any
|
8 |
+
|
9 |
+
import requests
|
10 |
+
import timm
|
11 |
+
import torch
|
12 |
+
import torchvision.transforms as transforms
|
13 |
+
from PIL import Image
|
14 |
+
|
15 |
+
|
16 |
+
class TaggingHead(torch.nn.Module):
|
17 |
+
def __init__(self, input_dim, num_classes):
|
18 |
+
super().__init__()
|
19 |
+
self.input_dim = input_dim
|
20 |
+
self.num_classes = num_classes
|
21 |
+
self.head = torch.nn.Sequential(torch.nn.Linear(input_dim, num_classes))
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
logits = self.head(x)
|
25 |
+
probs = torch.nn.functional.sigmoid(logits)
|
26 |
+
return probs
|
27 |
+
|
28 |
+
|
29 |
+
def get_tags(tags_file: Path) -> tuple[dict[str, int], int, int]:
|
30 |
+
with tags_file.open("r", encoding="utf-8") as f:
|
31 |
+
tag_info = json.load(f)
|
32 |
+
tag_map = tag_info["tag_map"]
|
33 |
+
tag_split = tag_info["tag_split"]
|
34 |
+
gen_tag_count = tag_split["gen_tag_count"]
|
35 |
+
character_tag_count = tag_split["character_tag_count"]
|
36 |
+
return tag_map, gen_tag_count, character_tag_count
|
37 |
+
|
38 |
+
|
39 |
+
def get_character_ip_mapping(mapping_file: Path):
|
40 |
+
with mapping_file.open("r", encoding="utf-8") as f:
|
41 |
+
mapping = json.load(f)
|
42 |
+
return mapping
|
43 |
+
|
44 |
+
|
45 |
+
def get_encoder():
|
46 |
+
base_model_repo = "hf_hub:SmilingWolf/wd-eva02-large-tagger-v3"
|
47 |
+
encoder = timm.create_model(base_model_repo, pretrained=False)
|
48 |
+
encoder.reset_classifier(0)
|
49 |
+
return encoder
|
50 |
+
|
51 |
+
|
52 |
+
def get_decoder():
|
53 |
+
decoder = TaggingHead(1024, 13461)
|
54 |
+
return decoder
|
55 |
+
|
56 |
+
|
57 |
+
def get_model():
|
58 |
+
encoder = get_encoder()
|
59 |
+
decoder = get_decoder()
|
60 |
+
model = torch.nn.Sequential(encoder, decoder)
|
61 |
+
return model
|
62 |
+
|
63 |
+
|
64 |
+
def load_model(weights_file, device):
|
65 |
+
model = get_model()
|
66 |
+
states_dict = torch.load(weights_file, map_location=device, weights_only=True)
|
67 |
+
model.load_state_dict(states_dict)
|
68 |
+
model.to(device)
|
69 |
+
model.eval()
|
70 |
+
return model
|
71 |
+
|
72 |
+
|
73 |
+
def pure_pil_alpha_to_color_v2(
|
74 |
+
image: Image.Image, color: tuple[int, int, int] = (255, 255, 255)
|
75 |
+
) -> Image.Image:
|
76 |
+
"""
|
77 |
+
Convert a PIL image with an alpha channel to a RGB image.
|
78 |
+
This is a workaround for the fact that the model expects a RGB image, but the image may have an alpha channel.
|
79 |
+
This function will convert the image to a RGB image, and fill the alpha channel with the given color.
|
80 |
+
The alpha channel is the 4th channel of the image.
|
81 |
+
"""
|
82 |
+
image.load() # needed for split()
|
83 |
+
background = Image.new("RGB", image.size, color)
|
84 |
+
background.paste(image, mask=image.split()[3]) # 3 is the alpha channel
|
85 |
+
return background
|
86 |
+
|
87 |
+
|
88 |
+
def pil_to_rgb(image: Image.Image) -> Image.Image:
|
89 |
+
if image.mode == "RGBA":
|
90 |
+
image = pure_pil_alpha_to_color_v2(image)
|
91 |
+
elif image.mode == "P":
|
92 |
+
image = pure_pil_alpha_to_color_v2(image.convert("RGBA"))
|
93 |
+
else:
|
94 |
+
image = image.convert("RGB")
|
95 |
+
return image
|
96 |
+
|
97 |
+
|
98 |
+
class EndpointHandler:
|
99 |
+
def __init__(self, path: str):
|
100 |
+
repo_path = Path(path)
|
101 |
+
assert repo_path.is_dir(), f"Model directory not found: {repo_path}"
|
102 |
+
weights_file = repo_path / "model_v0.9.pth"
|
103 |
+
tags_file = repo_path / "tags_v0.9_13k.json"
|
104 |
+
mapping_file = repo_path / "char_ip_map.json"
|
105 |
+
if not weights_file.exists():
|
106 |
+
raise FileNotFoundError(f"Model file not found: {weights_file}")
|
107 |
+
if not tags_file.exists():
|
108 |
+
raise FileNotFoundError(f"Tags file not found: {tags_file}")
|
109 |
+
if not mapping_file.exists():
|
110 |
+
raise FileNotFoundError(f"Mapping file not found: {mapping_file}")
|
111 |
+
|
112 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
113 |
+
self.model = load_model(str(weights_file), self.device)
|
114 |
+
self.transform = transforms.Compose(
|
115 |
+
[
|
116 |
+
transforms.Resize((448, 448)),
|
117 |
+
transforms.ToTensor(),
|
118 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
119 |
+
]
|
120 |
+
)
|
121 |
+
self.fetch_image_timeout = 5.0
|
122 |
+
self.default_general_threshold = 0.3
|
123 |
+
self.default_character_threshold = 0.85
|
124 |
+
|
125 |
+
tag_map, self.gen_tag_count, self.character_tag_count = get_tags(tags_file)
|
126 |
+
|
127 |
+
# Invert the tag_map for efficient index-to-tag lookups
|
128 |
+
self.index_to_tag_map = {v: k for k, v in tag_map.items()}
|
129 |
+
|
130 |
+
self.character_ip_mapping = get_character_ip_mapping(mapping_file)
|
131 |
+
|
132 |
+
def __call__(self, data: dict[str, Any]) -> dict[str, Any]:
|
133 |
+
inputs = data.pop("inputs", data)
|
134 |
+
|
135 |
+
fetch_start_time = time.time()
|
136 |
+
if isinstance(inputs, Image.Image):
|
137 |
+
image = inputs
|
138 |
+
elif image_url := inputs.pop("url", None):
|
139 |
+
with requests.get(
|
140 |
+
image_url, stream=True, timeout=self.fetch_image_timeout
|
141 |
+
) as res:
|
142 |
+
res.raise_for_status()
|
143 |
+
image = Image.open(res.raw)
|
144 |
+
elif image_base64_encoded := inputs.pop("image", None):
|
145 |
+
image = Image.open(io.BytesIO(base64.b64decode(image_base64_encoded)))
|
146 |
+
else:
|
147 |
+
raise ValueError(f"No image or url provided: {data}")
|
148 |
+
# remove alpha channel if it exists
|
149 |
+
image = pil_to_rgb(image)
|
150 |
+
fetch_time = time.time() - fetch_start_time
|
151 |
+
|
152 |
+
parameters = data.pop("parameters", {})
|
153 |
+
general_threshold = parameters.pop(
|
154 |
+
"general_threshold", self.default_general_threshold
|
155 |
+
)
|
156 |
+
character_threshold = parameters.pop(
|
157 |
+
"character_threshold", self.default_character_threshold
|
158 |
+
)
|
159 |
+
|
160 |
+
inference_start_time = time.time()
|
161 |
+
with torch.inference_mode():
|
162 |
+
# Preprocess image on CPU, then pin memory for faster async transfer
|
163 |
+
image_tensor = self.transform(image).unsqueeze(0).pin_memory()
|
164 |
+
|
165 |
+
# Asynchronously move image to GPU
|
166 |
+
image_tensor = image_tensor.to(self.device, non_blocking=True)
|
167 |
+
|
168 |
+
# Run model on GPU
|
169 |
+
probs = self.model(image_tensor)[0] # Get probs for the single image
|
170 |
+
|
171 |
+
# Perform thresholding directly on the GPU
|
172 |
+
general_mask = probs[: self.gen_tag_count] > general_threshold
|
173 |
+
character_mask = probs[self.gen_tag_count :] > character_threshold
|
174 |
+
|
175 |
+
# Get the indices of positive tags on the GPU
|
176 |
+
general_indices = general_mask.nonzero(as_tuple=True)[0]
|
177 |
+
character_indices = (
|
178 |
+
character_mask.nonzero(as_tuple=True)[0] + self.gen_tag_count
|
179 |
+
)
|
180 |
+
|
181 |
+
# Combine indices and move the small result tensor to the CPU
|
182 |
+
combined_indices = torch.cat((general_indices, character_indices)).cpu()
|
183 |
+
|
184 |
+
inference_time = time.time() - inference_start_time
|
185 |
+
|
186 |
+
post_process_start_time = time.time()
|
187 |
+
|
188 |
+
cur_gen_tags = []
|
189 |
+
cur_char_tags = []
|
190 |
+
|
191 |
+
# Use the efficient pre-computed map for lookups
|
192 |
+
for i in combined_indices:
|
193 |
+
idx = i.item()
|
194 |
+
tag = self.index_to_tag_map[idx]
|
195 |
+
if idx < self.gen_tag_count:
|
196 |
+
cur_gen_tags.append(tag)
|
197 |
+
else:
|
198 |
+
cur_char_tags.append(tag)
|
199 |
+
|
200 |
+
ip_tags = []
|
201 |
+
for tag in cur_char_tags:
|
202 |
+
if tag in self.character_ip_mapping:
|
203 |
+
ip_tags.extend(self.character_ip_mapping[tag])
|
204 |
+
ip_tags = sorted(set(ip_tags))
|
205 |
+
post_process_time = time.time() - post_process_start_time
|
206 |
+
|
207 |
+
logging.info(
|
208 |
+
f"Timing - Fetch: {fetch_time:.3f}s, Inference: {inference_time:.3f}s, Post-process: {post_process_time:.3f}s, Total: {fetch_time + inference_time + post_process_time:.3f}s"
|
209 |
+
)
|
210 |
+
|
211 |
+
return {
|
212 |
+
"feature": cur_gen_tags,
|
213 |
+
"character": cur_char_tags,
|
214 |
+
"ip": ip_tags,
|
215 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio>=4.31.0
|
2 |
+
huggingface_hub>=0.24.0
|
3 |
+
torch
|
4 |
+
torchvision
|
5 |
+
timm
|
6 |
+
pillow
|
7 |
+
requests
|