tomcas commited on
Commit
06a6a51
1 Parent(s): 1464cae

Upload 6 files

Browse files
Files changed (6) hide show
  1. README.md +32 -7
  2. app.py +339 -0
  3. gitattributes +27 -0
  4. gitignore +1 -0
  5. power.jpg +0 -0
  6. requirements.txt +3 -0
README.md CHANGED
@@ -1,13 +1,38 @@
1
  ---
2
- title: Wdtag
3
- emoji: 🔥
4
- colorFrom: yellow
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 4.31.5
8
  app_file: app.py
9
  pinned: false
10
- license: mit
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: WaifuDiffusion Tagger
3
+ emoji: 💬
4
+ colorFrom: blue
5
+ colorTo: red
6
  sdk: gradio
7
+ sdk_version: 4.20.1
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
+ # Configuration
13
+
14
+ `title`: _string_
15
+ Display title for the Space
16
+
17
+ `emoji`: _string_
18
+ Space emoji (emoji-only character allowed)
19
+
20
+ `colorFrom`: _string_
21
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
22
+
23
+ `colorTo`: _string_
24
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
25
+
26
+ `sdk`: _string_
27
+ Can be either `gradio`, `streamlit`, or `static`
28
+
29
+ `sdk_version` : _string_
30
+ Only applicable for `streamlit` SDK.
31
+ See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
32
+
33
+ `app_file`: _string_
34
+ Path to your main application file (which contains either `gradio` or `streamlit` Python code, or `static` html code).
35
+ Path is relative to the root of the repository.
36
+
37
+ `pinned`: _boolean_
38
+ Whether the Space stays on top of your list.
app.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import gradio as gr
5
+ import huggingface_hub
6
+ import numpy as np
7
+ import onnxruntime as rt
8
+ import pandas as pd
9
+ from PIL import Image
10
+
11
+ TITLE = "WaifuDiffusion Tagger"
12
+ DESCRIPTION = """
13
+ Demo for the WaifuDiffusion tagger models
14
+
15
+ Example image by [ほし☆☆☆](https://www.pixiv.net/en/users/43565085)
16
+ """
17
+
18
+ HF_TOKEN = os.environ["HF_TOKEN"]
19
+
20
+ # Dataset v3 series of models:
21
+ SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
22
+ CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
23
+ VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
24
+
25
+ # Dataset v2 series of models:
26
+ MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
27
+ SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
28
+ CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
29
+ CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
30
+ VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
31
+
32
+ # Files to download from the repos
33
+ MODEL_FILENAME = "model.onnx"
34
+ LABEL_FILENAME = "selected_tags.csv"
35
+
36
+ # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
37
+ kaomojis = [
38
+ "0_0",
39
+ "(o)_(o)",
40
+ "+_+",
41
+ "+_-",
42
+ "._.",
43
+ "<o>_<o>",
44
+ "<|>_<|>",
45
+ "=_=",
46
+ ">_<",
47
+ "3_3",
48
+ "6_9",
49
+ ">_o",
50
+ "@_@",
51
+ "^_^",
52
+ "o_o",
53
+ "u_u",
54
+ "x_x",
55
+ "|_|",
56
+ "||_||",
57
+ ]
58
+
59
+
60
+ def parse_args() -> argparse.Namespace:
61
+ parser = argparse.ArgumentParser()
62
+ parser.add_argument("--score-slider-step", type=float, default=0.05)
63
+ parser.add_argument("--score-general-threshold", type=float, default=0.35)
64
+ parser.add_argument("--score-character-threshold", type=float, default=0.85)
65
+ parser.add_argument("--share", action="store_true")
66
+ return parser.parse_args()
67
+
68
+
69
+ def load_labels(dataframe) -> list[str]:
70
+ name_series = dataframe["name"]
71
+ name_series = name_series.map(
72
+ lambda x: x.replace("_", " ") if x not in kaomojis else x
73
+ )
74
+ tag_names = name_series.tolist()
75
+
76
+ rating_indexes = list(np.where(dataframe["category"] == 9)[0])
77
+ general_indexes = list(np.where(dataframe["category"] == 0)[0])
78
+ character_indexes = list(np.where(dataframe["category"] == 4)[0])
79
+ return tag_names, rating_indexes, general_indexes, character_indexes
80
+
81
+
82
+ def mcut_threshold(probs):
83
+ """
84
+ Maximum Cut Thresholding (MCut)
85
+ Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy
86
+ for Multi-label Classification. In 11th International Symposium, IDA 2012
87
+ (pp. 172-183).
88
+ """
89
+ sorted_probs = probs[probs.argsort()[::-1]]
90
+ difs = sorted_probs[:-1] - sorted_probs[1:]
91
+ t = difs.argmax()
92
+ thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
93
+ return thresh
94
+
95
+
96
+ class Predictor:
97
+ def __init__(self):
98
+ self.model_target_size = None
99
+ self.last_loaded_repo = None
100
+
101
+ def download_model(self, model_repo):
102
+ csv_path = huggingface_hub.hf_hub_download(
103
+ model_repo,
104
+ LABEL_FILENAME,
105
+ use_auth_token=HF_TOKEN,
106
+ )
107
+ model_path = huggingface_hub.hf_hub_download(
108
+ model_repo,
109
+ MODEL_FILENAME,
110
+ use_auth_token=HF_TOKEN,
111
+ )
112
+ return csv_path, model_path
113
+
114
+ def load_model(self, model_repo):
115
+ if model_repo == self.last_loaded_repo:
116
+ return
117
+
118
+ csv_path, model_path = self.download_model(model_repo)
119
+
120
+ tags_df = pd.read_csv(csv_path)
121
+ sep_tags = load_labels(tags_df)
122
+
123
+ self.tag_names = sep_tags[0]
124
+ self.rating_indexes = sep_tags[1]
125
+ self.general_indexes = sep_tags[2]
126
+ self.character_indexes = sep_tags[3]
127
+
128
+ model = rt.InferenceSession(model_path)
129
+ _, height, width, _ = model.get_inputs()[0].shape
130
+ self.model_target_size = height
131
+
132
+ self.last_loaded_repo = model_repo
133
+ self.model = model
134
+
135
+ def prepare_image(self, image):
136
+ target_size = self.model_target_size
137
+
138
+ canvas = Image.new("RGBA", image.size, (255, 255, 255))
139
+ canvas.alpha_composite(image)
140
+ image = canvas.convert("RGB")
141
+
142
+ # Pad image to square
143
+ image_shape = image.size
144
+ max_dim = max(image_shape)
145
+ pad_left = (max_dim - image_shape[0]) // 2
146
+ pad_top = (max_dim - image_shape[1]) // 2
147
+
148
+ padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
149
+ padded_image.paste(image, (pad_left, pad_top))
150
+
151
+ # Resize
152
+ if max_dim != target_size:
153
+ padded_image = padded_image.resize(
154
+ (target_size, target_size),
155
+ Image.BICUBIC,
156
+ )
157
+
158
+ # Convert to numpy array
159
+ image_array = np.asarray(padded_image, dtype=np.float32)
160
+
161
+ # Convert PIL-native RGB to BGR
162
+ image_array = image_array[:, :, ::-1]
163
+
164
+ return np.expand_dims(image_array, axis=0)
165
+
166
+ def predict(
167
+ self,
168
+ image,
169
+ model_repo,
170
+ general_thresh,
171
+ general_mcut_enabled,
172
+ character_thresh,
173
+ character_mcut_enabled,
174
+ ):
175
+ self.load_model(model_repo)
176
+
177
+ image = self.prepare_image(image)
178
+
179
+ input_name = self.model.get_inputs()[0].name
180
+ label_name = self.model.get_outputs()[0].name
181
+ preds = self.model.run([label_name], {input_name: image})[0]
182
+
183
+ labels = list(zip(self.tag_names, preds[0].astype(float)))
184
+
185
+ # First 4 labels are actually ratings: pick one with argmax
186
+ ratings_names = [labels[i] for i in self.rating_indexes]
187
+ rating = dict(ratings_names)
188
+
189
+ # Then we have general tags: pick any where prediction confidence > threshold
190
+ general_names = [labels[i] for i in self.general_indexes]
191
+
192
+ if general_mcut_enabled:
193
+ general_probs = np.array([x[1] for x in general_names])
194
+ general_thresh = mcut_threshold(general_probs)
195
+
196
+ general_res = [x for x in general_names if x[1] > general_thresh]
197
+ general_res = dict(general_res)
198
+
199
+ # Everything else is characters: pick any where prediction confidence > threshold
200
+ character_names = [labels[i] for i in self.character_indexes]
201
+
202
+ if character_mcut_enabled:
203
+ character_probs = np.array([x[1] for x in character_names])
204
+ character_thresh = mcut_threshold(character_probs)
205
+ character_thresh = max(0.15, character_thresh)
206
+
207
+ character_res = [x for x in character_names if x[1] > character_thresh]
208
+ character_res = dict(character_res)
209
+
210
+ sorted_general_strings = sorted(
211
+ general_res.items(),
212
+ key=lambda x: x[1],
213
+ reverse=True,
214
+ )
215
+ sorted_general_strings = [x[0] for x in sorted_general_strings]
216
+ sorted_general_strings = (
217
+ ", ".join(sorted_general_strings).replace("(", "\(").replace(")", "\)")
218
+ )
219
+
220
+ return sorted_general_strings, rating, character_res, general_res
221
+
222
+
223
+ def main():
224
+ args = parse_args()
225
+
226
+ predictor = Predictor()
227
+
228
+ dropdown_list = [
229
+ SWINV2_MODEL_DSV3_REPO,
230
+ CONV_MODEL_DSV3_REPO,
231
+ VIT_MODEL_DSV3_REPO,
232
+ MOAT_MODEL_DSV2_REPO,
233
+ SWIN_MODEL_DSV2_REPO,
234
+ CONV_MODEL_DSV2_REPO,
235
+ CONV2_MODEL_DSV2_REPO,
236
+ VIT_MODEL_DSV2_REPO,
237
+ ]
238
+
239
+ with gr.Blocks(title=TITLE) as demo:
240
+ with gr.Column():
241
+ gr.Markdown(
242
+ value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>"
243
+ )
244
+ gr.Markdown(value=DESCRIPTION)
245
+ with gr.Row():
246
+ with gr.Column(variant="panel"):
247
+ image = gr.Image(type="pil", image_mode="RGBA", label="Input")
248
+ model_repo = gr.Dropdown(
249
+ dropdown_list,
250
+ value=SWINV2_MODEL_DSV3_REPO,
251
+ label="Model",
252
+ )
253
+ with gr.Row():
254
+ general_thresh = gr.Slider(
255
+ 0,
256
+ 1,
257
+ step=args.score_slider_step,
258
+ value=args.score_general_threshold,
259
+ label="General Tags Threshold",
260
+ scale=3,
261
+ )
262
+ general_mcut_enabled = gr.Checkbox(
263
+ value=False,
264
+ label="Use MCut threshold",
265
+ scale=1,
266
+ )
267
+ with gr.Row():
268
+ character_thresh = gr.Slider(
269
+ 0,
270
+ 1,
271
+ step=args.score_slider_step,
272
+ value=args.score_character_threshold,
273
+ label="Character Tags Threshold",
274
+ scale=3,
275
+ )
276
+ character_mcut_enabled = gr.Checkbox(
277
+ value=False,
278
+ label="Use MCut threshold",
279
+ scale=1,
280
+ )
281
+ with gr.Row():
282
+ clear = gr.ClearButton(
283
+ components=[
284
+ image,
285
+ model_repo,
286
+ general_thresh,
287
+ general_mcut_enabled,
288
+ character_thresh,
289
+ character_mcut_enabled,
290
+ ],
291
+ variant="secondary",
292
+ size="lg",
293
+ )
294
+ submit = gr.Button(value="Submit", variant="primary", size="lg")
295
+ with gr.Column(variant="panel"):
296
+ sorted_general_strings = gr.Textbox(label="Output (string)")
297
+ rating = gr.Label(label="Rating")
298
+ character_res = gr.Label(label="Output (characters)")
299
+ general_res = gr.Label(label="Output (tags)")
300
+ clear.add(
301
+ [
302
+ sorted_general_strings,
303
+ rating,
304
+ character_res,
305
+ general_res,
306
+ ]
307
+ )
308
+
309
+ submit.click(
310
+ predictor.predict,
311
+ inputs=[
312
+ image,
313
+ model_repo,
314
+ general_thresh,
315
+ general_mcut_enabled,
316
+ character_thresh,
317
+ character_mcut_enabled,
318
+ ],
319
+ outputs=[sorted_general_strings, rating, character_res, general_res],
320
+ )
321
+
322
+ gr.Examples(
323
+ [["power.jpg", SWINV2_MODEL_DSV3_REPO, 0.35, False, 0.85, False]],
324
+ inputs=[
325
+ image,
326
+ model_repo,
327
+ general_thresh,
328
+ general_mcut_enabled,
329
+ character_thresh,
330
+ character_mcut_enabled,
331
+ ],
332
+ )
333
+
334
+ demo.queue(max_size=10)
335
+ demo.launch()
336
+
337
+
338
+ if __name__ == "__main__":
339
+ main()
gitattributes ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.onnx filter=lfs diff=lfs merge=lfs -text
14
+ *.ot filter=lfs diff=lfs merge=lfs -text
15
+ *.parquet filter=lfs diff=lfs merge=lfs -text
16
+ *.pb filter=lfs diff=lfs merge=lfs -text
17
+ *.pt filter=lfs diff=lfs merge=lfs -text
18
+ *.pth filter=lfs diff=lfs merge=lfs -text
19
+ *.rar filter=lfs diff=lfs merge=lfs -text
20
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
22
+ *.tflite filter=lfs diff=lfs merge=lfs -text
23
+ *.tgz filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ images
power.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ pillow>=9.0.0
2
+ onnxruntime>=1.12.0
3
+ huggingface-hub