Spaces:
Running
on
Zero
Running
on
Zero
Upload 22 files
Browse files- app.py +220 -452
- constants.py +453 -0
- env.py +15 -18
- modutils.py +263 -50
- requirements.txt +1 -0
- tagger/character_series_dict.csv +0 -0
- tagger/danbooru_e621.csv +0 -0
- tagger/output.py +16 -0
- tagger/tag_group.csv +0 -0
- tagger/tagger.py +556 -0
- tagger/utils.py +50 -0
- tagger/v2.py +260 -0
- utils.py +421 -0
app.py
CHANGED
@@ -1,231 +1,104 @@
|
|
1 |
import spaces
|
2 |
-
import gradio as gr
|
3 |
import os
|
4 |
from stablepy import Model_Diffusers
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
from stablepy.diffusers_vanilla.style_prompt_config import STYLE_NAMES
|
6 |
-
from stablepy.diffusers_vanilla.constants import FLUX_CN_UNION_MODES
|
7 |
import torch
|
8 |
import re
|
9 |
-
from huggingface_hub import HfApi
|
10 |
from stablepy import (
|
11 |
-
CONTROLNET_MODEL_IDS,
|
12 |
-
VALID_TASKS,
|
13 |
-
T2I_PREPROCESSOR_NAME,
|
14 |
-
FLASH_LORA,
|
15 |
-
SCHEDULER_CONFIG_MAP,
|
16 |
scheduler_names,
|
17 |
-
IP_ADAPTER_MODELS,
|
18 |
IP_ADAPTERS_SD,
|
19 |
IP_ADAPTERS_SDXL,
|
20 |
-
REPO_IMAGE_ENCODER,
|
21 |
-
ALL_PROMPT_WEIGHT_OPTIONS,
|
22 |
-
SD15_TASKS,
|
23 |
-
SDXL_TASKS,
|
24 |
)
|
25 |
import time
|
26 |
from PIL import ImageFile
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
|
|
30 |
print(os.getenv("SPACES_ZERO_GPU"))
|
31 |
|
32 |
-
PREPROCESSOR_CONTROLNET = {
|
33 |
-
"openpose": [
|
34 |
-
"Openpose",
|
35 |
-
"None",
|
36 |
-
],
|
37 |
-
"scribble": [
|
38 |
-
"HED",
|
39 |
-
"PidiNet",
|
40 |
-
"None",
|
41 |
-
],
|
42 |
-
"softedge": [
|
43 |
-
"PidiNet",
|
44 |
-
"HED",
|
45 |
-
"HED safe",
|
46 |
-
"PidiNet safe",
|
47 |
-
"None",
|
48 |
-
],
|
49 |
-
"segmentation": [
|
50 |
-
"UPerNet",
|
51 |
-
"None",
|
52 |
-
],
|
53 |
-
"depth": [
|
54 |
-
"DPT",
|
55 |
-
"Midas",
|
56 |
-
"None",
|
57 |
-
],
|
58 |
-
"normalbae": [
|
59 |
-
"NormalBae",
|
60 |
-
"None",
|
61 |
-
],
|
62 |
-
"lineart": [
|
63 |
-
"Lineart",
|
64 |
-
"Lineart coarse",
|
65 |
-
"Lineart (anime)",
|
66 |
-
"None",
|
67 |
-
"None (anime)",
|
68 |
-
],
|
69 |
-
"lineart_anime": [
|
70 |
-
"Lineart",
|
71 |
-
"Lineart coarse",
|
72 |
-
"Lineart (anime)",
|
73 |
-
"None",
|
74 |
-
"None (anime)",
|
75 |
-
],
|
76 |
-
"shuffle": [
|
77 |
-
"ContentShuffle",
|
78 |
-
"None",
|
79 |
-
],
|
80 |
-
"canny": [
|
81 |
-
"Canny",
|
82 |
-
"None",
|
83 |
-
],
|
84 |
-
"mlsd": [
|
85 |
-
"MLSD",
|
86 |
-
"None",
|
87 |
-
],
|
88 |
-
"ip2p": [
|
89 |
-
"ip2p"
|
90 |
-
],
|
91 |
-
"recolor": [
|
92 |
-
"Recolor luminance",
|
93 |
-
"Recolor intensity",
|
94 |
-
"None",
|
95 |
-
],
|
96 |
-
"tile": [
|
97 |
-
"Mild Blur",
|
98 |
-
"Moderate Blur",
|
99 |
-
"Heavy Blur",
|
100 |
-
"None",
|
101 |
-
],
|
102 |
-
}
|
103 |
-
|
104 |
-
TASK_STABLEPY = {
|
105 |
-
'txt2img': 'txt2img',
|
106 |
-
'img2img': 'img2img',
|
107 |
-
'inpaint': 'inpaint',
|
108 |
-
# 'canny T2I Adapter': 'sdxl_canny_t2i', # NO HAVE STEP CALLBACK PARAMETERS SO NOT WORKS WITH DIFFUSERS 0.29.0
|
109 |
-
# 'sketch T2I Adapter': 'sdxl_sketch_t2i',
|
110 |
-
# 'lineart T2I Adapter': 'sdxl_lineart_t2i',
|
111 |
-
# 'depth-midas T2I Adapter': 'sdxl_depth-midas_t2i',
|
112 |
-
# 'openpose T2I Adapter': 'sdxl_openpose_t2i',
|
113 |
-
'openpose ControlNet': 'openpose',
|
114 |
-
'canny ControlNet': 'canny',
|
115 |
-
'mlsd ControlNet': 'mlsd',
|
116 |
-
'scribble ControlNet': 'scribble',
|
117 |
-
'softedge ControlNet': 'softedge',
|
118 |
-
'segmentation ControlNet': 'segmentation',
|
119 |
-
'depth ControlNet': 'depth',
|
120 |
-
'normalbae ControlNet': 'normalbae',
|
121 |
-
'lineart ControlNet': 'lineart',
|
122 |
-
'lineart_anime ControlNet': 'lineart_anime',
|
123 |
-
'shuffle ControlNet': 'shuffle',
|
124 |
-
'ip2p ControlNet': 'ip2p',
|
125 |
-
'optical pattern ControlNet': 'pattern',
|
126 |
-
'recolor ControlNet': 'recolor',
|
127 |
-
'tile ControlNet': 'tile',
|
128 |
-
}
|
129 |
-
|
130 |
-
TASK_MODEL_LIST = list(TASK_STABLEPY.keys())
|
131 |
-
|
132 |
-
UPSCALER_DICT_GUI = {
|
133 |
-
None: None,
|
134 |
-
"Lanczos": "Lanczos",
|
135 |
-
"Nearest": "Nearest",
|
136 |
-
'Latent': 'Latent',
|
137 |
-
'Latent (antialiased)': 'Latent (antialiased)',
|
138 |
-
'Latent (bicubic)': 'Latent (bicubic)',
|
139 |
-
'Latent (bicubic antialiased)': 'Latent (bicubic antialiased)',
|
140 |
-
'Latent (nearest)': 'Latent (nearest)',
|
141 |
-
'Latent (nearest-exact)': 'Latent (nearest-exact)',
|
142 |
-
"RealESRGAN_x4plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
143 |
-
"RealESRNet_x4plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth",
|
144 |
-
"RealESRGAN_x4plus_anime_6B": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
145 |
-
"RealESRGAN_x2plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
146 |
-
"realesr-animevideov3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
|
147 |
-
"realesr-general-x4v3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
|
148 |
-
"realesr-general-wdn-x4v3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
|
149 |
-
"4x-UltraSharp": "https://huggingface.co/Shandypur/ESRGAN-4x-UltraSharp/resolve/main/4x-UltraSharp.pth",
|
150 |
-
"4x_foolhardy_Remacri": "https://huggingface.co/FacehugmanIII/4x_foolhardy_Remacri/resolve/main/4x_foolhardy_Remacri.pth",
|
151 |
-
"Remacri4xExtraSmoother": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/Remacri%204x%20ExtraSmoother.pth",
|
152 |
-
"AnimeSharp4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/AnimeSharp%204x.pth",
|
153 |
-
"lollypop": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/lollypop.pth",
|
154 |
-
"RealisticRescaler4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/RealisticRescaler%204x.pth",
|
155 |
-
"NickelbackFS4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/NickelbackFS%204x.pth"
|
156 |
-
}
|
157 |
-
|
158 |
-
UPSCALER_KEYS = list(UPSCALER_DICT_GUI.keys())
|
159 |
-
|
160 |
-
def get_model_list(directory_path):
|
161 |
-
model_list = []
|
162 |
-
valid_extensions = {'.ckpt', '.pt', '.pth', '.safetensors', '.bin'}
|
163 |
-
|
164 |
-
for filename in os.listdir(directory_path):
|
165 |
-
if os.path.splitext(filename)[1] in valid_extensions:
|
166 |
-
# name_without_extension = os.path.splitext(filename)[0]
|
167 |
-
file_path = os.path.join(directory_path, filename)
|
168 |
-
# model_list.append((name_without_extension, file_path))
|
169 |
-
model_list.append(file_path)
|
170 |
-
print('\033[34mFILE: ' + file_path + '\033[0m')
|
171 |
-
return model_list
|
172 |
-
|
173 |
## BEGIN MOD
|
174 |
from modutils import (list_uniq, download_private_repo, get_model_id_list, get_tupled_embed_list,
|
175 |
get_lora_model_list, get_all_lora_tupled_list, update_loras, apply_lora_prompt, set_prompt_loras,
|
176 |
get_my_lora, upload_file_lora, move_file_lora, search_civitai_lora, select_civitai_lora,
|
177 |
update_civitai_selection, get_civitai_tag, CIVITAI_SORT, CIVITAI_PERIOD, CIVITAI_BASEMODEL,
|
178 |
set_textual_inversion_prompt, get_model_pipeline, change_interface_mode, get_t2i_model_info,
|
179 |
-
get_tupled_model_list, save_gallery_images, set_optimization, set_sampler_settings,
|
180 |
set_quick_presets, process_style_prompt, optimization_list, save_images, download_things,
|
181 |
-
preset_styles, preset_quality, preset_sampler_setting, translate_to_en)
|
182 |
from env import (HF_TOKEN, CIVITAI_API_KEY, HF_LORA_ESSENTIAL_PRIVATE_REPO, HF_VAE_PRIVATE_REPO,
|
183 |
HF_SDXL_EMBEDS_NEGATIVE_PRIVATE_REPO, HF_SDXL_EMBEDS_POSITIVE_PRIVATE_REPO,
|
184 |
-
|
185 |
-
|
186 |
-
|
|
|
|
|
|
|
|
|
187 |
|
188 |
# - **Download Models**
|
189 |
-
|
190 |
# - **Download VAEs**
|
191 |
-
|
192 |
# - **Download LoRAs**
|
193 |
-
|
194 |
-
|
195 |
-
download_private_repo(HF_LORA_ESSENTIAL_PRIVATE_REPO, directory_loras, True)
|
196 |
-
download_private_repo(HF_VAE_PRIVATE_REPO, directory_vaes, False)
|
197 |
-
|
198 |
-
load_diffusers_format_model = list_uniq(get_model_id_list() + load_diffusers_format_model)
|
199 |
-
## END MOD
|
200 |
|
201 |
# Download stuffs
|
202 |
-
for url in [url.strip() for url in
|
203 |
if not os.path.exists(f"./models/{url.split('/')[-1]}"):
|
204 |
-
download_things(
|
205 |
-
for url in [url.strip() for url in
|
206 |
if not os.path.exists(f"./vaes/{url.split('/')[-1]}"):
|
207 |
-
download_things(
|
208 |
-
for url in [url.strip() for url in
|
209 |
if not os.path.exists(f"./loras/{url.split('/')[-1]}"):
|
210 |
-
download_things(
|
211 |
|
212 |
# Download Embeddings
|
213 |
-
for url_embed in
|
214 |
if not os.path.exists(f"./embedings/{url_embed.split('/')[-1]}"):
|
215 |
-
download_things(
|
216 |
|
217 |
# Build list models
|
218 |
-
embed_list = get_model_list(
|
219 |
-
model_list = get_model_list(directory_models)
|
220 |
-
model_list = load_diffusers_format_model + model_list
|
221 |
-
## BEGIN MOD
|
222 |
lora_model_list = get_lora_model_list()
|
223 |
-
vae_model_list = get_model_list(
|
224 |
vae_model_list.insert(0, "None")
|
225 |
|
226 |
-
|
227 |
-
|
228 |
-
|
|
|
|
|
229 |
|
230 |
def get_embed_list(pipeline_name):
|
231 |
return get_tupled_embed_list(embed_sdxl_list if pipeline_name == "StableDiffusionXLPipeline" else embed_list)
|
@@ -248,12 +121,13 @@ warnings.filterwarnings(action="ignore", category=UserWarning, module="diffusers
|
|
248 |
warnings.filterwarnings(action="ignore", category=FutureWarning, module="transformers")
|
249 |
## BEGIN MOD
|
250 |
from stablepy import logger
|
251 |
-
logger.setLevel(logging.CRITICAL)
|
|
|
252 |
|
253 |
-
from v2 import V2_ALL_MODELS, v2_random_prompt, v2_upsampling_prompt
|
254 |
-
from utils import (gradio_copy_text, COPY_ACTION_JS, gradio_copy_prompt,
|
255 |
V2_ASPECT_RATIO_OPTIONS, V2_RATING_OPTIONS, V2_LENGTH_OPTIONS, V2_IDENTITY_OPTIONS)
|
256 |
-
from tagger import (predict_tags_wd, convert_danbooru_to_e621_prompt,
|
257 |
remove_specific_prompt, insert_recom_prompt, insert_model_recom_prompt,
|
258 |
compose_prompt_to_copy, translate_prompt, select_random_character)
|
259 |
def description_ui():
|
@@ -267,137 +141,91 @@ def description_ui():
|
|
267 |
)
|
268 |
## END MOD
|
269 |
|
270 |
-
msg_inc_vae = (
|
271 |
-
"Use the right VAE for your model to maintain image quality. The wrong"
|
272 |
-
" VAE can lead to poor results, like blurriness in the generated images."
|
273 |
-
)
|
274 |
-
|
275 |
-
SDXL_TASK = [k for k, v in TASK_STABLEPY.items() if v in SDXL_TASKS]
|
276 |
-
SD_TASK = [k for k, v in TASK_STABLEPY.items() if v in SD15_TASKS]
|
277 |
-
FLUX_TASK = list(TASK_STABLEPY.keys())[:3] + [k for k, v in TASK_STABLEPY.items() if v in FLUX_CN_UNION_MODES.keys()]
|
278 |
-
|
279 |
-
MODEL_TYPE_TASK = {
|
280 |
-
"SD 1.5": SD_TASK,
|
281 |
-
"SDXL": SDXL_TASK,
|
282 |
-
"FLUX": FLUX_TASK,
|
283 |
-
}
|
284 |
-
|
285 |
-
MODEL_TYPE_CLASS = {
|
286 |
-
"diffusers:StableDiffusionPipeline": "SD 1.5",
|
287 |
-
"diffusers:StableDiffusionXLPipeline": "SDXL",
|
288 |
-
"diffusers:FluxPipeline": "FLUX",
|
289 |
-
}
|
290 |
-
|
291 |
-
POST_PROCESSING_SAMPLER = ["Use same sampler"] + scheduler_names[:-2]
|
292 |
-
|
293 |
-
SUBTITLE_GUI = (
|
294 |
-
"### This demo uses [diffusers](https://github.com/huggingface/diffusers)"
|
295 |
-
" to perform different tasks in image generation."
|
296 |
-
)
|
297 |
-
|
298 |
-
def extract_parameters(input_string):
|
299 |
-
parameters = {}
|
300 |
-
input_string = input_string.replace("\n", "")
|
301 |
-
|
302 |
-
if "Negative prompt:" not in input_string:
|
303 |
-
if "Steps:" in input_string:
|
304 |
-
input_string = input_string.replace("Steps:", "Negative prompt: Steps:")
|
305 |
-
else:
|
306 |
-
print("Invalid metadata")
|
307 |
-
parameters["prompt"] = input_string
|
308 |
-
return parameters
|
309 |
-
|
310 |
-
parm = input_string.split("Negative prompt:")
|
311 |
-
parameters["prompt"] = parm[0].strip()
|
312 |
-
if "Steps:" not in parm[1]:
|
313 |
-
print("Steps not detected")
|
314 |
-
parameters["neg_prompt"] = parm[1].strip()
|
315 |
-
return parameters
|
316 |
-
parm = parm[1].split("Steps:")
|
317 |
-
parameters["neg_prompt"] = parm[0].strip()
|
318 |
-
input_string = "Steps:" + parm[1]
|
319 |
-
|
320 |
-
# Extracting Steps
|
321 |
-
steps_match = re.search(r'Steps: (\d+)', input_string)
|
322 |
-
if steps_match:
|
323 |
-
parameters['Steps'] = int(steps_match.group(1))
|
324 |
-
|
325 |
-
# Extracting Size
|
326 |
-
size_match = re.search(r'Size: (\d+x\d+)', input_string)
|
327 |
-
if size_match:
|
328 |
-
parameters['Size'] = size_match.group(1)
|
329 |
-
width, height = map(int, parameters['Size'].split('x'))
|
330 |
-
parameters['width'] = width
|
331 |
-
parameters['height'] = height
|
332 |
-
|
333 |
-
# Extracting other parameters
|
334 |
-
other_parameters = re.findall(r'(\w+): (.*?)(?=, \w+|$)', input_string)
|
335 |
-
for param in other_parameters:
|
336 |
-
parameters[param[0]] = param[1].strip('"')
|
337 |
-
|
338 |
-
return parameters
|
339 |
-
|
340 |
-
def info_html(json_data, title, subtitle):
|
341 |
-
return f"""
|
342 |
-
<div style='padding: 0; border-radius: 10px;'>
|
343 |
-
<p style='margin: 0; font-weight: bold;'>{title}</p>
|
344 |
-
<details>
|
345 |
-
<summary>Details</summary>
|
346 |
-
<p style='margin: 0; font-weight: bold;'>{subtitle}</p>
|
347 |
-
</details>
|
348 |
-
</div>
|
349 |
-
"""
|
350 |
-
|
351 |
-
def get_model_type(repo_id: str):
|
352 |
-
api = HfApi(token=os.environ.get("HF_TOKEN")) # if use private or gated model
|
353 |
-
default = "SD 1.5"
|
354 |
-
try:
|
355 |
-
model = api.model_info(repo_id=repo_id, timeout=5.0)
|
356 |
-
tags = model.tags
|
357 |
-
for tag in tags:
|
358 |
-
if tag in MODEL_TYPE_CLASS.keys(): return MODEL_TYPE_CLASS.get(tag, default)
|
359 |
-
except Exception:
|
360 |
-
return default
|
361 |
-
return default
|
362 |
-
|
363 |
class GuiSD:
|
364 |
-
def __init__(self):
|
365 |
self.model = None
|
366 |
-
|
367 |
-
|
368 |
-
self.
|
369 |
-
base_model_id="Lykon/dreamshaper-8",
|
370 |
-
task_name="txt2img",
|
371 |
-
vae_model=None,
|
372 |
-
type_model_precision=torch.float16,
|
373 |
-
retain_task_model_in_cache=False,
|
374 |
-
device="cpu",
|
375 |
-
)
|
376 |
-
self.model.load_beta_styles()
|
377 |
-
#self.model.device = torch.device("cpu") #
|
378 |
|
379 |
def load_new_model(self, model_name, vae_model, task, progress=gr.Progress(track_tqdm=True)):
|
380 |
|
381 |
-
yield f"Loading model: {model_name}"
|
382 |
-
|
383 |
vae_model = vae_model if vae_model != "None" else None
|
384 |
model_type = get_model_type(model_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
385 |
|
386 |
if vae_model:
|
387 |
vae_type = "SDXL" if "sdxl" in vae_model.lower() else "SD 1.5"
|
388 |
if model_type != vae_type:
|
389 |
-
gr.Warning(
|
390 |
|
391 |
-
|
392 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
393 |
|
394 |
-
self.model.load_pipe(
|
395 |
-
model_name,
|
396 |
-
task_name=TASK_STABLEPY[task],
|
397 |
-
vae_model=vae_model if vae_model != "None" else None,
|
398 |
-
type_model_precision=dtype_model,
|
399 |
-
retain_task_model_in_cache=False,
|
400 |
-
)
|
401 |
yield f"Model loaded: {model_name}"
|
402 |
|
403 |
#@spaces.GPU
|
@@ -506,18 +334,17 @@ class GuiSD:
|
|
506 |
mode_ip2,
|
507 |
scale_ip2,
|
508 |
pag_scale,
|
509 |
-
#progress=gr.Progress(track_tqdm=True),
|
510 |
):
|
511 |
-
|
|
|
512 |
|
513 |
vae_model = vae_model if vae_model != "None" else None
|
514 |
loras_list = [lora1, lora2, lora3, lora4, lora5]
|
515 |
vae_msg = f"VAE: {vae_model}" if vae_model else ""
|
516 |
msg_lora = ""
|
517 |
|
518 |
-
print("Config model:", model_name, vae_model, loras_list)
|
519 |
-
|
520 |
## BEGIN MOD
|
|
|
521 |
global lora_model_list
|
522 |
lora_model_list = get_lora_model_list()
|
523 |
lora1, lora_scale1, lora2, lora_scale2, lora3, lora_scale3, lora4, lora_scale4, lora5, lora_scale5 = \
|
@@ -526,6 +353,8 @@ class GuiSD:
|
|
526 |
prompt, neg_prompt = insert_model_recom_prompt(prompt, neg_prompt, model_name)
|
527 |
## END MOD
|
528 |
|
|
|
|
|
529 |
task = TASK_STABLEPY[task]
|
530 |
|
531 |
params_ip_img = []
|
@@ -548,7 +377,8 @@ class GuiSD:
|
|
548 |
params_ip_mode.append(modeip)
|
549 |
params_ip_scale.append(scaleip)
|
550 |
|
551 |
-
|
|
|
552 |
|
553 |
if task != "txt2img" and not image_control:
|
554 |
raise ValueError("No control image found: To use this function, you have to upload an image in 'Image ControlNet/Inpaint/Img2img'")
|
@@ -621,15 +451,15 @@ class GuiSD:
|
|
621 |
"high_threshold": high_threshold,
|
622 |
"value_threshold": value_threshold,
|
623 |
"distance_threshold": distance_threshold,
|
624 |
-
"lora_A": lora1 if lora1 != "None"
|
625 |
"lora_scale_A": lora_scale1,
|
626 |
-
"lora_B": lora2 if lora2 != "None"
|
627 |
"lora_scale_B": lora_scale2,
|
628 |
-
"lora_C": lora3 if lora3 != "None"
|
629 |
"lora_scale_C": lora_scale3,
|
630 |
-
"lora_D": lora4 if lora4 != "None"
|
631 |
"lora_scale_D": lora_scale4,
|
632 |
-
"lora_E": lora5 if lora5 != "None"
|
633 |
"lora_scale_E": lora_scale5,
|
634 |
## BEGIN MOD
|
635 |
"textual_inversion": get_embed_list(self.model.class_name) if textual_inversion else [],
|
@@ -679,19 +509,24 @@ class GuiSD:
|
|
679 |
}
|
680 |
|
681 |
self.model.device = torch.device("cuda:0")
|
682 |
-
if hasattr(self.model.pipe, "transformer") and loras_list != ["None"] * 5
|
683 |
self.model.pipe.transformer.to(self.model.device)
|
684 |
print("transformer to cuda")
|
685 |
|
686 |
-
|
687 |
-
|
688 |
-
info_state = "PROCESSING "
|
689 |
for img, seed, image_path, metadata in self.model(**pipe_params):
|
690 |
-
info_state
|
|
|
691 |
if image_path:
|
692 |
-
|
693 |
if vae_msg:
|
694 |
-
|
|
|
|
|
|
|
|
|
|
|
695 |
|
696 |
for status, lora in zip(self.model.lora_status, self.model.lora_memory):
|
697 |
if status:
|
@@ -700,9 +535,9 @@ class GuiSD:
|
|
700 |
msg_lora += f"<br>Error with: {lora}"
|
701 |
|
702 |
if msg_lora:
|
703 |
-
|
704 |
|
705 |
-
|
706 |
|
707 |
download_links = "<br>".join(
|
708 |
[
|
@@ -711,19 +546,16 @@ class GuiSD:
|
|
711 |
]
|
712 |
)
|
713 |
if save_generated_images:
|
714 |
-
|
715 |
|
|
|
716 |
img = save_images(img, metadata)
|
|
|
717 |
|
718 |
-
|
719 |
-
|
720 |
-
def update_task_options(model_name, task_name):
|
721 |
-
new_choices = MODEL_TYPE_TASK[get_model_type(model_name)]
|
722 |
|
723 |
-
|
724 |
-
task_name = "txt2img"
|
725 |
|
726 |
-
return gr.update(value=task_name, choices=new_choices)
|
727 |
|
728 |
def dynamic_gpu_duration(func, duration, *args):
|
729 |
|
@@ -733,10 +565,12 @@ def dynamic_gpu_duration(func, duration, *args):
|
|
733 |
|
734 |
return wrapped_func()
|
735 |
|
|
|
736 |
@spaces.GPU
|
737 |
def dummy_gpu():
|
738 |
return None
|
739 |
|
|
|
740 |
def sd_gen_generate_pipeline(*args):
|
741 |
|
742 |
gpu_duration_arg = int(args[-1]) if args[-1] else 59
|
@@ -744,7 +578,7 @@ def sd_gen_generate_pipeline(*args):
|
|
744 |
load_lora_cpu = args[-3]
|
745 |
generation_args = args[:-3]
|
746 |
lora_list = [
|
747 |
-
None if item == "None" or item == "" else item
|
748 |
for item in [args[7], args[9], args[11], args[13], args[15]]
|
749 |
]
|
750 |
lora_status = [None] * 5
|
@@ -754,7 +588,7 @@ def sd_gen_generate_pipeline(*args):
|
|
754 |
msg_load_lora = "Updating LoRAs in CPU (Slow but saves GPU usage)..."
|
755 |
|
756 |
if lora_list != sd_gen.model.lora_memory and lora_list != [None] * 5:
|
757 |
-
yield
|
758 |
|
759 |
# Load lora in CPU
|
760 |
if load_lora_cpu:
|
@@ -780,14 +614,15 @@ def sd_gen_generate_pipeline(*args):
|
|
780 |
)
|
781 |
gr.Info(f"LoRAs in cache: {lora_cache_msg}")
|
782 |
|
783 |
-
|
|
|
784 |
gr.Info(msg_request)
|
785 |
print(msg_request)
|
786 |
-
|
787 |
-
# yield from sd_gen.generate_pipeline(*generation_args)
|
788 |
|
789 |
start_time = time.time()
|
790 |
|
|
|
791 |
yield from dynamic_gpu_duration(
|
792 |
sd_gen.generate_pipeline,
|
793 |
gpu_duration_arg,
|
@@ -795,31 +630,19 @@ def sd_gen_generate_pipeline(*args):
|
|
795 |
)
|
796 |
|
797 |
end_time = time.time()
|
|
|
|
|
|
|
|
|
798 |
|
799 |
if verbose_arg:
|
800 |
-
execution_time = end_time - start_time
|
801 |
-
msg_task_complete = (
|
802 |
-
f"GPU task complete in: {round(execution_time, 0) + 1} seconds"
|
803 |
-
)
|
804 |
gr.Info(msg_task_complete)
|
805 |
print(msg_task_complete)
|
806 |
|
807 |
-
|
808 |
-
if image is None: return ""
|
809 |
-
|
810 |
-
try:
|
811 |
-
metadata_keys = ['parameters', 'metadata', 'prompt', 'Comment']
|
812 |
-
|
813 |
-
for key in metadata_keys:
|
814 |
-
if key in image.info:
|
815 |
-
return image.info[key]
|
816 |
-
|
817 |
-
return str(image.info)
|
818 |
|
819 |
-
except Exception as e:
|
820 |
-
return f"Error extracting metadata: {str(e)}"
|
821 |
|
822 |
-
@spaces.GPU(duration=
|
823 |
def esrgan_upscale(image, upscaler_name, upscaler_size):
|
824 |
if image is None: return None
|
825 |
|
@@ -841,17 +664,20 @@ def esrgan_upscale(image, upscaler_name, upscaler_size):
|
|
841 |
|
842 |
return image_path
|
843 |
|
|
|
844 |
dynamic_gpu_duration.zerogpu = True
|
845 |
sd_gen_generate_pipeline.zerogpu = True
|
846 |
sd_gen = GuiSD()
|
847 |
|
|
|
848 |
## BEGIN MOD
|
849 |
CSS ="""
|
850 |
-
.gradio-container, #main { width:100%; height:100%; max-width:100%; padding-left:0; padding-right:0; margin-left:0; margin-right:0;
|
851 |
-
.contain { display:flex; flex-direction:column;
|
852 |
-
#component-0 { width:100%; height:100%;
|
853 |
-
#gallery { flex-grow:1;
|
854 |
-
|
|
|
855 |
#model-info { text-align:center; }
|
856 |
.title { font-size: 3em; align-items: center; text-align: center; }
|
857 |
.info { align-items: center; text-align: center; }
|
@@ -865,6 +691,15 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs
|
|
865 |
with gr.Tab("Generation"):
|
866 |
with gr.Row():
|
867 |
with gr.Column(scale=2):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
868 |
interface_mode_gui = gr.Radio(label="Quick settings", choices=["Simple", "Standard", "Fast", "LoRA"], value="Standard")
|
869 |
with gr.Accordion("Model and Task", open=False) as menu_model:
|
870 |
task_gui = gr.Dropdown(label="Task", choices=SDXL_TASK, value=TASK_MODEL_LIST[0])
|
@@ -928,7 +763,7 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs
|
|
928 |
[task_gui],
|
929 |
)
|
930 |
|
931 |
-
load_model_gui = gr.HTML()
|
932 |
|
933 |
result_images = gr.Gallery(
|
934 |
label="Generated images",
|
@@ -950,6 +785,13 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs
|
|
950 |
|
951 |
actual_task_info = gr.HTML()
|
952 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
953 |
with gr.Row(equal_height=False, variant="default"):
|
954 |
gpu_duration_gui = gr.Number(minimum=5, maximum=240, value=59, show_label=False, container=False, info="GPU time duration (seconds)")
|
955 |
with gr.Column():
|
@@ -972,15 +814,7 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs
|
|
972 |
with gr.Row():
|
973 |
sampler_gui = gr.Dropdown(label="Sampler", choices=scheduler_names, value="Euler a")
|
974 |
vae_model_gui = gr.Dropdown(label="VAE Model", choices=vae_model_list, value=vae_model_list[0])
|
975 |
-
|
976 |
-
("Compel format: (word)weight", "Compel"),
|
977 |
-
("Classic format: (word:weight)", "Classic"),
|
978 |
-
("Classic-original format: (word:weight)", "Classic-original"),
|
979 |
-
("Classic-no_norm format: (word:weight)", "Classic-no_norm"),
|
980 |
-
("Classic-ignore", "Classic-ignore"),
|
981 |
-
("None", "None"),
|
982 |
-
]
|
983 |
-
prompt_syntax_gui = gr.Dropdown(label="Prompt Syntax", choices=prompt_s_options, value=prompt_s_options[1][1])
|
984 |
|
985 |
with gr.Row(equal_height=False):
|
986 |
|
@@ -1130,8 +964,11 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs
|
|
1130 |
with gr.Accordion("Select from Gallery", open=False):
|
1131 |
search_civitai_gallery_lora = gr.Gallery([], label="Results", allow_preview=False, columns=5, show_share_button=False, interactive=False)
|
1132 |
search_civitai_result_lora = gr.Dropdown(label="Search Results", choices=[("", "")], value="", allow_custom_value=True, visible=False)
|
1133 |
-
|
|
|
|
|
1134 |
button_lora = gr.Button("Get and update lists of LoRAs")
|
|
|
1135 |
with gr.Accordion("From Local", open=True, visible=True):
|
1136 |
file_output_lora = gr.File(label="Uploaded LoRA", file_types=['.ckpt', '.pt', '.pth', '.safetensors', '.bin'], file_count="multiple", interactive=False, visible=False)
|
1137 |
upload_button_lora = gr.UploadButton(label="Upload LoRA from your disk (very slow)", file_types=['.ckpt', '.pt', '.pth', '.safetensors', '.bin'], file_count="multiple")
|
@@ -1169,7 +1006,7 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs
|
|
1169 |
negative_prompt_ad_a_gui = gr.Textbox(label="Negative prompt", placeholder="Main negative prompt will be use", lines=3)
|
1170 |
with gr.Row():
|
1171 |
strength_ad_a_gui = gr.Number(label="Strength:", value=0.35, step=0.01, minimum=0.01, maximum=1.0)
|
1172 |
-
face_detector_ad_a_gui = gr.Checkbox(label="Face detector", value=
|
1173 |
person_detector_ad_a_gui = gr.Checkbox(label="Person detector", value=True)
|
1174 |
hand_detector_ad_a_gui = gr.Checkbox(label="Hand detector", value=False)
|
1175 |
with gr.Row():
|
@@ -1184,7 +1021,7 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs
|
|
1184 |
negative_prompt_ad_b_gui = gr.Textbox(label="Negative prompt", placeholder="Main negative prompt will be use", lines=3)
|
1185 |
with gr.Row():
|
1186 |
strength_ad_b_gui = gr.Number(label="Strength:", value=0.35, step=0.01, minimum=0.01, maximum=1.0)
|
1187 |
-
face_detector_ad_b_gui = gr.Checkbox(label="Face detector", value=
|
1188 |
person_detector_ad_b_gui = gr.Checkbox(label="Person detector", value=True)
|
1189 |
hand_detector_ad_b_gui = gr.Checkbox(label="Hand detector", value=False)
|
1190 |
with gr.Row():
|
@@ -1314,73 +1151,7 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs
|
|
1314 |
|
1315 |
with gr.Accordion("Examples and help", open=True, visible=True) as menu_example:
|
1316 |
gr.Examples(
|
1317 |
-
examples=
|
1318 |
-
[
|
1319 |
-
"1girl, souryuu asuka langley, neon genesis evangelion, plugsuit, pilot suit, red bodysuit, sitting, crossing legs, black eye patch, cat hat, throne, symmetrical, looking down, from bottom, looking at viewer, outdoors, masterpiece, best quality, very aesthetic, absurdres",
|
1320 |
-
"nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
|
1321 |
-
1,
|
1322 |
-
30,
|
1323 |
-
7.5,
|
1324 |
-
True,
|
1325 |
-
-1,
|
1326 |
-
"Euler a",
|
1327 |
-
1152,
|
1328 |
-
896,
|
1329 |
-
"votepurchase/animagine-xl-3.1",
|
1330 |
-
],
|
1331 |
-
[
|
1332 |
-
"solo, princess Zelda OOT, score_9, score_8_up, score_8, medium breasts, cute, eyelashes, cute small face, long hair, crown braid, hairclip, pointy ears, soft curvy body, looking at viewer, smile, blush, white dress, medium body, (((holding the Master Sword))), standing, deep forest in the background",
|
1333 |
-
"score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white,",
|
1334 |
-
1,
|
1335 |
-
30,
|
1336 |
-
5.,
|
1337 |
-
True,
|
1338 |
-
-1,
|
1339 |
-
"Euler a",
|
1340 |
-
1024,
|
1341 |
-
1024,
|
1342 |
-
"votepurchase/ponyDiffusionV6XL",
|
1343 |
-
],
|
1344 |
-
[
|
1345 |
-
"1girl, oomuro sakurako, yuru yuri, official art, school uniform, anime artwork, anime style, vibrant, studio anime, highly detailed, masterpiece, best quality, very aesthetic, absurdres",
|
1346 |
-
"photo, deformed, black and white, realism, disfigured, low contrast, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
|
1347 |
-
1,
|
1348 |
-
40,
|
1349 |
-
7.0,
|
1350 |
-
True,
|
1351 |
-
-1,
|
1352 |
-
"Euler a",
|
1353 |
-
1024,
|
1354 |
-
1024,
|
1355 |
-
"Raelina/Rae-Diffusion-XL-V2",
|
1356 |
-
],
|
1357 |
-
[
|
1358 |
-
"1girl, akaza akari, yuru yuri, official art, anime artwork, anime style, vibrant, studio anime, highly detailed, masterpiece, best quality, very aesthetic, absurdres",
|
1359 |
-
"photo, deformed, black and white, realism, disfigured, low contrast, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
|
1360 |
-
1,
|
1361 |
-
35,
|
1362 |
-
7.0,
|
1363 |
-
True,
|
1364 |
-
-1,
|
1365 |
-
"Euler a",
|
1366 |
-
1024,
|
1367 |
-
1024,
|
1368 |
-
"Raelina/Raemu-XL-V4",
|
1369 |
-
],
|
1370 |
-
[
|
1371 |
-
"yoshida yuuko, machikado mazoku, 1girl, solo, demon horns,horns, school uniform, long hair, open mouth, skirt, demon girl, ahoge, shiny, shiny hair, anime artwork",
|
1372 |
-
"nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
|
1373 |
-
1,
|
1374 |
-
50,
|
1375 |
-
7.,
|
1376 |
-
True,
|
1377 |
-
-1,
|
1378 |
-
"Euler a",
|
1379 |
-
1024,
|
1380 |
-
1024,
|
1381 |
-
"cagliostrolab/animagine-xl-3.1",
|
1382 |
-
],
|
1383 |
-
],
|
1384 |
fn=sd_gen.generate_pipeline,
|
1385 |
inputs=[
|
1386 |
prompt_gui,
|
@@ -1395,16 +1166,12 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs
|
|
1395 |
img_width_gui,
|
1396 |
model_name_gui,
|
1397 |
],
|
1398 |
-
outputs=[result_images, actual_task_info],
|
1399 |
cache_examples=False,
|
1400 |
#elem_id="examples",
|
1401 |
)
|
1402 |
|
1403 |
-
gr.Markdown(
|
1404 |
-
"""### Resources
|
1405 |
-
- You can also try the image generator in Colab’s free tier, which provides free GPU [link](https://github.com/R3gm/SD_diffusers_interactive).
|
1406 |
-
"""
|
1407 |
-
)
|
1408 |
## END MOD
|
1409 |
|
1410 |
with gr.Tab("Inpaint mask maker", render=True):
|
@@ -1563,7 +1330,7 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs
|
|
1563 |
)
|
1564 |
search_civitai_result_lora.change(select_civitai_lora, [search_civitai_result_lora], [text_lora, search_civitai_desc_lora], queue=False, scroll_to_output=True)
|
1565 |
search_civitai_gallery_lora.select(update_civitai_selection, None, [search_civitai_result_lora], queue=False, show_api=False)
|
1566 |
-
button_lora.click(get_my_lora, [text_lora], [lora1_gui, lora2_gui, lora3_gui, lora4_gui, lora5_gui], scroll_to_output=True)
|
1567 |
upload_button_lora.upload(upload_file_lora, [upload_button_lora], [file_output_lora, upload_button_lora]).success(
|
1568 |
move_file_lora, [file_output_lora], [lora1_gui, lora2_gui, lora3_gui, lora4_gui, lora5_gui], scroll_to_output=True)
|
1569 |
|
@@ -1719,10 +1486,11 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs
|
|
1719 |
verbose_info_gui,
|
1720 |
gpu_duration_gui,
|
1721 |
],
|
1722 |
-
outputs=[result_images, actual_task_info],
|
1723 |
queue=True,
|
1724 |
show_progress="full",
|
1725 |
-
).success(save_gallery_images, [result_images], [result_images, result_images_files], queue=False, show_api=False)
|
|
|
1726 |
|
1727 |
with gr.Tab("Danbooru Tags Transformer with WD Tagger", render=True):
|
1728 |
with gr.Column(scale=2):
|
@@ -1818,5 +1586,5 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs
|
|
1818 |
gr.DuplicateButton(value="Duplicate Space for private use (This demo does not work on CPU. Requires GPU Space)")
|
1819 |
|
1820 |
app.queue()
|
1821 |
-
app.launch() # allowed_paths=["./images/"], show_error=True, debug=True
|
1822 |
## END MOD
|
|
|
1 |
import spaces
|
|
|
2 |
import os
|
3 |
from stablepy import Model_Diffusers
|
4 |
+
from constants import (
|
5 |
+
PREPROCESSOR_CONTROLNET,
|
6 |
+
TASK_STABLEPY,
|
7 |
+
TASK_MODEL_LIST,
|
8 |
+
UPSCALER_DICT_GUI,
|
9 |
+
UPSCALER_KEYS,
|
10 |
+
PROMPT_W_OPTIONS,
|
11 |
+
WARNING_MSG_VAE,
|
12 |
+
SDXL_TASK,
|
13 |
+
MODEL_TYPE_TASK,
|
14 |
+
POST_PROCESSING_SAMPLER,
|
15 |
+
|
16 |
+
)
|
17 |
from stablepy.diffusers_vanilla.style_prompt_config import STYLE_NAMES
|
|
|
18 |
import torch
|
19 |
import re
|
|
|
20 |
from stablepy import (
|
|
|
|
|
|
|
|
|
|
|
21 |
scheduler_names,
|
|
|
22 |
IP_ADAPTERS_SD,
|
23 |
IP_ADAPTERS_SDXL,
|
|
|
|
|
|
|
|
|
24 |
)
|
25 |
import time
|
26 |
from PIL import ImageFile
|
27 |
+
from utils import (
|
28 |
+
get_model_list,
|
29 |
+
extract_parameters,
|
30 |
+
get_model_type,
|
31 |
+
extract_exif_data,
|
32 |
+
create_mask_now,
|
33 |
+
download_diffuser_repo,
|
34 |
+
progress_step_bar,
|
35 |
+
html_template_message,
|
36 |
+
)
|
37 |
+
from datetime import datetime
|
38 |
+
import gradio as gr
|
39 |
+
import logging
|
40 |
+
import diffusers
|
41 |
+
import warnings
|
42 |
+
from stablepy import logger
|
43 |
+
# import urllib.parse
|
44 |
|
45 |
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
46 |
+
# os.environ["PYTORCH_NO_CUDA_MEMORY_CACHING"] = "1"
|
47 |
print(os.getenv("SPACES_ZERO_GPU"))
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
## BEGIN MOD
|
50 |
from modutils import (list_uniq, download_private_repo, get_model_id_list, get_tupled_embed_list,
|
51 |
get_lora_model_list, get_all_lora_tupled_list, update_loras, apply_lora_prompt, set_prompt_loras,
|
52 |
get_my_lora, upload_file_lora, move_file_lora, search_civitai_lora, select_civitai_lora,
|
53 |
update_civitai_selection, get_civitai_tag, CIVITAI_SORT, CIVITAI_PERIOD, CIVITAI_BASEMODEL,
|
54 |
set_textual_inversion_prompt, get_model_pipeline, change_interface_mode, get_t2i_model_info,
|
55 |
+
get_tupled_model_list, save_gallery_images, save_gallery_history, set_optimization, set_sampler_settings,
|
56 |
set_quick_presets, process_style_prompt, optimization_list, save_images, download_things,
|
57 |
+
preset_styles, preset_quality, preset_sampler_setting, translate_to_en, EXAMPLES_GUI, RESOURCES)
|
58 |
from env import (HF_TOKEN, CIVITAI_API_KEY, HF_LORA_ESSENTIAL_PRIVATE_REPO, HF_VAE_PRIVATE_REPO,
|
59 |
HF_SDXL_EMBEDS_NEGATIVE_PRIVATE_REPO, HF_SDXL_EMBEDS_POSITIVE_PRIVATE_REPO,
|
60 |
+
DIRECTORY_MODELS, DIRECTORY_LORAS, DIRECTORY_VAES, DIRECTORY_EMBEDS, DIRECTORY_EMBEDS_SDXL,
|
61 |
+
DIRECTORY_EMBEDS_POSITIVE_SDXL, LOAD_DIFFUSERS_FORMAT_MODEL,
|
62 |
+
DOWNLOAD_MODEL_LIST, DOWNLOAD_LORA_LIST, DOWNLOAD_VAE_LIST, DOWNLOAD_EMBEDS)
|
63 |
+
|
64 |
+
download_private_repo(HF_LORA_ESSENTIAL_PRIVATE_REPO, DIRECTORY_LORAS, True)
|
65 |
+
download_private_repo(HF_VAE_PRIVATE_REPO, DIRECTORY_VAES, False)
|
66 |
+
## END MOD
|
67 |
|
68 |
# - **Download Models**
|
69 |
+
DOWNLOAD_MODEL = ", ".join(DOWNLOAD_MODEL_LIST)
|
70 |
# - **Download VAEs**
|
71 |
+
DOWNLOAD_VAE = ", ".join(DOWNLOAD_VAE_LIST)
|
72 |
# - **Download LoRAs**
|
73 |
+
DOWNLOAD_LORA = ", ".join(DOWNLOAD_LORA_LIST)
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
# Download stuffs
|
76 |
+
for url in [url.strip() for url in DOWNLOAD_MODEL.split(',')]:
|
77 |
if not os.path.exists(f"./models/{url.split('/')[-1]}"):
|
78 |
+
download_things(DIRECTORY_MODELS, url, HF_TOKEN, CIVITAI_API_KEY)
|
79 |
+
for url in [url.strip() for url in DOWNLOAD_VAE.split(',')]:
|
80 |
if not os.path.exists(f"./vaes/{url.split('/')[-1]}"):
|
81 |
+
download_things(DIRECTORY_VAES, url, HF_TOKEN, CIVITAI_API_KEY)
|
82 |
+
for url in [url.strip() for url in DOWNLOAD_LORA.split(',')]:
|
83 |
if not os.path.exists(f"./loras/{url.split('/')[-1]}"):
|
84 |
+
download_things(DIRECTORY_LORAS, url, HF_TOKEN, CIVITAI_API_KEY)
|
85 |
|
86 |
# Download Embeddings
|
87 |
+
for url_embed in DOWNLOAD_EMBEDS:
|
88 |
if not os.path.exists(f"./embedings/{url_embed.split('/')[-1]}"):
|
89 |
+
download_things(DIRECTORY_EMBEDS, url_embed, HF_TOKEN, CIVITAI_API_KEY)
|
90 |
|
91 |
# Build list models
|
92 |
+
embed_list = get_model_list(DIRECTORY_EMBEDS)
|
|
|
|
|
|
|
93 |
lora_model_list = get_lora_model_list()
|
94 |
+
vae_model_list = get_model_list(DIRECTORY_VAES)
|
95 |
vae_model_list.insert(0, "None")
|
96 |
|
97 |
+
## BEGIN MOD
|
98 |
+
model_list = list_uniq(get_model_id_list() + LOAD_DIFFUSERS_FORMAT_MODEL + get_model_list(DIRECTORY_MODELS))
|
99 |
+
download_private_repo(HF_SDXL_EMBEDS_NEGATIVE_PRIVATE_REPO, DIRECTORY_EMBEDS_SDXL, False)
|
100 |
+
download_private_repo(HF_SDXL_EMBEDS_POSITIVE_PRIVATE_REPO, DIRECTORY_EMBEDS_POSITIVE_SDXL, False)
|
101 |
+
embed_sdxl_list = get_model_list(DIRECTORY_EMBEDS_SDXL) + get_model_list(DIRECTORY_EMBEDS_POSITIVE_SDXL)
|
102 |
|
103 |
def get_embed_list(pipeline_name):
|
104 |
return get_tupled_embed_list(embed_sdxl_list if pipeline_name == "StableDiffusionXLPipeline" else embed_list)
|
|
|
121 |
warnings.filterwarnings(action="ignore", category=FutureWarning, module="transformers")
|
122 |
## BEGIN MOD
|
123 |
from stablepy import logger
|
124 |
+
#logger.setLevel(logging.CRITICAL)
|
125 |
+
logger.setLevel(logging.DEBUG)
|
126 |
|
127 |
+
from tagger.v2 import V2_ALL_MODELS, v2_random_prompt, v2_upsampling_prompt
|
128 |
+
from tagger.utils import (gradio_copy_text, COPY_ACTION_JS, gradio_copy_prompt,
|
129 |
V2_ASPECT_RATIO_OPTIONS, V2_RATING_OPTIONS, V2_LENGTH_OPTIONS, V2_IDENTITY_OPTIONS)
|
130 |
+
from tagger.tagger import (predict_tags_wd, convert_danbooru_to_e621_prompt,
|
131 |
remove_specific_prompt, insert_recom_prompt, insert_model_recom_prompt,
|
132 |
compose_prompt_to_copy, translate_prompt, select_random_character)
|
133 |
def description_ui():
|
|
|
141 |
)
|
142 |
## END MOD
|
143 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
class GuiSD:
|
145 |
+
def __init__(self, stream=True):
|
146 |
self.model = None
|
147 |
+
self.status_loading = False
|
148 |
+
self.sleep_loading = 4
|
149 |
+
self.last_load = datetime.now()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
|
151 |
def load_new_model(self, model_name, vae_model, task, progress=gr.Progress(track_tqdm=True)):
|
152 |
|
|
|
|
|
153 |
vae_model = vae_model if vae_model != "None" else None
|
154 |
model_type = get_model_type(model_name)
|
155 |
+
dtype_model = torch.bfloat16 if model_type == "FLUX" else torch.float16
|
156 |
+
|
157 |
+
if not os.path.exists(model_name):
|
158 |
+
_ = download_diffuser_repo(
|
159 |
+
repo_name=model_name,
|
160 |
+
model_type=model_type,
|
161 |
+
revision="main",
|
162 |
+
token=True,
|
163 |
+
)
|
164 |
+
|
165 |
+
for i in range(68):
|
166 |
+
if not self.status_loading:
|
167 |
+
self.status_loading = True
|
168 |
+
if i > 0:
|
169 |
+
time.sleep(self.sleep_loading)
|
170 |
+
print("Previous model ops...")
|
171 |
+
break
|
172 |
+
time.sleep(0.5)
|
173 |
+
print(f"Waiting queue {i}")
|
174 |
+
yield "Waiting queue"
|
175 |
+
|
176 |
+
self.status_loading = True
|
177 |
+
|
178 |
+
yield f"Loading model: {model_name}"
|
179 |
|
180 |
if vae_model:
|
181 |
vae_type = "SDXL" if "sdxl" in vae_model.lower() else "SD 1.5"
|
182 |
if model_type != vae_type:
|
183 |
+
gr.Warning(WARNING_MSG_VAE)
|
184 |
|
185 |
+
print("Loading model...")
|
186 |
+
|
187 |
+
try:
|
188 |
+
start_time = time.time()
|
189 |
+
|
190 |
+
if self.model is None:
|
191 |
+
self.model = Model_Diffusers(
|
192 |
+
base_model_id=model_name,
|
193 |
+
task_name=TASK_STABLEPY[task],
|
194 |
+
vae_model=vae_model,
|
195 |
+
type_model_precision=dtype_model,
|
196 |
+
retain_task_model_in_cache=False,
|
197 |
+
device="cpu",
|
198 |
+
)
|
199 |
+
else:
|
200 |
+
|
201 |
+
if self.model.base_model_id != model_name:
|
202 |
+
load_now_time = datetime.now()
|
203 |
+
elapsed_time = max((load_now_time - self.last_load).total_seconds(), 0)
|
204 |
+
|
205 |
+
if elapsed_time <= 8:
|
206 |
+
print("Waiting for the previous model's time ops...")
|
207 |
+
time.sleep(8-elapsed_time)
|
208 |
+
|
209 |
+
self.model.device = torch.device("cpu")
|
210 |
+
self.model.load_pipe(
|
211 |
+
model_name,
|
212 |
+
task_name=TASK_STABLEPY[task],
|
213 |
+
vae_model=vae_model,
|
214 |
+
type_model_precision=dtype_model,
|
215 |
+
retain_task_model_in_cache=False,
|
216 |
+
)
|
217 |
+
|
218 |
+
end_time = time.time()
|
219 |
+
self.sleep_loading = max(min(int(end_time - start_time), 10), 4)
|
220 |
+
except Exception as e:
|
221 |
+
self.last_load = datetime.now()
|
222 |
+
self.status_loading = False
|
223 |
+
self.sleep_loading = 4
|
224 |
+
raise e
|
225 |
+
|
226 |
+
self.last_load = datetime.now()
|
227 |
+
self.status_loading = False
|
228 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
yield f"Model loaded: {model_name}"
|
230 |
|
231 |
#@spaces.GPU
|
|
|
334 |
mode_ip2,
|
335 |
scale_ip2,
|
336 |
pag_scale,
|
|
|
337 |
):
|
338 |
+
info_state = html_template_message("Navigating latent space...")
|
339 |
+
yield info_state, gr.update(), gr.update()
|
340 |
|
341 |
vae_model = vae_model if vae_model != "None" else None
|
342 |
loras_list = [lora1, lora2, lora3, lora4, lora5]
|
343 |
vae_msg = f"VAE: {vae_model}" if vae_model else ""
|
344 |
msg_lora = ""
|
345 |
|
|
|
|
|
346 |
## BEGIN MOD
|
347 |
+
loras_list = [s if s else "None" for s in loras_list]
|
348 |
global lora_model_list
|
349 |
lora_model_list = get_lora_model_list()
|
350 |
lora1, lora_scale1, lora2, lora_scale2, lora3, lora_scale3, lora4, lora_scale4, lora5, lora_scale5 = \
|
|
|
353 |
prompt, neg_prompt = insert_model_recom_prompt(prompt, neg_prompt, model_name)
|
354 |
## END MOD
|
355 |
|
356 |
+
print("Config model:", model_name, vae_model, loras_list)
|
357 |
+
|
358 |
task = TASK_STABLEPY[task]
|
359 |
|
360 |
params_ip_img = []
|
|
|
377 |
params_ip_mode.append(modeip)
|
378 |
params_ip_scale.append(scaleip)
|
379 |
|
380 |
+
concurrency = 5
|
381 |
+
self.model.stream_config(concurrency=concurrency, latent_resize_by=1, vae_decoding=False)
|
382 |
|
383 |
if task != "txt2img" and not image_control:
|
384 |
raise ValueError("No control image found: To use this function, you have to upload an image in 'Image ControlNet/Inpaint/Img2img'")
|
|
|
451 |
"high_threshold": high_threshold,
|
452 |
"value_threshold": value_threshold,
|
453 |
"distance_threshold": distance_threshold,
|
454 |
+
"lora_A": lora1 if lora1 != "None" else None,
|
455 |
"lora_scale_A": lora_scale1,
|
456 |
+
"lora_B": lora2 if lora2 != "None" else None,
|
457 |
"lora_scale_B": lora_scale2,
|
458 |
+
"lora_C": lora3 if lora3 != "None" else None,
|
459 |
"lora_scale_C": lora_scale3,
|
460 |
+
"lora_D": lora4 if lora4 != "None" else None,
|
461 |
"lora_scale_D": lora_scale4,
|
462 |
+
"lora_E": lora5 if lora5 != "None" else None,
|
463 |
"lora_scale_E": lora_scale5,
|
464 |
## BEGIN MOD
|
465 |
"textual_inversion": get_embed_list(self.model.class_name) if textual_inversion else [],
|
|
|
509 |
}
|
510 |
|
511 |
self.model.device = torch.device("cuda:0")
|
512 |
+
if hasattr(self.model.pipe, "transformer") and loras_list != ["None"] * 5:
|
513 |
self.model.pipe.transformer.to(self.model.device)
|
514 |
print("transformer to cuda")
|
515 |
|
516 |
+
actual_progress = 0
|
517 |
+
info_images = gr.update()
|
|
|
518 |
for img, seed, image_path, metadata in self.model(**pipe_params):
|
519 |
+
info_state = progress_step_bar(actual_progress, steps)
|
520 |
+
actual_progress += concurrency
|
521 |
if image_path:
|
522 |
+
info_images = f"Seeds: {str(seed)}"
|
523 |
if vae_msg:
|
524 |
+
info_images = info_images + "<br>" + vae_msg
|
525 |
+
|
526 |
+
if "Cannot copy out of meta tensor; no data!" in self.model.last_lora_error:
|
527 |
+
msg_ram = "Unable to process the LoRAs due to high RAM usage; please try again later."
|
528 |
+
print(msg_ram)
|
529 |
+
msg_lora += f"<br>{msg_ram}"
|
530 |
|
531 |
for status, lora in zip(self.model.lora_status, self.model.lora_memory):
|
532 |
if status:
|
|
|
535 |
msg_lora += f"<br>Error with: {lora}"
|
536 |
|
537 |
if msg_lora:
|
538 |
+
info_images += msg_lora
|
539 |
|
540 |
+
info_images = info_images + "<br>" + "GENERATION DATA:<br>" + metadata[0].replace("\n", "<br>") + "<br>-------<br>"
|
541 |
|
542 |
download_links = "<br>".join(
|
543 |
[
|
|
|
546 |
]
|
547 |
)
|
548 |
if save_generated_images:
|
549 |
+
info_images += f"<br>{download_links}"
|
550 |
|
551 |
+
## BEGIN MOD
|
552 |
img = save_images(img, metadata)
|
553 |
+
## END MOD
|
554 |
|
555 |
+
info_state = "COMPLETE"
|
|
|
|
|
|
|
556 |
|
557 |
+
yield info_state, img, info_images
|
|
|
558 |
|
|
|
559 |
|
560 |
def dynamic_gpu_duration(func, duration, *args):
|
561 |
|
|
|
565 |
|
566 |
return wrapped_func()
|
567 |
|
568 |
+
|
569 |
@spaces.GPU
|
570 |
def dummy_gpu():
|
571 |
return None
|
572 |
|
573 |
+
|
574 |
def sd_gen_generate_pipeline(*args):
|
575 |
|
576 |
gpu_duration_arg = int(args[-1]) if args[-1] else 59
|
|
|
578 |
load_lora_cpu = args[-3]
|
579 |
generation_args = args[:-3]
|
580 |
lora_list = [
|
581 |
+
None if item == "None" or item == "" else item # MOD
|
582 |
for item in [args[7], args[9], args[11], args[13], args[15]]
|
583 |
]
|
584 |
lora_status = [None] * 5
|
|
|
588 |
msg_load_lora = "Updating LoRAs in CPU (Slow but saves GPU usage)..."
|
589 |
|
590 |
if lora_list != sd_gen.model.lora_memory and lora_list != [None] * 5:
|
591 |
+
yield msg_load_lora, gr.update(), gr.update()
|
592 |
|
593 |
# Load lora in CPU
|
594 |
if load_lora_cpu:
|
|
|
614 |
)
|
615 |
gr.Info(f"LoRAs in cache: {lora_cache_msg}")
|
616 |
|
617 |
+
msg_request = f"Requesting {gpu_duration_arg}s. of GPU time.\nModel: {sd_gen.model.base_model_id}"
|
618 |
+
if verbose_arg:
|
619 |
gr.Info(msg_request)
|
620 |
print(msg_request)
|
621 |
+
yield msg_request.replace("\n", "<br>"), gr.update(), gr.update()
|
|
|
622 |
|
623 |
start_time = time.time()
|
624 |
|
625 |
+
# yield from sd_gen.generate_pipeline(*generation_args)
|
626 |
yield from dynamic_gpu_duration(
|
627 |
sd_gen.generate_pipeline,
|
628 |
gpu_duration_arg,
|
|
|
630 |
)
|
631 |
|
632 |
end_time = time.time()
|
633 |
+
execution_time = end_time - start_time
|
634 |
+
msg_task_complete = (
|
635 |
+
f"GPU task complete in: {int(round(execution_time, 0) + 1)} seconds"
|
636 |
+
)
|
637 |
|
638 |
if verbose_arg:
|
|
|
|
|
|
|
|
|
639 |
gr.Info(msg_task_complete)
|
640 |
print(msg_task_complete)
|
641 |
|
642 |
+
yield msg_task_complete, gr.update(), gr.update()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
643 |
|
|
|
|
|
644 |
|
645 |
+
@spaces.GPU(duration=15)
|
646 |
def esrgan_upscale(image, upscaler_name, upscaler_size):
|
647 |
if image is None: return None
|
648 |
|
|
|
664 |
|
665 |
return image_path
|
666 |
|
667 |
+
|
668 |
dynamic_gpu_duration.zerogpu = True
|
669 |
sd_gen_generate_pipeline.zerogpu = True
|
670 |
sd_gen = GuiSD()
|
671 |
|
672 |
+
|
673 |
## BEGIN MOD
|
674 |
CSS ="""
|
675 |
+
.gradio-container, #main { width:100%; height:100%; max-width:100%; padding-left:0; padding-right:0; margin-left:0; margin-right:0; }
|
676 |
+
.contain { display:flex; flex-direction:column; }
|
677 |
+
#component-0 { width:100%; height:100%; }
|
678 |
+
#gallery { flex-grow:1; }
|
679 |
+
#load_model { height: 50px; }
|
680 |
+
.lora { min-width:480px; }
|
681 |
#model-info { text-align:center; }
|
682 |
.title { font-size: 3em; align-items: center; text-align: center; }
|
683 |
.info { align-items: center; text-align: center; }
|
|
|
691 |
with gr.Tab("Generation"):
|
692 |
with gr.Row():
|
693 |
with gr.Column(scale=2):
|
694 |
+
|
695 |
+
def update_task_options(model_name, task_name):
|
696 |
+
new_choices = MODEL_TYPE_TASK[get_model_type(model_name)]
|
697 |
+
|
698 |
+
if task_name not in new_choices:
|
699 |
+
task_name = "txt2img"
|
700 |
+
|
701 |
+
return gr.update(value=task_name, choices=new_choices)
|
702 |
+
|
703 |
interface_mode_gui = gr.Radio(label="Quick settings", choices=["Simple", "Standard", "Fast", "LoRA"], value="Standard")
|
704 |
with gr.Accordion("Model and Task", open=False) as menu_model:
|
705 |
task_gui = gr.Dropdown(label="Task", choices=SDXL_TASK, value=TASK_MODEL_LIST[0])
|
|
|
763 |
[task_gui],
|
764 |
)
|
765 |
|
766 |
+
load_model_gui = gr.HTML(elem_id="load_model", elem_classes="contain")
|
767 |
|
768 |
result_images = gr.Gallery(
|
769 |
label="Generated images",
|
|
|
785 |
|
786 |
actual_task_info = gr.HTML()
|
787 |
|
788 |
+
with gr.Accordion("History", open=False):
|
789 |
+
history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", format="png", interactive=False, show_share_button=False,
|
790 |
+
show_download_button=True)
|
791 |
+
history_files = gr.Files(interactive=False, visible=False)
|
792 |
+
history_clear_button = gr.Button(value="Clear History", variant="secondary")
|
793 |
+
history_clear_button.click(lambda: ([], []), None, [history_gallery, history_files], queue=False, show_api=False)
|
794 |
+
|
795 |
with gr.Row(equal_height=False, variant="default"):
|
796 |
gpu_duration_gui = gr.Number(minimum=5, maximum=240, value=59, show_label=False, container=False, info="GPU time duration (seconds)")
|
797 |
with gr.Column():
|
|
|
814 |
with gr.Row():
|
815 |
sampler_gui = gr.Dropdown(label="Sampler", choices=scheduler_names, value="Euler a")
|
816 |
vae_model_gui = gr.Dropdown(label="VAE Model", choices=vae_model_list, value=vae_model_list[0])
|
817 |
+
prompt_syntax_gui = gr.Dropdown(label="Prompt Syntax", choices=PROMPT_W_OPTIONS, value=PROMPT_W_OPTIONS[1][1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
818 |
|
819 |
with gr.Row(equal_height=False):
|
820 |
|
|
|
964 |
with gr.Accordion("Select from Gallery", open=False):
|
965 |
search_civitai_gallery_lora = gr.Gallery([], label="Results", allow_preview=False, columns=5, show_share_button=False, interactive=False)
|
966 |
search_civitai_result_lora = gr.Dropdown(label="Search Results", choices=[("", "")], value="", allow_custom_value=True, visible=False)
|
967 |
+
with gr.Row():
|
968 |
+
text_lora = gr.Textbox(label="LoRA URL", placeholder="https://civitai.com/api/download/models/28907", lines=1, scale=4)
|
969 |
+
romanize_text = gr.Checkbox(value=False, label="Transliterate name", scale=1)
|
970 |
button_lora = gr.Button("Get and update lists of LoRAs")
|
971 |
+
new_lora_status = gr.HTML()
|
972 |
with gr.Accordion("From Local", open=True, visible=True):
|
973 |
file_output_lora = gr.File(label="Uploaded LoRA", file_types=['.ckpt', '.pt', '.pth', '.safetensors', '.bin'], file_count="multiple", interactive=False, visible=False)
|
974 |
upload_button_lora = gr.UploadButton(label="Upload LoRA from your disk (very slow)", file_types=['.ckpt', '.pt', '.pth', '.safetensors', '.bin'], file_count="multiple")
|
|
|
1006 |
negative_prompt_ad_a_gui = gr.Textbox(label="Negative prompt", placeholder="Main negative prompt will be use", lines=3)
|
1007 |
with gr.Row():
|
1008 |
strength_ad_a_gui = gr.Number(label="Strength:", value=0.35, step=0.01, minimum=0.01, maximum=1.0)
|
1009 |
+
face_detector_ad_a_gui = gr.Checkbox(label="Face detector", value=False)
|
1010 |
person_detector_ad_a_gui = gr.Checkbox(label="Person detector", value=True)
|
1011 |
hand_detector_ad_a_gui = gr.Checkbox(label="Hand detector", value=False)
|
1012 |
with gr.Row():
|
|
|
1021 |
negative_prompt_ad_b_gui = gr.Textbox(label="Negative prompt", placeholder="Main negative prompt will be use", lines=3)
|
1022 |
with gr.Row():
|
1023 |
strength_ad_b_gui = gr.Number(label="Strength:", value=0.35, step=0.01, minimum=0.01, maximum=1.0)
|
1024 |
+
face_detector_ad_b_gui = gr.Checkbox(label="Face detector", value=False)
|
1025 |
person_detector_ad_b_gui = gr.Checkbox(label="Person detector", value=True)
|
1026 |
hand_detector_ad_b_gui = gr.Checkbox(label="Hand detector", value=False)
|
1027 |
with gr.Row():
|
|
|
1151 |
|
1152 |
with gr.Accordion("Examples and help", open=True, visible=True) as menu_example:
|
1153 |
gr.Examples(
|
1154 |
+
examples=EXAMPLES_GUI,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1155 |
fn=sd_gen.generate_pipeline,
|
1156 |
inputs=[
|
1157 |
prompt_gui,
|
|
|
1166 |
img_width_gui,
|
1167 |
model_name_gui,
|
1168 |
],
|
1169 |
+
outputs=[load_model_gui, result_images, actual_task_info],
|
1170 |
cache_examples=False,
|
1171 |
#elem_id="examples",
|
1172 |
)
|
1173 |
|
1174 |
+
gr.Markdown(RESOURCES)
|
|
|
|
|
|
|
|
|
1175 |
## END MOD
|
1176 |
|
1177 |
with gr.Tab("Inpaint mask maker", render=True):
|
|
|
1330 |
)
|
1331 |
search_civitai_result_lora.change(select_civitai_lora, [search_civitai_result_lora], [text_lora, search_civitai_desc_lora], queue=False, scroll_to_output=True)
|
1332 |
search_civitai_gallery_lora.select(update_civitai_selection, None, [search_civitai_result_lora], queue=False, show_api=False)
|
1333 |
+
button_lora.click(get_my_lora, [text_lora, romanize_text], [lora1_gui, lora2_gui, lora3_gui, lora4_gui, lora5_gui, new_lora_status], scroll_to_output=True)
|
1334 |
upload_button_lora.upload(upload_file_lora, [upload_button_lora], [file_output_lora, upload_button_lora]).success(
|
1335 |
move_file_lora, [file_output_lora], [lora1_gui, lora2_gui, lora3_gui, lora4_gui, lora5_gui], scroll_to_output=True)
|
1336 |
|
|
|
1486 |
verbose_info_gui,
|
1487 |
gpu_duration_gui,
|
1488 |
],
|
1489 |
+
outputs=[load_model_gui, result_images, actual_task_info],
|
1490 |
queue=True,
|
1491 |
show_progress="full",
|
1492 |
+
).success(save_gallery_images, [result_images, model_name_gui], [result_images, result_images_files], queue=False, show_api=False)\
|
1493 |
+
.success(save_gallery_history, [result_images, result_images_files, history_gallery, history_files], [history_gallery, history_files], queue=False, show_api=False)
|
1494 |
|
1495 |
with gr.Tab("Danbooru Tags Transformer with WD Tagger", render=True):
|
1496 |
with gr.Column(scale=2):
|
|
|
1586 |
gr.DuplicateButton(value="Duplicate Space for private use (This demo does not work on CPU. Requires GPU Space)")
|
1587 |
|
1588 |
app.queue()
|
1589 |
+
app.launch(show_error=True, debug=True) # allowed_paths=["./images/"], show_error=True, debug=True
|
1590 |
## END MOD
|
constants.py
ADDED
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from stablepy.diffusers_vanilla.constants import FLUX_CN_UNION_MODES
|
3 |
+
from stablepy import (
|
4 |
+
scheduler_names,
|
5 |
+
SD15_TASKS,
|
6 |
+
SDXL_TASKS,
|
7 |
+
)
|
8 |
+
|
9 |
+
# - **Download Models**
|
10 |
+
DOWNLOAD_MODEL = "https://civitai.com/api/download/models/574369, https://huggingface.co/TechnoByte/MilkyWonderland/resolve/main/milkyWonderland_v40.safetensors"
|
11 |
+
|
12 |
+
# - **Download VAEs**
|
13 |
+
DOWNLOAD_VAE = "https://huggingface.co/nubby/blessed-sdxl-vae-fp16-fix/resolve/main/sdxl_vae-fp16fix-c-1.1-b-0.5.safetensors?download=true, https://huggingface.co/nubby/blessed-sdxl-vae-fp16-fix/resolve/main/sdxl_vae-fp16fix-blessed.safetensors?download=true, https://huggingface.co/digiplay/VAE/resolve/main/vividReal_v20.safetensors?download=true, https://huggingface.co/fp16-guy/anything_kl-f8-anime2_vae-ft-mse-840000-ema-pruned_blessed_clearvae_fp16_cleaned/resolve/main/vae-ft-mse-840000-ema-pruned_fp16.safetensors?download=true"
|
14 |
+
|
15 |
+
# - **Download LoRAs**
|
16 |
+
DOWNLOAD_LORA = "https://huggingface.co/Leopain/color/resolve/main/Coloring_book_-_LineArt.safetensors, https://civitai.com/api/download/models/135867, https://huggingface.co/Linaqruf/anime-detailer-xl-lora/resolve/main/anime-detailer-xl.safetensors?download=true, https://huggingface.co/Linaqruf/style-enhancer-xl-lora/resolve/main/style-enhancer-xl.safetensors?download=true, https://huggingface.co/ByteDance/Hyper-SD/resolve/main/Hyper-SD15-8steps-CFG-lora.safetensors?download=true, https://huggingface.co/ByteDance/Hyper-SD/resolve/main/Hyper-SDXL-8steps-CFG-lora.safetensors?download=true"
|
17 |
+
|
18 |
+
LOAD_DIFFUSERS_FORMAT_MODEL = [
|
19 |
+
'stabilityai/stable-diffusion-xl-base-1.0',
|
20 |
+
'black-forest-labs/FLUX.1-dev',
|
21 |
+
'John6666/blue-pencil-flux1-v021-fp8-flux',
|
22 |
+
'John6666/wai-ani-flux-v10forfp8-fp8-flux',
|
23 |
+
'John6666/xe-anime-flux-v04-fp8-flux',
|
24 |
+
'John6666/lyh-anime-flux-v2a1-fp8-flux',
|
25 |
+
'John6666/carnival-unchained-v10-fp8-flux',
|
26 |
+
'cagliostrolab/animagine-xl-3.1',
|
27 |
+
'John6666/epicrealism-xl-v8kiss-sdxl',
|
28 |
+
'misri/epicrealismXL_v7FinalDestination',
|
29 |
+
'misri/juggernautXL_juggernautX',
|
30 |
+
'misri/zavychromaxl_v80',
|
31 |
+
'SG161222/RealVisXL_V4.0',
|
32 |
+
'SG161222/RealVisXL_V5.0',
|
33 |
+
'misri/newrealityxlAllInOne_Newreality40',
|
34 |
+
'eienmojiki/Anything-XL',
|
35 |
+
'eienmojiki/Starry-XL-v5.2',
|
36 |
+
'gsdf/CounterfeitXL',
|
37 |
+
'KBlueLeaf/Kohaku-XL-Zeta',
|
38 |
+
'John6666/silvermoon-mix-01xl-v11-sdxl',
|
39 |
+
'WhiteAiZ/autismmixSDXL_autismmixConfetti_diffusers',
|
40 |
+
'kitty7779/ponyDiffusionV6XL',
|
41 |
+
'GraydientPlatformAPI/aniverse-pony',
|
42 |
+
'John6666/ras-real-anime-screencap-v1-sdxl',
|
43 |
+
'John6666/duchaiten-pony-xl-no-score-v60-sdxl',
|
44 |
+
'John6666/mistoon-anime-ponyalpha-sdxl',
|
45 |
+
'John6666/3x3x3mixxl-v2-sdxl',
|
46 |
+
'John6666/3x3x3mixxl-3dv01-sdxl',
|
47 |
+
'John6666/ebara-mfcg-pony-mix-v12-sdxl',
|
48 |
+
'John6666/t-ponynai3-v51-sdxl',
|
49 |
+
'John6666/t-ponynai3-v65-sdxl',
|
50 |
+
'John6666/prefect-pony-xl-v3-sdxl',
|
51 |
+
'John6666/mala-anime-mix-nsfw-pony-xl-v5-sdxl',
|
52 |
+
'John6666/wai-real-mix-v11-sdxl',
|
53 |
+
'John6666/wai-c-v6-sdxl',
|
54 |
+
'John6666/iniverse-mix-xl-sfwnsfw-pony-guofeng-v43-sdxl',
|
55 |
+
'John6666/photo-realistic-pony-v5-sdxl',
|
56 |
+
'John6666/pony-realism-v21main-sdxl',
|
57 |
+
'John6666/pony-realism-v22main-sdxl',
|
58 |
+
'John6666/cyberrealistic-pony-v63-sdxl',
|
59 |
+
'John6666/cyberrealistic-pony-v64-sdxl',
|
60 |
+
'John6666/cyberrealistic-pony-v65-sdxl',
|
61 |
+
'GraydientPlatformAPI/realcartoon-pony-diffusion',
|
62 |
+
'John6666/nova-anime-xl-pony-v5-sdxl',
|
63 |
+
'John6666/autismmix-sdxl-autismmix-pony-sdxl',
|
64 |
+
'John6666/aimz-dream-real-pony-mix-v3-sdxl',
|
65 |
+
'John6666/duchaiten-pony-real-v11fix-sdxl',
|
66 |
+
'John6666/duchaiten-pony-real-v20-sdxl',
|
67 |
+
'yodayo-ai/kivotos-xl-2.0',
|
68 |
+
'yodayo-ai/holodayo-xl-2.1',
|
69 |
+
'yodayo-ai/clandestine-xl-1.0',
|
70 |
+
'digiplay/majicMIX_sombre_v2',
|
71 |
+
'digiplay/majicMIX_realistic_v6',
|
72 |
+
'digiplay/majicMIX_realistic_v7',
|
73 |
+
'digiplay/DreamShaper_8',
|
74 |
+
'digiplay/BeautifulArt_v1',
|
75 |
+
'digiplay/DarkSushi2.5D_v1',
|
76 |
+
'digiplay/darkphoenix3D_v1.1',
|
77 |
+
'digiplay/BeenYouLiteL11_diffusers',
|
78 |
+
'Yntec/RevAnimatedV2Rebirth',
|
79 |
+
'youknownothing/cyberrealistic_v50',
|
80 |
+
'youknownothing/deliberate-v6',
|
81 |
+
'GraydientPlatformAPI/deliberate-cyber3',
|
82 |
+
'GraydientPlatformAPI/picx-real',
|
83 |
+
'GraydientPlatformAPI/perfectworld6',
|
84 |
+
'emilianJR/epiCRealism',
|
85 |
+
'votepurchase/counterfeitV30_v30',
|
86 |
+
'votepurchase/ChilloutMix',
|
87 |
+
'Meina/MeinaMix_V11',
|
88 |
+
'Meina/MeinaUnreal_V5',
|
89 |
+
'Meina/MeinaPastel_V7',
|
90 |
+
'GraydientPlatformAPI/realcartoon3d-17',
|
91 |
+
'GraydientPlatformAPI/realcartoon-pixar11',
|
92 |
+
'GraydientPlatformAPI/realcartoon-real17',
|
93 |
+
]
|
94 |
+
|
95 |
+
DIFFUSERS_FORMAT_LORAS = [
|
96 |
+
"nerijs/animation2k-flux",
|
97 |
+
"XLabs-AI/flux-RealismLora",
|
98 |
+
]
|
99 |
+
|
100 |
+
DOWNLOAD_EMBEDS = [
|
101 |
+
'https://huggingface.co/datasets/Nerfgun3/bad_prompt/blob/main/bad_prompt_version2.pt',
|
102 |
+
'https://huggingface.co/embed/negative/resolve/main/EasyNegativeV2.safetensors',
|
103 |
+
'https://huggingface.co/embed/negative/resolve/main/bad-hands-5.pt',
|
104 |
+
]
|
105 |
+
|
106 |
+
CIVITAI_API_KEY = os.environ.get("CIVITAI_API_KEY")
|
107 |
+
HF_TOKEN = os.environ.get("HF_READ_TOKEN")
|
108 |
+
|
109 |
+
DIRECTORY_MODELS = 'models'
|
110 |
+
DIRECTORY_LORAS = 'loras'
|
111 |
+
DIRECTORY_VAES = 'vaes'
|
112 |
+
DIRECTORY_EMBEDS = 'embedings'
|
113 |
+
|
114 |
+
PREPROCESSOR_CONTROLNET = {
|
115 |
+
"openpose": [
|
116 |
+
"Openpose",
|
117 |
+
"None",
|
118 |
+
],
|
119 |
+
"scribble": [
|
120 |
+
"HED",
|
121 |
+
"PidiNet",
|
122 |
+
"None",
|
123 |
+
],
|
124 |
+
"softedge": [
|
125 |
+
"PidiNet",
|
126 |
+
"HED",
|
127 |
+
"HED safe",
|
128 |
+
"PidiNet safe",
|
129 |
+
"None",
|
130 |
+
],
|
131 |
+
"segmentation": [
|
132 |
+
"UPerNet",
|
133 |
+
"None",
|
134 |
+
],
|
135 |
+
"depth": [
|
136 |
+
"DPT",
|
137 |
+
"Midas",
|
138 |
+
"None",
|
139 |
+
],
|
140 |
+
"normalbae": [
|
141 |
+
"NormalBae",
|
142 |
+
"None",
|
143 |
+
],
|
144 |
+
"lineart": [
|
145 |
+
"Lineart",
|
146 |
+
"Lineart coarse",
|
147 |
+
"Lineart (anime)",
|
148 |
+
"None",
|
149 |
+
"None (anime)",
|
150 |
+
],
|
151 |
+
"lineart_anime": [
|
152 |
+
"Lineart",
|
153 |
+
"Lineart coarse",
|
154 |
+
"Lineart (anime)",
|
155 |
+
"None",
|
156 |
+
"None (anime)",
|
157 |
+
],
|
158 |
+
"shuffle": [
|
159 |
+
"ContentShuffle",
|
160 |
+
"None",
|
161 |
+
],
|
162 |
+
"canny": [
|
163 |
+
"Canny",
|
164 |
+
"None",
|
165 |
+
],
|
166 |
+
"mlsd": [
|
167 |
+
"MLSD",
|
168 |
+
"None",
|
169 |
+
],
|
170 |
+
"ip2p": [
|
171 |
+
"ip2p"
|
172 |
+
],
|
173 |
+
"recolor": [
|
174 |
+
"Recolor luminance",
|
175 |
+
"Recolor intensity",
|
176 |
+
"None",
|
177 |
+
],
|
178 |
+
"tile": [
|
179 |
+
"Mild Blur",
|
180 |
+
"Moderate Blur",
|
181 |
+
"Heavy Blur",
|
182 |
+
"None",
|
183 |
+
],
|
184 |
+
|
185 |
+
}
|
186 |
+
|
187 |
+
TASK_STABLEPY = {
|
188 |
+
'txt2img': 'txt2img',
|
189 |
+
'img2img': 'img2img',
|
190 |
+
'inpaint': 'inpaint',
|
191 |
+
# 'canny T2I Adapter': 'sdxl_canny_t2i', # NO HAVE STEP CALLBACK PARAMETERS SO NOT WORKS WITH DIFFUSERS 0.29.0
|
192 |
+
# 'sketch T2I Adapter': 'sdxl_sketch_t2i',
|
193 |
+
# 'lineart T2I Adapter': 'sdxl_lineart_t2i',
|
194 |
+
# 'depth-midas T2I Adapter': 'sdxl_depth-midas_t2i',
|
195 |
+
# 'openpose T2I Adapter': 'sdxl_openpose_t2i',
|
196 |
+
'openpose ControlNet': 'openpose',
|
197 |
+
'canny ControlNet': 'canny',
|
198 |
+
'mlsd ControlNet': 'mlsd',
|
199 |
+
'scribble ControlNet': 'scribble',
|
200 |
+
'softedge ControlNet': 'softedge',
|
201 |
+
'segmentation ControlNet': 'segmentation',
|
202 |
+
'depth ControlNet': 'depth',
|
203 |
+
'normalbae ControlNet': 'normalbae',
|
204 |
+
'lineart ControlNet': 'lineart',
|
205 |
+
'lineart_anime ControlNet': 'lineart_anime',
|
206 |
+
'shuffle ControlNet': 'shuffle',
|
207 |
+
'ip2p ControlNet': 'ip2p',
|
208 |
+
'optical pattern ControlNet': 'pattern',
|
209 |
+
'recolor ControlNet': 'recolor',
|
210 |
+
'tile ControlNet': 'tile',
|
211 |
+
}
|
212 |
+
|
213 |
+
TASK_MODEL_LIST = list(TASK_STABLEPY.keys())
|
214 |
+
|
215 |
+
UPSCALER_DICT_GUI = {
|
216 |
+
None: None,
|
217 |
+
"Lanczos": "Lanczos",
|
218 |
+
"Nearest": "Nearest",
|
219 |
+
'Latent': 'Latent',
|
220 |
+
'Latent (antialiased)': 'Latent (antialiased)',
|
221 |
+
'Latent (bicubic)': 'Latent (bicubic)',
|
222 |
+
'Latent (bicubic antialiased)': 'Latent (bicubic antialiased)',
|
223 |
+
'Latent (nearest)': 'Latent (nearest)',
|
224 |
+
'Latent (nearest-exact)': 'Latent (nearest-exact)',
|
225 |
+
"RealESRGAN_x4plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
226 |
+
"RealESRNet_x4plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth",
|
227 |
+
"RealESRGAN_x4plus_anime_6B": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
228 |
+
"RealESRGAN_x2plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
229 |
+
"realesr-animevideov3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
|
230 |
+
"realesr-general-x4v3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
|
231 |
+
"realesr-general-wdn-x4v3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
|
232 |
+
"4x-UltraSharp": "https://huggingface.co/Shandypur/ESRGAN-4x-UltraSharp/resolve/main/4x-UltraSharp.pth",
|
233 |
+
"4x_foolhardy_Remacri": "https://huggingface.co/FacehugmanIII/4x_foolhardy_Remacri/resolve/main/4x_foolhardy_Remacri.pth",
|
234 |
+
"Remacri4xExtraSmoother": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/Remacri%204x%20ExtraSmoother.pth",
|
235 |
+
"AnimeSharp4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/AnimeSharp%204x.pth",
|
236 |
+
"lollypop": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/lollypop.pth",
|
237 |
+
"RealisticRescaler4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/RealisticRescaler%204x.pth",
|
238 |
+
"NickelbackFS4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/NickelbackFS%204x.pth"
|
239 |
+
}
|
240 |
+
|
241 |
+
UPSCALER_KEYS = list(UPSCALER_DICT_GUI.keys())
|
242 |
+
|
243 |
+
PROMPT_W_OPTIONS = [
|
244 |
+
("Compel format: (word)weight", "Compel"),
|
245 |
+
("Classic format: (word:weight)", "Classic"),
|
246 |
+
("Classic-original format: (word:weight)", "Classic-original"),
|
247 |
+
("Classic-no_norm format: (word:weight)", "Classic-no_norm"),
|
248 |
+
("Classic-ignore", "Classic-ignore"),
|
249 |
+
("None", "None"),
|
250 |
+
]
|
251 |
+
|
252 |
+
WARNING_MSG_VAE = (
|
253 |
+
"Use the right VAE for your model to maintain image quality. The wrong"
|
254 |
+
" VAE can lead to poor results, like blurriness in the generated images."
|
255 |
+
)
|
256 |
+
|
257 |
+
SDXL_TASK = [k for k, v in TASK_STABLEPY.items() if v in SDXL_TASKS]
|
258 |
+
SD_TASK = [k for k, v in TASK_STABLEPY.items() if v in SD15_TASKS]
|
259 |
+
FLUX_TASK = list(TASK_STABLEPY.keys())[:3] + [k for k, v in TASK_STABLEPY.items() if v in FLUX_CN_UNION_MODES.keys()]
|
260 |
+
|
261 |
+
MODEL_TYPE_TASK = {
|
262 |
+
"SD 1.5": SD_TASK,
|
263 |
+
"SDXL": SDXL_TASK,
|
264 |
+
"FLUX": FLUX_TASK,
|
265 |
+
}
|
266 |
+
|
267 |
+
MODEL_TYPE_CLASS = {
|
268 |
+
"diffusers:StableDiffusionPipeline": "SD 1.5",
|
269 |
+
"diffusers:StableDiffusionXLPipeline": "SDXL",
|
270 |
+
"diffusers:FluxPipeline": "FLUX",
|
271 |
+
}
|
272 |
+
|
273 |
+
POST_PROCESSING_SAMPLER = ["Use same sampler"] + scheduler_names[:-2]
|
274 |
+
|
275 |
+
SUBTITLE_GUI = (
|
276 |
+
"### This demo uses [diffusers](https://github.com/huggingface/diffusers)"
|
277 |
+
" to perform different tasks in image generation."
|
278 |
+
)
|
279 |
+
|
280 |
+
HELP_GUI = (
|
281 |
+
"""### Help:
|
282 |
+
- The current space runs on a ZERO GPU which is assigned for approximately 60 seconds; Therefore, if you submit expensive tasks, the operation may be canceled upon reaching the maximum allowed time with 'GPU TASK ABORTED'.
|
283 |
+
- Distorted or strange images often result from high prompt weights, so it's best to use low weights and scales, and consider using Classic variants like 'Classic-original'.
|
284 |
+
- For better results with Pony Diffusion, try using sampler DPM++ 1s or DPM2 with Compel or Classic prompt weights.
|
285 |
+
"""
|
286 |
+
)
|
287 |
+
|
288 |
+
EXAMPLES_GUI_HELP = (
|
289 |
+
"""### The following examples perform specific tasks:
|
290 |
+
1. Generation with SDXL and upscale
|
291 |
+
2. Generation with FLUX dev
|
292 |
+
3. ControlNet Canny SDXL
|
293 |
+
4. Optical pattern (Optical illusion) SDXL
|
294 |
+
5. Convert an image to a coloring drawing
|
295 |
+
6. ControlNet OpenPose SD 1.5 and Latent upscale
|
296 |
+
|
297 |
+
- Different tasks can be performed, such as img2img or using the IP adapter, to preserve a person's appearance or a specific style based on an image.
|
298 |
+
"""
|
299 |
+
)
|
300 |
+
|
301 |
+
EXAMPLES_GUI = [
|
302 |
+
[
|
303 |
+
"1girl, souryuu asuka langley, neon genesis evangelion, rebuild of evangelion, lance of longinus, cat hat, plugsuit, pilot suit, red bodysuit, sitting, crossed legs, black eye patch, throne, looking down, from bottom, looking at viewer, outdoors, (masterpiece), (best quality), (ultra-detailed), very aesthetic, illustration, disheveled hair, perfect composition, moist skin, intricate details",
|
304 |
+
"nfsw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, unfinished, very displeasing, oldest, early, chromatic aberration, artistic error, scan, abstract",
|
305 |
+
28,
|
306 |
+
7.0,
|
307 |
+
-1,
|
308 |
+
"None",
|
309 |
+
0.33,
|
310 |
+
"Euler a",
|
311 |
+
1152,
|
312 |
+
896,
|
313 |
+
"cagliostrolab/animagine-xl-3.1",
|
314 |
+
"txt2img",
|
315 |
+
"image.webp", # img conttol
|
316 |
+
1024, # img resolution
|
317 |
+
0.35, # strength
|
318 |
+
1.0, # cn scale
|
319 |
+
0.0, # cn start
|
320 |
+
1.0, # cn end
|
321 |
+
"Classic",
|
322 |
+
"Nearest",
|
323 |
+
45,
|
324 |
+
False,
|
325 |
+
],
|
326 |
+
[
|
327 |
+
"a digital illustration of a movie poster titled 'Finding Emo', finding nemo parody poster, featuring a depressed cartoon clownfish with black emo hair, eyeliner, and piercings, bored expression, swimming in a dark underwater scene, in the background, movie title in a dripping, grungy font, moody blue and purple color palette",
|
328 |
+
"",
|
329 |
+
24,
|
330 |
+
3.5,
|
331 |
+
-1,
|
332 |
+
"None",
|
333 |
+
0.33,
|
334 |
+
"Euler a",
|
335 |
+
1152,
|
336 |
+
896,
|
337 |
+
"black-forest-labs/FLUX.1-dev",
|
338 |
+
"txt2img",
|
339 |
+
None, # img conttol
|
340 |
+
1024, # img resolution
|
341 |
+
0.35, # strength
|
342 |
+
1.0, # cn scale
|
343 |
+
0.0, # cn start
|
344 |
+
1.0, # cn end
|
345 |
+
"Classic",
|
346 |
+
None,
|
347 |
+
70,
|
348 |
+
True,
|
349 |
+
],
|
350 |
+
[
|
351 |
+
"((masterpiece)), best quality, blonde disco girl, detailed face, realistic face, realistic hair, dynamic pose, pink pvc, intergalactic disco background, pastel lights, dynamic contrast, airbrush, fine detail, 70s vibe, midriff",
|
352 |
+
"(worst quality:1.2), (bad quality:1.2), (poor quality:1.2), (missing fingers:1.2), bad-artist-anime, bad-artist, bad-picture-chill-75v",
|
353 |
+
48,
|
354 |
+
3.5,
|
355 |
+
-1,
|
356 |
+
"None",
|
357 |
+
0.33,
|
358 |
+
"DPM++ 2M SDE Lu",
|
359 |
+
1024,
|
360 |
+
1024,
|
361 |
+
"misri/epicrealismXL_v7FinalDestination",
|
362 |
+
"canny ControlNet",
|
363 |
+
"image.webp", # img conttol
|
364 |
+
1024, # img resolution
|
365 |
+
0.35, # strength
|
366 |
+
1.0, # cn scale
|
367 |
+
0.0, # cn start
|
368 |
+
1.0, # cn end
|
369 |
+
"Classic",
|
370 |
+
None,
|
371 |
+
44,
|
372 |
+
False,
|
373 |
+
],
|
374 |
+
[
|
375 |
+
"cinematic scenery old city ruins",
|
376 |
+
"(worst quality, low quality, illustration, 3d, 2d, painting, cartoons, sketch), (illustration, 3d, 2d, painting, cartoons, sketch, blurry, film grain, noise), (low quality, worst quality:1.2)",
|
377 |
+
50,
|
378 |
+
4.0,
|
379 |
+
-1,
|
380 |
+
"None",
|
381 |
+
0.33,
|
382 |
+
"Euler a",
|
383 |
+
1024,
|
384 |
+
1024,
|
385 |
+
"misri/juggernautXL_juggernautX",
|
386 |
+
"optical pattern ControlNet",
|
387 |
+
"spiral_no_transparent.png", # img conttol
|
388 |
+
1024, # img resolution
|
389 |
+
0.35, # strength
|
390 |
+
1.0, # cn scale
|
391 |
+
0.05, # cn start
|
392 |
+
0.75, # cn end
|
393 |
+
"Classic",
|
394 |
+
None,
|
395 |
+
35,
|
396 |
+
False,
|
397 |
+
],
|
398 |
+
[
|
399 |
+
"black and white, line art, coloring drawing, clean line art, black strokes, no background, white, black, free lines, black scribbles, on paper, A blend of comic book art and lineart full of black and white color, masterpiece, high-resolution, trending on Pixiv fan box, palette knife, brush strokes, two-dimensional, planar vector, T-shirt design, stickers, and T-shirt design, vector art, fantasy art, Adobe Illustrator, hand-painted, digital painting, low polygon, soft lighting, aerial view, isometric style, retro aesthetics, 8K resolution, black sketch lines, monochrome, invert color",
|
400 |
+
"color, red, green, yellow, colored, duplicate, blurry, abstract, disfigured, deformed, animated, toy, figure, framed, 3d, bad art, poorly drawn, extra limbs, close up, b&w, weird colors, blurry, watermark, blur haze, 2 heads, long neck, watermark, elongated body, cropped image, out of frame, draft, deformed hands, twisted fingers, double image, malformed hands, multiple heads, extra limb, ugly, poorly drawn hands, missing limb, cut-off, over satured, grain, lowères, bad anatomy, poorly drawn face, mutation, mutated, floating limbs, disconnected limbs, out of focus, long body, disgusting, extra fingers, groos proportions, missing arms, mutated hands, cloned face, missing legs, ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, signature, cut off, draft, deformed, blurry, bad anatomy, disfigured, poorly drawn face, mutation, bluelish, blue",
|
401 |
+
20,
|
402 |
+
4.0,
|
403 |
+
-1,
|
404 |
+
"loras/Coloring_book_-_LineArt.safetensors",
|
405 |
+
1.0,
|
406 |
+
"DPM++ 2M SDE Karras",
|
407 |
+
1024,
|
408 |
+
1024,
|
409 |
+
"cagliostrolab/animagine-xl-3.1",
|
410 |
+
"lineart ControlNet",
|
411 |
+
"color_image.png", # img conttol
|
412 |
+
896, # img resolution
|
413 |
+
0.35, # strength
|
414 |
+
1.0, # cn scale
|
415 |
+
0.0, # cn start
|
416 |
+
1.0, # cn end
|
417 |
+
"Compel",
|
418 |
+
None,
|
419 |
+
35,
|
420 |
+
False,
|
421 |
+
],
|
422 |
+
[
|
423 |
+
"1girl,face,curly hair,red hair,white background,",
|
424 |
+
"(worst quality:2),(low quality:2),(normal quality:2),lowres,watermark,",
|
425 |
+
38,
|
426 |
+
5.0,
|
427 |
+
-1,
|
428 |
+
"None",
|
429 |
+
0.33,
|
430 |
+
"DPM++ 2M SDE Karras",
|
431 |
+
512,
|
432 |
+
512,
|
433 |
+
"digiplay/majicMIX_realistic_v7",
|
434 |
+
"openpose ControlNet",
|
435 |
+
"image.webp", # img conttol
|
436 |
+
1024, # img resolution
|
437 |
+
0.35, # strength
|
438 |
+
1.0, # cn scale
|
439 |
+
0.0, # cn start
|
440 |
+
0.9, # cn end
|
441 |
+
"Compel",
|
442 |
+
"Latent (antialiased)",
|
443 |
+
46,
|
444 |
+
False,
|
445 |
+
],
|
446 |
+
]
|
447 |
+
|
448 |
+
RESOURCES = (
|
449 |
+
"""### Resources
|
450 |
+
- John6666's space has some great features you might find helpful [link](https://huggingface.co/spaces/John6666/DiffuseCraftMod).
|
451 |
+
- You can also try the image generator in Colab’s free tier, which provides free GPU [link](https://github.com/R3gm/SD_diffusers_interactive).
|
452 |
+
"""
|
453 |
+
)
|
env.py
CHANGED
@@ -2,10 +2,10 @@ import os
|
|
2 |
|
3 |
CIVITAI_API_KEY = os.environ.get("CIVITAI_API_KEY")
|
4 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
5 |
-
|
6 |
|
7 |
# - **List Models**
|
8 |
-
|
9 |
'stabilityai/stable-diffusion-xl-base-1.0',
|
10 |
'John6666/blue-pencil-flux1-v021-fp8-flux',
|
11 |
'John6666/wai-ani-flux-v10forfp8-fp8-flux',
|
@@ -106,11 +106,11 @@ HF_MODEL_USER_EX = ["John6666"] # sorted by a special rule
|
|
106 |
|
107 |
|
108 |
# - **Download Models**
|
109 |
-
|
110 |
]
|
111 |
|
112 |
# - **Download VAEs**
|
113 |
-
|
114 |
'https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/sdxl.vae.safetensors?download=true',
|
115 |
'https://huggingface.co/nubby/blessed-sdxl-vae-fp16-fix/resolve/main/sdxl_vae-fp16fix-c-1.1-b-0.5.safetensors?download=true',
|
116 |
'https://huggingface.co/nubby/blessed-sdxl-vae-fp16-fix/resolve/main/sdxl_vae-fp16fix-blessed.safetensors?download=true',
|
@@ -119,29 +119,26 @@ download_vae_list = [
|
|
119 |
]
|
120 |
|
121 |
# - **Download LoRAs**
|
122 |
-
|
123 |
]
|
124 |
|
125 |
# Download Embeddings
|
126 |
-
|
127 |
'https://huggingface.co/datasets/Nerfgun3/bad_prompt/blob/main/bad_prompt_version2.pt',
|
128 |
'https://huggingface.co/embed/negative/resolve/main/EasyNegativeV2.safetensors',
|
129 |
'https://huggingface.co/embed/negative/resolve/main/bad-hands-5.pt',
|
130 |
]
|
131 |
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
directory_embeds = 'embedings'
|
139 |
-
os.makedirs(directory_embeds, exist_ok=True)
|
140 |
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
os.makedirs(directory_embeds_positive_sdxl, exist_ok=True)
|
145 |
|
146 |
HF_LORA_PRIVATE_REPOS1 = ['John6666/loratest1', 'John6666/loratest3', 'John6666/loratest4', 'John6666/loratest6']
|
147 |
HF_LORA_PRIVATE_REPOS2 = ['John6666/loratest10', 'John6666/loratest11','John6666/loratest'] # to be sorted as 1 repo
|
|
|
2 |
|
3 |
CIVITAI_API_KEY = os.environ.get("CIVITAI_API_KEY")
|
4 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
5 |
+
HF_READ_TOKEN = os.environ.get('HF_READ_TOKEN') # only use for private repo
|
6 |
|
7 |
# - **List Models**
|
8 |
+
LOAD_DIFFUSERS_FORMAT_MODEL = [
|
9 |
'stabilityai/stable-diffusion-xl-base-1.0',
|
10 |
'John6666/blue-pencil-flux1-v021-fp8-flux',
|
11 |
'John6666/wai-ani-flux-v10forfp8-fp8-flux',
|
|
|
106 |
|
107 |
|
108 |
# - **Download Models**
|
109 |
+
DOWNLOAD_MODEL_LIST = [
|
110 |
]
|
111 |
|
112 |
# - **Download VAEs**
|
113 |
+
DOWNLOAD_VAE_LIST = [
|
114 |
'https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/sdxl.vae.safetensors?download=true',
|
115 |
'https://huggingface.co/nubby/blessed-sdxl-vae-fp16-fix/resolve/main/sdxl_vae-fp16fix-c-1.1-b-0.5.safetensors?download=true',
|
116 |
'https://huggingface.co/nubby/blessed-sdxl-vae-fp16-fix/resolve/main/sdxl_vae-fp16fix-blessed.safetensors?download=true',
|
|
|
119 |
]
|
120 |
|
121 |
# - **Download LoRAs**
|
122 |
+
DOWNLOAD_LORA_LIST = [
|
123 |
]
|
124 |
|
125 |
# Download Embeddings
|
126 |
+
DOWNLOAD_EMBEDS = [
|
127 |
'https://huggingface.co/datasets/Nerfgun3/bad_prompt/blob/main/bad_prompt_version2.pt',
|
128 |
'https://huggingface.co/embed/negative/resolve/main/EasyNegativeV2.safetensors',
|
129 |
'https://huggingface.co/embed/negative/resolve/main/bad-hands-5.pt',
|
130 |
]
|
131 |
|
132 |
+
DIRECTORY_MODELS = 'models'
|
133 |
+
DIRECTORY_LORAS = 'loras'
|
134 |
+
DIRECTORY_VAES = 'vaes'
|
135 |
+
DIRECTORY_EMBEDS = 'embedings'
|
136 |
+
DIRECTORY_EMBEDS_SDXL = 'embedings_xl'
|
137 |
+
DIRECTORY_EMBEDS_POSITIVE_SDXL = 'embedings_xl/positive'
|
|
|
|
|
138 |
|
139 |
+
directories = [DIRECTORY_MODELS, DIRECTORY_LORAS, DIRECTORY_VAES, DIRECTORY_EMBEDS, DIRECTORY_EMBEDS_SDXL, DIRECTORY_EMBEDS_POSITIVE_SDXL]
|
140 |
+
for directory in directories:
|
141 |
+
os.makedirs(directory, exist_ok=True)
|
|
|
142 |
|
143 |
HF_LORA_PRIVATE_REPOS1 = ['John6666/loratest1', 'John6666/loratest3', 'John6666/loratest4', 'John6666/loratest6']
|
144 |
HF_LORA_PRIVATE_REPOS2 = ['John6666/loratest10', 'John6666/loratest11','John6666/loratest'] # to be sorted as 1 repo
|
modutils.py
CHANGED
@@ -5,6 +5,7 @@ import os
|
|
5 |
import re
|
6 |
from pathlib import Path
|
7 |
from PIL import Image
|
|
|
8 |
import shutil
|
9 |
import requests
|
10 |
from requests.adapters import HTTPAdapter
|
@@ -12,11 +13,16 @@ from urllib3.util import Retry
|
|
12 |
import urllib.parse
|
13 |
import pandas as pd
|
14 |
from huggingface_hub import HfApi, HfFolder, hf_hub_download, snapshot_download
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
|
17 |
from env import (HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2,
|
18 |
HF_MODEL_USER_EX, HF_MODEL_USER_LIKES, DIFFUSERS_FORMAT_LORAS,
|
19 |
-
|
20 |
|
21 |
|
22 |
MODEL_TYPE_DICT = {
|
@@ -46,7 +52,6 @@ def is_repo_name(s):
|
|
46 |
return re.fullmatch(r'^[^/]+?/[^/]+?$', s)
|
47 |
|
48 |
|
49 |
-
from translatepy import Translator
|
50 |
translator = Translator()
|
51 |
def translate_to_en(input: str):
|
52 |
try:
|
@@ -64,6 +69,7 @@ def get_local_model_list(dir_path):
|
|
64 |
if file.suffix in valid_extensions:
|
65 |
file_path = str(Path(f"{dir_path}/{file.name}"))
|
66 |
model_list.append(file_path)
|
|
|
67 |
return model_list
|
68 |
|
69 |
|
@@ -98,21 +104,81 @@ def split_hf_url(url: str):
|
|
98 |
print(e)
|
99 |
|
100 |
|
101 |
-
def download_hf_file(directory, url, progress=gr.Progress(track_tqdm=True)):
|
102 |
-
hf_token = get_token()
|
103 |
repo_id, filename, subfolder, repo_type = split_hf_url(url)
|
|
|
|
|
|
|
104 |
try:
|
105 |
-
print(f"
|
106 |
-
|
107 |
-
else: path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type=repo_type, local_dir=directory, token=hf_token)
|
108 |
return path
|
109 |
except Exception as e:
|
110 |
-
print(f"
|
111 |
return None
|
112 |
|
113 |
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
url = url.strip()
|
|
|
|
|
116 |
if "drive.google.com" in url:
|
117 |
original_dir = os.getcwd()
|
118 |
os.chdir(directory)
|
@@ -123,18 +189,48 @@ def download_things(directory, url, hf_token="", civitai_api_key=""):
|
|
123 |
# url = urllib.parse.quote(url, safe=':/') # fix encoding
|
124 |
if "/blob/" in url:
|
125 |
url = url.replace("/blob/", "/resolve/")
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
elif "civitai.com" in url:
|
128 |
-
|
129 |
-
|
130 |
-
if civitai_api_key:
|
131 |
-
url = url + f"?token={civitai_api_key}"
|
132 |
-
os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
|
133 |
-
else:
|
134 |
print("\033[91mYou need an API key to download Civitai models.\033[0m")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
else:
|
136 |
os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
|
137 |
|
|
|
|
|
138 |
|
139 |
def get_download_file(temp_dir, url, civitai_key="", progress=gr.Progress(track_tqdm=True)):
|
140 |
if not "http" in url and is_repo_name(url) and not Path(url).exists():
|
@@ -173,7 +269,7 @@ def to_lora_key(path: str):
|
|
173 |
|
174 |
def to_lora_path(key: str):
|
175 |
if Path(key).is_file(): return key
|
176 |
-
path = Path(f"{
|
177 |
return str(path)
|
178 |
|
179 |
|
@@ -203,25 +299,21 @@ def save_images(images: list[Image.Image], metadatas: list[str]):
|
|
203 |
raise Exception(f"Failed to save image file:") from e
|
204 |
|
205 |
|
206 |
-
def save_gallery_images(images, progress=gr.Progress(track_tqdm=True)):
|
207 |
-
from datetime import datetime, timezone, timedelta
|
208 |
progress(0, desc="Updating gallery...")
|
209 |
-
|
210 |
-
|
211 |
-
i = 1
|
212 |
-
if not images: return images, gr.update(visible=False)
|
213 |
output_images = []
|
214 |
output_paths = []
|
215 |
-
for image in images:
|
216 |
-
filename = basename
|
217 |
-
i += 1
|
218 |
oldpath = Path(image[0])
|
219 |
newpath = oldpath
|
220 |
try:
|
221 |
if oldpath.exists():
|
222 |
newpath = oldpath.resolve().rename(Path(filename).resolve())
|
223 |
except Exception as e:
|
224 |
-
|
225 |
finally:
|
226 |
output_paths.append(str(newpath))
|
227 |
output_images.append((str(newpath), str(filename)))
|
@@ -229,10 +321,47 @@ def save_gallery_images(images, progress=gr.Progress(track_tqdm=True)):
|
|
229 |
return gr.update(value=output_images), gr.update(value=output_paths, visible=True)
|
230 |
|
231 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
def download_private_repo(repo_id, dir_path, is_replace):
|
233 |
-
if not
|
234 |
try:
|
235 |
-
snapshot_download(repo_id=repo_id, local_dir=dir_path, allow_patterns=['*.ckpt', '*.pt', '*.pth', '*.safetensors', '*.bin'],
|
236 |
except Exception as e:
|
237 |
print(f"Error: Failed to download {repo_id}.")
|
238 |
print(e)
|
@@ -250,9 +379,9 @@ private_model_path_repo_dict = {} # {"local filepath": "huggingface repo_id", ..
|
|
250 |
def get_private_model_list(repo_id, dir_path):
|
251 |
global private_model_path_repo_dict
|
252 |
api = HfApi()
|
253 |
-
if not
|
254 |
try:
|
255 |
-
files = api.list_repo_files(repo_id, token=
|
256 |
except Exception as e:
|
257 |
print(f"Error: Failed to list {repo_id}.")
|
258 |
print(e)
|
@@ -270,11 +399,11 @@ def get_private_model_list(repo_id, dir_path):
|
|
270 |
def download_private_file(repo_id, path, is_replace):
|
271 |
file = Path(path)
|
272 |
newpath = Path(f'{file.parent.name}/{escape_lora_basename(file.stem)}{file.suffix}') if is_replace else file
|
273 |
-
if not
|
274 |
filename = file.name
|
275 |
dirname = file.parent.name
|
276 |
try:
|
277 |
-
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=dirname,
|
278 |
except Exception as e:
|
279 |
print(f"Error: Failed to download {filename}.")
|
280 |
print(e)
|
@@ -404,9 +533,9 @@ def get_private_lora_model_lists():
|
|
404 |
models1 = []
|
405 |
models2 = []
|
406 |
for repo in HF_LORA_PRIVATE_REPOS1:
|
407 |
-
models1.extend(get_private_model_list(repo,
|
408 |
for repo in HF_LORA_PRIVATE_REPOS2:
|
409 |
-
models2.extend(get_private_model_list(repo,
|
410 |
models = list_uniq(models1 + sorted(models2))
|
411 |
private_lora_model_list = models.copy()
|
412 |
return models
|
@@ -451,7 +580,7 @@ def get_civitai_info(path):
|
|
451 |
|
452 |
|
453 |
def get_lora_model_list():
|
454 |
-
loras = list_uniq(get_private_lora_model_lists() + DIFFUSERS_FORMAT_LORAS + get_local_model_list(
|
455 |
loras.insert(0, "None")
|
456 |
loras.insert(0, "")
|
457 |
return loras
|
@@ -503,14 +632,14 @@ def update_lora_dict(path):
|
|
503 |
def download_lora(dl_urls: str):
|
504 |
global loras_url_to_path_dict
|
505 |
dl_path = ""
|
506 |
-
before = get_local_model_list(
|
507 |
urls = []
|
508 |
for url in [url.strip() for url in dl_urls.split(',')]:
|
509 |
-
local_path = f"{
|
510 |
if not Path(local_path).exists():
|
511 |
-
download_things(
|
512 |
urls.append(url)
|
513 |
-
after = get_local_model_list(
|
514 |
new_files = list_sub(after, before)
|
515 |
i = 0
|
516 |
for file in new_files:
|
@@ -761,12 +890,14 @@ def update_loras(prompt, prompt_syntax, lora1, lora1_wt, lora2, lora2_wt, lora3,
|
|
761 |
gr.update(value=tag5, label=label5, visible=on5), gr.update(visible=on5), gr.update(value=md5, visible=on5)
|
762 |
|
763 |
|
764 |
-
def get_my_lora(link_url):
|
765 |
-
|
|
|
|
|
766 |
for url in [url.strip() for url in link_url.split(',')]:
|
767 |
-
if not Path(f"{
|
768 |
-
download_things(
|
769 |
-
after = get_local_model_list(
|
770 |
new_files = list_sub(after, before)
|
771 |
for file in new_files:
|
772 |
path = Path(file)
|
@@ -774,11 +905,16 @@ def get_my_lora(link_url):
|
|
774 |
new_path = Path(f'{path.parent.name}/{escape_lora_basename(path.stem)}{path.suffix}')
|
775 |
path.resolve().rename(new_path.resolve())
|
776 |
update_lora_dict(str(new_path))
|
|
|
777 |
new_lora_model_list = get_lora_model_list()
|
778 |
new_lora_tupled_list = get_all_lora_tupled_list()
|
779 |
-
|
|
|
|
|
|
|
|
|
780 |
return gr.update(
|
781 |
-
choices=new_lora_tupled_list, value=
|
782 |
), gr.update(
|
783 |
choices=new_lora_tupled_list
|
784 |
), gr.update(
|
@@ -787,6 +923,8 @@ def get_my_lora(link_url):
|
|
787 |
choices=new_lora_tupled_list
|
788 |
), gr.update(
|
789 |
choices=new_lora_tupled_list
|
|
|
|
|
790 |
)
|
791 |
|
792 |
|
@@ -794,12 +932,12 @@ def upload_file_lora(files, progress=gr.Progress(track_tqdm=True)):
|
|
794 |
progress(0, desc="Uploading...")
|
795 |
file_paths = [file.name for file in files]
|
796 |
progress(1, desc="Uploaded.")
|
797 |
-
return gr.update(value=file_paths, visible=True), gr.update(
|
798 |
|
799 |
|
800 |
def move_file_lora(filepaths):
|
801 |
for file in filepaths:
|
802 |
-
path = Path(shutil.move(Path(file).resolve(), Path(f"./{
|
803 |
newpath = Path(f'{path.parent.name}/{escape_lora_basename(path.stem)}{path.suffix}')
|
804 |
path.resolve().rename(newpath.resolve())
|
805 |
update_lora_dict(str(newpath))
|
@@ -941,7 +1079,7 @@ def update_civitai_selection(evt: gr.SelectData):
|
|
941 |
selected = civitai_last_choices[selected_index][1]
|
942 |
return gr.update(value=selected)
|
943 |
except Exception:
|
944 |
-
return gr.update(
|
945 |
|
946 |
|
947 |
def select_civitai_lora(search_result):
|
@@ -1425,3 +1563,78 @@ def get_model_pipeline(repo_id: str):
|
|
1425 |
else:
|
1426 |
return default
|
1427 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
import re
|
6 |
from pathlib import Path
|
7 |
from PIL import Image
|
8 |
+
import numpy as np
|
9 |
import shutil
|
10 |
import requests
|
11 |
from requests.adapters import HTTPAdapter
|
|
|
13 |
import urllib.parse
|
14 |
import pandas as pd
|
15 |
from huggingface_hub import HfApi, HfFolder, hf_hub_download, snapshot_download
|
16 |
+
from translatepy import Translator
|
17 |
+
from unidecode import unidecode
|
18 |
+
import copy
|
19 |
+
from datetime import datetime, timezone, timedelta
|
20 |
+
FILENAME_TIMEZONE = timezone(timedelta(hours=9)) # JST
|
21 |
|
22 |
|
23 |
from env import (HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2,
|
24 |
HF_MODEL_USER_EX, HF_MODEL_USER_LIKES, DIFFUSERS_FORMAT_LORAS,
|
25 |
+
DIRECTORY_LORAS, HF_READ_TOKEN, HF_TOKEN, CIVITAI_API_KEY)
|
26 |
|
27 |
|
28 |
MODEL_TYPE_DICT = {
|
|
|
52 |
return re.fullmatch(r'^[^/]+?/[^/]+?$', s)
|
53 |
|
54 |
|
|
|
55 |
translator = Translator()
|
56 |
def translate_to_en(input: str):
|
57 |
try:
|
|
|
69 |
if file.suffix in valid_extensions:
|
70 |
file_path = str(Path(f"{dir_path}/{file.name}"))
|
71 |
model_list.append(file_path)
|
72 |
+
#print('\033[34mFILE: ' + file_path + '\033[0m')
|
73 |
return model_list
|
74 |
|
75 |
|
|
|
104 |
print(e)
|
105 |
|
106 |
|
107 |
+
def download_hf_file(directory, url, force_filename="", hf_token="", progress=gr.Progress(track_tqdm=True)):
|
|
|
108 |
repo_id, filename, subfolder, repo_type = split_hf_url(url)
|
109 |
+
kwargs = {}
|
110 |
+
if subfolder is not None: kwargs["subfolder"] = subfolder
|
111 |
+
if force_filename: kwargs["force_filename"] = force_filename
|
112 |
try:
|
113 |
+
print(f"Start downloading: {url} to {directory}")
|
114 |
+
path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type=repo_type, local_dir=directory, token=hf_token, **kwargs)
|
|
|
115 |
return path
|
116 |
except Exception as e:
|
117 |
+
print(f"Download failed: {url} {e}")
|
118 |
return None
|
119 |
|
120 |
|
121 |
+
USER_AGENT = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:127.0) Gecko/20100101 Firefox/127.0'
|
122 |
+
|
123 |
+
|
124 |
+
def request_json_data(url):
|
125 |
+
model_version_id = url.split('/')[-1]
|
126 |
+
if "?modelVersionId=" in model_version_id:
|
127 |
+
match = re.search(r'modelVersionId=(\d+)', url)
|
128 |
+
model_version_id = match.group(1)
|
129 |
+
|
130 |
+
endpoint_url = f"https://civitai.com/api/v1/model-versions/{model_version_id}"
|
131 |
+
|
132 |
+
params = {}
|
133 |
+
headers = {'User-Agent': USER_AGENT, 'content-type': 'application/json'}
|
134 |
+
session = requests.Session()
|
135 |
+
retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
|
136 |
+
session.mount("https://", HTTPAdapter(max_retries=retries))
|
137 |
+
|
138 |
+
try:
|
139 |
+
result = session.get(endpoint_url, params=params, headers=headers, stream=True, timeout=(3.0, 15))
|
140 |
+
result.raise_for_status()
|
141 |
+
json_data = result.json()
|
142 |
+
return json_data if json_data else None
|
143 |
+
except Exception as e:
|
144 |
+
print(f"Error: {e}")
|
145 |
+
return None
|
146 |
+
|
147 |
+
|
148 |
+
class ModelInformation:
|
149 |
+
def __init__(self, json_data):
|
150 |
+
self.model_version_id = json_data.get("id", "")
|
151 |
+
self.model_id = json_data.get("modelId", "")
|
152 |
+
self.download_url = json_data.get("downloadUrl", "")
|
153 |
+
self.model_url = f"https://civitai.com/models/{self.model_id}?modelVersionId={self.model_version_id}"
|
154 |
+
self.filename_url = next(
|
155 |
+
(v.get("name", "") for v in json_data.get("files", []) if str(self.model_version_id) in v.get("downloadUrl", "")), ""
|
156 |
+
)
|
157 |
+
self.filename_url = self.filename_url if self.filename_url else ""
|
158 |
+
self.description = json_data.get("description", "")
|
159 |
+
if self.description is None: self.description = ""
|
160 |
+
self.model_name = json_data.get("model", {}).get("name", "")
|
161 |
+
self.model_type = json_data.get("model", {}).get("type", "")
|
162 |
+
self.nsfw = json_data.get("model", {}).get("nsfw", False)
|
163 |
+
self.poi = json_data.get("model", {}).get("poi", False)
|
164 |
+
self.images = [img.get("url", "") for img in json_data.get("images", [])]
|
165 |
+
self.example_prompt = json_data.get("trainedWords", [""])[0] if json_data.get("trainedWords") else ""
|
166 |
+
self.original_json = copy.deepcopy(json_data)
|
167 |
+
|
168 |
+
|
169 |
+
def retrieve_model_info(url):
|
170 |
+
json_data = request_json_data(url)
|
171 |
+
if not json_data:
|
172 |
+
return None
|
173 |
+
model_descriptor = ModelInformation(json_data)
|
174 |
+
return model_descriptor
|
175 |
+
|
176 |
+
|
177 |
+
def download_things(directory, url, hf_token="", civitai_api_key="", romanize=False):
|
178 |
+
hf_token = get_token()
|
179 |
url = url.strip()
|
180 |
+
downloaded_file_path = None
|
181 |
+
|
182 |
if "drive.google.com" in url:
|
183 |
original_dir = os.getcwd()
|
184 |
os.chdir(directory)
|
|
|
189 |
# url = urllib.parse.quote(url, safe=':/') # fix encoding
|
190 |
if "/blob/" in url:
|
191 |
url = url.replace("/blob/", "/resolve/")
|
192 |
+
|
193 |
+
filename = unidecode(url.split('/')[-1]) if romanize else url.split('/')[-1]
|
194 |
+
|
195 |
+
download_hf_file(directory, url, filename, hf_token)
|
196 |
+
|
197 |
+
downloaded_file_path = os.path.join(directory, filename)
|
198 |
+
|
199 |
elif "civitai.com" in url:
|
200 |
+
|
201 |
+
if not civitai_api_key:
|
|
|
|
|
|
|
|
|
202 |
print("\033[91mYou need an API key to download Civitai models.\033[0m")
|
203 |
+
|
204 |
+
model_profile = retrieve_model_info(url)
|
205 |
+
if model_profile.download_url and model_profile.filename_url:
|
206 |
+
url = model_profile.download_url
|
207 |
+
filename = unidecode(model_profile.filename_url) if romanize else model_profile.filename_url
|
208 |
+
else:
|
209 |
+
if "?" in url:
|
210 |
+
url = url.split("?")[0]
|
211 |
+
filename = ""
|
212 |
+
|
213 |
+
url_dl = url + f"?token={civitai_api_key}"
|
214 |
+
print(f"Filename: {filename}")
|
215 |
+
|
216 |
+
param_filename = ""
|
217 |
+
if filename:
|
218 |
+
param_filename = f"-o '{filename}'"
|
219 |
+
|
220 |
+
aria2_command = (
|
221 |
+
f'aria2c --console-log-level=error --summary-interval=10 -c -x 16 '
|
222 |
+
f'-k 1M -s 16 -d "{directory}" {param_filename} "{url_dl}"'
|
223 |
+
)
|
224 |
+
os.system(aria2_command)
|
225 |
+
|
226 |
+
if param_filename and os.path.exists(os.path.join(directory, filename)):
|
227 |
+
downloaded_file_path = os.path.join(directory, filename)
|
228 |
+
|
229 |
else:
|
230 |
os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
|
231 |
|
232 |
+
return downloaded_file_path
|
233 |
+
|
234 |
|
235 |
def get_download_file(temp_dir, url, civitai_key="", progress=gr.Progress(track_tqdm=True)):
|
236 |
if not "http" in url and is_repo_name(url) and not Path(url).exists():
|
|
|
269 |
|
270 |
def to_lora_path(key: str):
|
271 |
if Path(key).is_file(): return key
|
272 |
+
path = Path(f"{DIRECTORY_LORAS}/{escape_lora_basename(key)}.safetensors")
|
273 |
return str(path)
|
274 |
|
275 |
|
|
|
299 |
raise Exception(f"Failed to save image file:") from e
|
300 |
|
301 |
|
302 |
+
def save_gallery_images(images, model_name="", progress=gr.Progress(track_tqdm=True)):
|
|
|
303 |
progress(0, desc="Updating gallery...")
|
304 |
+
basename = f"{model_name.split('/')[-1]}_{datetime.now(FILENAME_TIMEZONE).strftime('%Y%m%d_%H%M%S')}_"
|
305 |
+
if not images: return images, gr.update()
|
|
|
|
|
306 |
output_images = []
|
307 |
output_paths = []
|
308 |
+
for i, image in enumerate(images):
|
309 |
+
filename = f"{basename}{str(i + 1)}.png"
|
|
|
310 |
oldpath = Path(image[0])
|
311 |
newpath = oldpath
|
312 |
try:
|
313 |
if oldpath.exists():
|
314 |
newpath = oldpath.resolve().rename(Path(filename).resolve())
|
315 |
except Exception as e:
|
316 |
+
print(e)
|
317 |
finally:
|
318 |
output_paths.append(str(newpath))
|
319 |
output_images.append((str(newpath), str(filename)))
|
|
|
321 |
return gr.update(value=output_images), gr.update(value=output_paths, visible=True)
|
322 |
|
323 |
|
324 |
+
def save_gallery_history(images, files, history_gallery, history_files, progress=gr.Progress(track_tqdm=True)):
|
325 |
+
if not images or not files: return gr.update(), gr.update()
|
326 |
+
if not history_gallery: history_gallery = []
|
327 |
+
if not history_files: history_files = []
|
328 |
+
output_gallery = images + history_gallery
|
329 |
+
output_files = files + history_files
|
330 |
+
return gr.update(value=output_gallery), gr.update(value=output_files, visible=True)
|
331 |
+
|
332 |
+
|
333 |
+
def save_image_history(image, gallery, files, model_name: str, progress=gr.Progress(track_tqdm=True)):
|
334 |
+
if not gallery: gallery = []
|
335 |
+
if not files: files = []
|
336 |
+
try:
|
337 |
+
basename = f"{model_name.split('/')[-1]}_{datetime.now(FILENAME_TIMEZONE).strftime('%Y%m%d_%H%M%S')}"
|
338 |
+
if image is None or not isinstance(image, (str, Image.Image, np.ndarray, tuple)): return gr.update(), gr.update()
|
339 |
+
filename = f"{basename}.png"
|
340 |
+
if isinstance(image, tuple): image = image[0]
|
341 |
+
if isinstance(image, str): oldpath = image
|
342 |
+
elif isinstance(image, Image.Image):
|
343 |
+
oldpath = "temp.png"
|
344 |
+
image.save(oldpath)
|
345 |
+
elif isinstance(image, np.ndarray):
|
346 |
+
oldpath = "temp.png"
|
347 |
+
Image.fromarray(image).convert('RGBA').save(oldpath)
|
348 |
+
oldpath = Path(oldpath)
|
349 |
+
newpath = oldpath
|
350 |
+
if oldpath.exists():
|
351 |
+
shutil.copy(oldpath.resolve(), Path(filename).resolve())
|
352 |
+
newpath = Path(filename).resolve()
|
353 |
+
files.insert(0, str(newpath))
|
354 |
+
gallery.insert(0, (str(newpath), str(filename)))
|
355 |
+
except Exception as e:
|
356 |
+
print(e)
|
357 |
+
finally:
|
358 |
+
return gr.update(value=gallery), gr.update(value=files, visible=True)
|
359 |
+
|
360 |
+
|
361 |
def download_private_repo(repo_id, dir_path, is_replace):
|
362 |
+
if not HF_READ_TOKEN: return
|
363 |
try:
|
364 |
+
snapshot_download(repo_id=repo_id, local_dir=dir_path, allow_patterns=['*.ckpt', '*.pt', '*.pth', '*.safetensors', '*.bin'], token=HF_READ_TOKEN)
|
365 |
except Exception as e:
|
366 |
print(f"Error: Failed to download {repo_id}.")
|
367 |
print(e)
|
|
|
379 |
def get_private_model_list(repo_id, dir_path):
|
380 |
global private_model_path_repo_dict
|
381 |
api = HfApi()
|
382 |
+
if not HF_READ_TOKEN: return []
|
383 |
try:
|
384 |
+
files = api.list_repo_files(repo_id, token=HF_READ_TOKEN)
|
385 |
except Exception as e:
|
386 |
print(f"Error: Failed to list {repo_id}.")
|
387 |
print(e)
|
|
|
399 |
def download_private_file(repo_id, path, is_replace):
|
400 |
file = Path(path)
|
401 |
newpath = Path(f'{file.parent.name}/{escape_lora_basename(file.stem)}{file.suffix}') if is_replace else file
|
402 |
+
if not HF_READ_TOKEN or newpath.exists(): return
|
403 |
filename = file.name
|
404 |
dirname = file.parent.name
|
405 |
try:
|
406 |
+
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=dirname, token=HF_READ_TOKEN)
|
407 |
except Exception as e:
|
408 |
print(f"Error: Failed to download {filename}.")
|
409 |
print(e)
|
|
|
533 |
models1 = []
|
534 |
models2 = []
|
535 |
for repo in HF_LORA_PRIVATE_REPOS1:
|
536 |
+
models1.extend(get_private_model_list(repo, DIRECTORY_LORAS))
|
537 |
for repo in HF_LORA_PRIVATE_REPOS2:
|
538 |
+
models2.extend(get_private_model_list(repo, DIRECTORY_LORAS))
|
539 |
models = list_uniq(models1 + sorted(models2))
|
540 |
private_lora_model_list = models.copy()
|
541 |
return models
|
|
|
580 |
|
581 |
|
582 |
def get_lora_model_list():
|
583 |
+
loras = list_uniq(get_private_lora_model_lists() + DIFFUSERS_FORMAT_LORAS + get_local_model_list(DIRECTORY_LORAS))
|
584 |
loras.insert(0, "None")
|
585 |
loras.insert(0, "")
|
586 |
return loras
|
|
|
632 |
def download_lora(dl_urls: str):
|
633 |
global loras_url_to_path_dict
|
634 |
dl_path = ""
|
635 |
+
before = get_local_model_list(DIRECTORY_LORAS)
|
636 |
urls = []
|
637 |
for url in [url.strip() for url in dl_urls.split(',')]:
|
638 |
+
local_path = f"{DIRECTORY_LORAS}/{url.split('/')[-1]}"
|
639 |
if not Path(local_path).exists():
|
640 |
+
download_things(DIRECTORY_LORAS, url, HF_TOKEN, CIVITAI_API_KEY)
|
641 |
urls.append(url)
|
642 |
+
after = get_local_model_list(DIRECTORY_LORAS)
|
643 |
new_files = list_sub(after, before)
|
644 |
i = 0
|
645 |
for file in new_files:
|
|
|
890 |
gr.update(value=tag5, label=label5, visible=on5), gr.update(visible=on5), gr.update(value=md5, visible=on5)
|
891 |
|
892 |
|
893 |
+
def get_my_lora(link_url, romanize):
|
894 |
+
l_name = ""
|
895 |
+
l_path = ""
|
896 |
+
before = get_local_model_list(DIRECTORY_LORAS)
|
897 |
for url in [url.strip() for url in link_url.split(',')]:
|
898 |
+
if not Path(f"{DIRECTORY_LORAS}/{url.split('/')[-1]}").exists():
|
899 |
+
l_name = download_things(DIRECTORY_LORAS, url, HF_TOKEN, CIVITAI_API_KEY, romanize)
|
900 |
+
after = get_local_model_list(DIRECTORY_LORAS)
|
901 |
new_files = list_sub(after, before)
|
902 |
for file in new_files:
|
903 |
path = Path(file)
|
|
|
905 |
new_path = Path(f'{path.parent.name}/{escape_lora_basename(path.stem)}{path.suffix}')
|
906 |
path.resolve().rename(new_path.resolve())
|
907 |
update_lora_dict(str(new_path))
|
908 |
+
l_path = str(new_path)
|
909 |
new_lora_model_list = get_lora_model_list()
|
910 |
new_lora_tupled_list = get_all_lora_tupled_list()
|
911 |
+
msg_lora = "Downloaded"
|
912 |
+
if l_name:
|
913 |
+
msg_lora += f": <b>{l_name}</b>"
|
914 |
+
print(msg_lora)
|
915 |
+
|
916 |
return gr.update(
|
917 |
+
choices=new_lora_tupled_list, value=l_path
|
918 |
), gr.update(
|
919 |
choices=new_lora_tupled_list
|
920 |
), gr.update(
|
|
|
923 |
choices=new_lora_tupled_list
|
924 |
), gr.update(
|
925 |
choices=new_lora_tupled_list
|
926 |
+
), gr.update(
|
927 |
+
value=msg_lora
|
928 |
)
|
929 |
|
930 |
|
|
|
932 |
progress(0, desc="Uploading...")
|
933 |
file_paths = [file.name for file in files]
|
934 |
progress(1, desc="Uploaded.")
|
935 |
+
return gr.update(value=file_paths, visible=True), gr.update()
|
936 |
|
937 |
|
938 |
def move_file_lora(filepaths):
|
939 |
for file in filepaths:
|
940 |
+
path = Path(shutil.move(Path(file).resolve(), Path(f"./{DIRECTORY_LORAS}").resolve()))
|
941 |
newpath = Path(f'{path.parent.name}/{escape_lora_basename(path.stem)}{path.suffix}')
|
942 |
path.resolve().rename(newpath.resolve())
|
943 |
update_lora_dict(str(newpath))
|
|
|
1079 |
selected = civitai_last_choices[selected_index][1]
|
1080 |
return gr.update(value=selected)
|
1081 |
except Exception:
|
1082 |
+
return gr.update()
|
1083 |
|
1084 |
|
1085 |
def select_civitai_lora(search_result):
|
|
|
1563 |
else:
|
1564 |
return default
|
1565 |
|
1566 |
+
|
1567 |
+
EXAMPLES_GUI = [
|
1568 |
+
[
|
1569 |
+
"1girl, souryuu asuka langley, neon genesis evangelion, plugsuit, pilot suit, red bodysuit, sitting, crossing legs, black eye patch, cat hat, throne, symmetrical, looking down, from bottom, looking at viewer, outdoors, masterpiece, best quality, very aesthetic, absurdres",
|
1570 |
+
"nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
|
1571 |
+
1,
|
1572 |
+
30,
|
1573 |
+
7.5,
|
1574 |
+
True,
|
1575 |
+
-1,
|
1576 |
+
"Euler a",
|
1577 |
+
1152,
|
1578 |
+
896,
|
1579 |
+
"votepurchase/animagine-xl-3.1",
|
1580 |
+
],
|
1581 |
+
[
|
1582 |
+
"solo, princess Zelda OOT, score_9, score_8_up, score_8, medium breasts, cute, eyelashes, cute small face, long hair, crown braid, hairclip, pointy ears, soft curvy body, looking at viewer, smile, blush, white dress, medium body, (((holding the Master Sword))), standing, deep forest in the background",
|
1583 |
+
"score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white,",
|
1584 |
+
1,
|
1585 |
+
30,
|
1586 |
+
5.,
|
1587 |
+
True,
|
1588 |
+
-1,
|
1589 |
+
"Euler a",
|
1590 |
+
1024,
|
1591 |
+
1024,
|
1592 |
+
"votepurchase/ponyDiffusionV6XL",
|
1593 |
+
],
|
1594 |
+
[
|
1595 |
+
"1girl, oomuro sakurako, yuru yuri, official art, school uniform, anime artwork, anime style, vibrant, studio anime, highly detailed, masterpiece, best quality, very aesthetic, absurdres",
|
1596 |
+
"photo, deformed, black and white, realism, disfigured, low contrast, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
|
1597 |
+
1,
|
1598 |
+
40,
|
1599 |
+
7.0,
|
1600 |
+
True,
|
1601 |
+
-1,
|
1602 |
+
"Euler a",
|
1603 |
+
1024,
|
1604 |
+
1024,
|
1605 |
+
"Raelina/Rae-Diffusion-XL-V2",
|
1606 |
+
],
|
1607 |
+
[
|
1608 |
+
"1girl, akaza akari, yuru yuri, official art, anime artwork, anime style, vibrant, studio anime, highly detailed, masterpiece, best quality, very aesthetic, absurdres",
|
1609 |
+
"photo, deformed, black and white, realism, disfigured, low contrast, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
|
1610 |
+
1,
|
1611 |
+
35,
|
1612 |
+
7.0,
|
1613 |
+
True,
|
1614 |
+
-1,
|
1615 |
+
"Euler a",
|
1616 |
+
1024,
|
1617 |
+
1024,
|
1618 |
+
"Raelina/Raemu-XL-V4",
|
1619 |
+
],
|
1620 |
+
[
|
1621 |
+
"yoshida yuuko, machikado mazoku, 1girl, solo, demon horns,horns, school uniform, long hair, open mouth, skirt, demon girl, ahoge, shiny, shiny hair, anime artwork",
|
1622 |
+
"nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
|
1623 |
+
1,
|
1624 |
+
50,
|
1625 |
+
7.,
|
1626 |
+
True,
|
1627 |
+
-1,
|
1628 |
+
"Euler a",
|
1629 |
+
1024,
|
1630 |
+
1024,
|
1631 |
+
"cagliostrolab/animagine-xl-3.1",
|
1632 |
+
],
|
1633 |
+
]
|
1634 |
+
|
1635 |
+
|
1636 |
+
RESOURCES = (
|
1637 |
+
"""### Resources
|
1638 |
+
- You can also try the image generator in Colab’s free tier, which provides free GPU [link](https://github.com/R3gm/SD_diffusers_interactive).
|
1639 |
+
"""
|
1640 |
+
)
|
requirements.txt
CHANGED
@@ -12,3 +12,4 @@ translatepy
|
|
12 |
timm
|
13 |
rapidfuzz
|
14 |
sentencepiece
|
|
|
|
12 |
timm
|
13 |
rapidfuzz
|
14 |
sentencepiece
|
15 |
+
unidecode
|
tagger/character_series_dict.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tagger/danbooru_e621.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tagger/output.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class UpsamplingOutput:
|
6 |
+
upsampled_tags: str
|
7 |
+
|
8 |
+
copyright_tags: str
|
9 |
+
character_tags: str
|
10 |
+
general_tags: str
|
11 |
+
rating_tag: str
|
12 |
+
aspect_ratio_tag: str
|
13 |
+
length_tag: str
|
14 |
+
identity_tag: str
|
15 |
+
|
16 |
+
elapsed_time: float = 0.0
|
tagger/tag_group.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tagger/tagger.py
ADDED
@@ -0,0 +1,556 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
+
from PIL import Image
|
3 |
+
import torch
|
4 |
+
import gradio as gr
|
5 |
+
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
|
9 |
+
WD_MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
|
10 |
+
WD_MODEL_NAME = WD_MODEL_NAMES[0]
|
11 |
+
|
12 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
13 |
+
default_device = device
|
14 |
+
|
15 |
+
try:
|
16 |
+
wd_model = AutoModelForImageClassification.from_pretrained(WD_MODEL_NAME, trust_remote_code=True).to(default_device).eval()
|
17 |
+
wd_processor = AutoImageProcessor.from_pretrained(WD_MODEL_NAME, trust_remote_code=True)
|
18 |
+
except Exception as e:
|
19 |
+
print(e)
|
20 |
+
wd_model = wd_processor = None
|
21 |
+
|
22 |
+
def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
|
23 |
+
return (
|
24 |
+
[f"1{noun}"]
|
25 |
+
+ [f"{num}{noun}s" for num in range(minimum + 1, maximum + 1)]
|
26 |
+
+ [f"{maximum+1}+{noun}s"]
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
PEOPLE_TAGS = (
|
31 |
+
_people_tag("girl") + _people_tag("boy") + _people_tag("other") + ["no humans"]
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
RATING_MAP = {
|
36 |
+
"sfw": "safe",
|
37 |
+
"general": "safe",
|
38 |
+
"sensitive": "sensitive",
|
39 |
+
"questionable": "nsfw",
|
40 |
+
"explicit": "explicit, nsfw",
|
41 |
+
}
|
42 |
+
DANBOORU_TO_E621_RATING_MAP = {
|
43 |
+
"sfw": "rating_safe",
|
44 |
+
"general": "rating_safe",
|
45 |
+
"safe": "rating_safe",
|
46 |
+
"sensitive": "rating_safe",
|
47 |
+
"nsfw": "rating_explicit",
|
48 |
+
"explicit, nsfw": "rating_explicit",
|
49 |
+
"explicit": "rating_explicit",
|
50 |
+
"rating:safe": "rating_safe",
|
51 |
+
"rating:general": "rating_safe",
|
52 |
+
"rating:sensitive": "rating_safe",
|
53 |
+
"rating:questionable, nsfw": "rating_explicit",
|
54 |
+
"rating:explicit, nsfw": "rating_explicit",
|
55 |
+
}
|
56 |
+
|
57 |
+
|
58 |
+
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
|
59 |
+
kaomojis = [
|
60 |
+
"0_0",
|
61 |
+
"(o)_(o)",
|
62 |
+
"+_+",
|
63 |
+
"+_-",
|
64 |
+
"._.",
|
65 |
+
"<o>_<o>",
|
66 |
+
"<|>_<|>",
|
67 |
+
"=_=",
|
68 |
+
">_<",
|
69 |
+
"3_3",
|
70 |
+
"6_9",
|
71 |
+
">_o",
|
72 |
+
"@_@",
|
73 |
+
"^_^",
|
74 |
+
"o_o",
|
75 |
+
"u_u",
|
76 |
+
"x_x",
|
77 |
+
"|_|",
|
78 |
+
"||_||",
|
79 |
+
]
|
80 |
+
|
81 |
+
|
82 |
+
def replace_underline(x: str):
|
83 |
+
return x.strip().replace("_", " ") if x not in kaomojis else x.strip()
|
84 |
+
|
85 |
+
|
86 |
+
def to_list(s):
|
87 |
+
return [x.strip() for x in s.split(",") if not s == ""]
|
88 |
+
|
89 |
+
|
90 |
+
def list_sub(a, b):
|
91 |
+
return [e for e in a if e not in b]
|
92 |
+
|
93 |
+
|
94 |
+
def list_uniq(l):
|
95 |
+
return sorted(set(l), key=l.index)
|
96 |
+
|
97 |
+
|
98 |
+
def load_dict_from_csv(filename):
|
99 |
+
dict = {}
|
100 |
+
if not Path(filename).exists():
|
101 |
+
if Path('./tagger/', filename).exists(): filename = str(Path('./tagger/', filename))
|
102 |
+
else: return dict
|
103 |
+
try:
|
104 |
+
with open(filename, 'r', encoding="utf-8") as f:
|
105 |
+
lines = f.readlines()
|
106 |
+
except Exception:
|
107 |
+
print(f"Failed to open dictionary file: {filename}")
|
108 |
+
return dict
|
109 |
+
for line in lines:
|
110 |
+
parts = line.strip().split(',')
|
111 |
+
dict[parts[0]] = parts[1]
|
112 |
+
return dict
|
113 |
+
|
114 |
+
|
115 |
+
anime_series_dict = load_dict_from_csv('character_series_dict.csv')
|
116 |
+
|
117 |
+
|
118 |
+
def character_list_to_series_list(character_list):
|
119 |
+
output_series_tag = []
|
120 |
+
series_tag = ""
|
121 |
+
series_dict = anime_series_dict
|
122 |
+
for tag in character_list:
|
123 |
+
series_tag = series_dict.get(tag, "")
|
124 |
+
if tag.endswith(")"):
|
125 |
+
tags = tag.split("(")
|
126 |
+
character_tag = "(".join(tags[:-1])
|
127 |
+
if character_tag.endswith(" "):
|
128 |
+
character_tag = character_tag[:-1]
|
129 |
+
series_tag = tags[-1].replace(")", "")
|
130 |
+
|
131 |
+
if series_tag:
|
132 |
+
output_series_tag.append(series_tag)
|
133 |
+
|
134 |
+
return output_series_tag
|
135 |
+
|
136 |
+
|
137 |
+
def select_random_character(series: str, character: str):
|
138 |
+
from random import seed, randrange
|
139 |
+
seed()
|
140 |
+
character_list = list(anime_series_dict.keys())
|
141 |
+
character = character_list[randrange(len(character_list) - 1)]
|
142 |
+
series = anime_series_dict.get(character.split(",")[0].strip(), "")
|
143 |
+
return series, character
|
144 |
+
|
145 |
+
|
146 |
+
def danbooru_to_e621(dtag, e621_dict):
|
147 |
+
def d_to_e(match, e621_dict):
|
148 |
+
dtag = match.group(0)
|
149 |
+
etag = e621_dict.get(replace_underline(dtag), "")
|
150 |
+
if etag:
|
151 |
+
return etag
|
152 |
+
else:
|
153 |
+
return dtag
|
154 |
+
|
155 |
+
import re
|
156 |
+
tag = re.sub(r'[\w ]+', lambda wrapper: d_to_e(wrapper, e621_dict), dtag, 2)
|
157 |
+
return tag
|
158 |
+
|
159 |
+
|
160 |
+
danbooru_to_e621_dict = load_dict_from_csv('danbooru_e621.csv')
|
161 |
+
|
162 |
+
|
163 |
+
def convert_danbooru_to_e621_prompt(input_prompt: str = "", prompt_type: str = "danbooru"):
|
164 |
+
if prompt_type == "danbooru": return input_prompt
|
165 |
+
tags = input_prompt.split(",") if input_prompt else []
|
166 |
+
people_tags: list[str] = []
|
167 |
+
other_tags: list[str] = []
|
168 |
+
rating_tags: list[str] = []
|
169 |
+
|
170 |
+
e621_dict = danbooru_to_e621_dict
|
171 |
+
for tag in tags:
|
172 |
+
tag = replace_underline(tag)
|
173 |
+
tag = danbooru_to_e621(tag, e621_dict)
|
174 |
+
if tag in PEOPLE_TAGS:
|
175 |
+
people_tags.append(tag)
|
176 |
+
elif tag in DANBOORU_TO_E621_RATING_MAP.keys():
|
177 |
+
rating_tags.append(DANBOORU_TO_E621_RATING_MAP.get(tag.replace(" ",""), ""))
|
178 |
+
else:
|
179 |
+
other_tags.append(tag)
|
180 |
+
|
181 |
+
rating_tags = sorted(set(rating_tags), key=rating_tags.index)
|
182 |
+
rating_tags = [rating_tags[0]] if rating_tags else []
|
183 |
+
rating_tags = ["explicit, nsfw"] if rating_tags and rating_tags[0] == "explicit" else rating_tags
|
184 |
+
|
185 |
+
output_prompt = ", ".join(people_tags + other_tags + rating_tags)
|
186 |
+
|
187 |
+
return output_prompt
|
188 |
+
|
189 |
+
|
190 |
+
from translatepy import Translator
|
191 |
+
translator = Translator()
|
192 |
+
def translate_prompt_old(prompt: str = ""):
|
193 |
+
def translate_to_english(input: str):
|
194 |
+
try:
|
195 |
+
output = str(translator.translate(input, 'English'))
|
196 |
+
except Exception as e:
|
197 |
+
output = input
|
198 |
+
print(e)
|
199 |
+
return output
|
200 |
+
|
201 |
+
def is_japanese(s):
|
202 |
+
import unicodedata
|
203 |
+
for ch in s:
|
204 |
+
name = unicodedata.name(ch, "")
|
205 |
+
if "CJK UNIFIED" in name or "HIRAGANA" in name or "KATAKANA" in name:
|
206 |
+
return True
|
207 |
+
return False
|
208 |
+
|
209 |
+
def to_list(s):
|
210 |
+
return [x.strip() for x in s.split(",")]
|
211 |
+
|
212 |
+
prompts = to_list(prompt)
|
213 |
+
outputs = []
|
214 |
+
for p in prompts:
|
215 |
+
p = translate_to_english(p) if is_japanese(p) else p
|
216 |
+
outputs.append(p)
|
217 |
+
|
218 |
+
return ", ".join(outputs)
|
219 |
+
|
220 |
+
|
221 |
+
def translate_prompt(input: str):
|
222 |
+
try:
|
223 |
+
output = str(translator.translate(input, 'English'))
|
224 |
+
except Exception as e:
|
225 |
+
output = input
|
226 |
+
print(e)
|
227 |
+
return output
|
228 |
+
|
229 |
+
|
230 |
+
def translate_prompt_to_ja(prompt: str = ""):
|
231 |
+
def translate_to_japanese(input: str):
|
232 |
+
try:
|
233 |
+
output = str(translator.translate(input, 'Japanese'))
|
234 |
+
except Exception as e:
|
235 |
+
output = input
|
236 |
+
print(e)
|
237 |
+
return output
|
238 |
+
|
239 |
+
def is_japanese(s):
|
240 |
+
import unicodedata
|
241 |
+
for ch in s:
|
242 |
+
name = unicodedata.name(ch, "")
|
243 |
+
if "CJK UNIFIED" in name or "HIRAGANA" in name or "KATAKANA" in name:
|
244 |
+
return True
|
245 |
+
return False
|
246 |
+
|
247 |
+
def to_list(s):
|
248 |
+
return [x.strip() for x in s.split(",")]
|
249 |
+
|
250 |
+
prompts = to_list(prompt)
|
251 |
+
outputs = []
|
252 |
+
for p in prompts:
|
253 |
+
p = translate_to_japanese(p) if not is_japanese(p) else p
|
254 |
+
outputs.append(p)
|
255 |
+
|
256 |
+
return ", ".join(outputs)
|
257 |
+
|
258 |
+
|
259 |
+
def tags_to_ja(itag, dict):
|
260 |
+
def t_to_j(match, dict):
|
261 |
+
tag = match.group(0)
|
262 |
+
ja = dict.get(replace_underline(tag), "")
|
263 |
+
if ja:
|
264 |
+
return ja
|
265 |
+
else:
|
266 |
+
return tag
|
267 |
+
|
268 |
+
import re
|
269 |
+
tag = re.sub(r'[\w ]+', lambda wrapper: t_to_j(wrapper, dict), itag, 2)
|
270 |
+
|
271 |
+
return tag
|
272 |
+
|
273 |
+
|
274 |
+
def convert_tags_to_ja(input_prompt: str = ""):
|
275 |
+
tags = input_prompt.split(",") if input_prompt else []
|
276 |
+
out_tags = []
|
277 |
+
|
278 |
+
tags_to_ja_dict = load_dict_from_csv('all_tags_ja_ext.csv')
|
279 |
+
dict = tags_to_ja_dict
|
280 |
+
for tag in tags:
|
281 |
+
tag = replace_underline(tag)
|
282 |
+
tag = tags_to_ja(tag, dict)
|
283 |
+
out_tags.append(tag)
|
284 |
+
|
285 |
+
return ", ".join(out_tags)
|
286 |
+
|
287 |
+
|
288 |
+
enable_auto_recom_prompt = True
|
289 |
+
|
290 |
+
|
291 |
+
animagine_ps = to_list("masterpiece, best quality, very aesthetic, absurdres")
|
292 |
+
animagine_nps = to_list("lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]")
|
293 |
+
pony_ps = to_list("score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres")
|
294 |
+
pony_nps = to_list("source_pony, score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends")
|
295 |
+
other_ps = to_list("anime artwork, anime style, studio anime, highly detailed, cinematic photo, 35mm photograph, film, bokeh, professional, 4k, highly detailed")
|
296 |
+
other_nps = to_list("photo, deformed, black and white, realism, disfigured, low contrast, drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly")
|
297 |
+
default_ps = to_list("highly detailed, masterpiece, best quality, very aesthetic, absurdres")
|
298 |
+
default_nps = to_list("score_6, score_5, score_4, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]")
|
299 |
+
def insert_recom_prompt(prompt: str = "", neg_prompt: str = "", type: str = "None"):
|
300 |
+
global enable_auto_recom_prompt
|
301 |
+
prompts = to_list(prompt)
|
302 |
+
neg_prompts = to_list(neg_prompt)
|
303 |
+
|
304 |
+
prompts = list_sub(prompts, animagine_ps + pony_ps)
|
305 |
+
neg_prompts = list_sub(neg_prompts, animagine_nps + pony_nps)
|
306 |
+
|
307 |
+
last_empty_p = [""] if not prompts and type != "None" else []
|
308 |
+
last_empty_np = [""] if not neg_prompts and type != "None" else []
|
309 |
+
|
310 |
+
if type == "Auto":
|
311 |
+
enable_auto_recom_prompt = True
|
312 |
+
else:
|
313 |
+
enable_auto_recom_prompt = False
|
314 |
+
if type == "Animagine":
|
315 |
+
prompts = prompts + animagine_ps
|
316 |
+
neg_prompts = neg_prompts + animagine_nps
|
317 |
+
elif type == "Pony":
|
318 |
+
prompts = prompts + pony_ps
|
319 |
+
neg_prompts = neg_prompts + pony_nps
|
320 |
+
|
321 |
+
prompt = ", ".join(list_uniq(prompts) + last_empty_p)
|
322 |
+
neg_prompt = ", ".join(list_uniq(neg_prompts) + last_empty_np)
|
323 |
+
|
324 |
+
return prompt, neg_prompt
|
325 |
+
|
326 |
+
|
327 |
+
def load_model_prompt_dict():
|
328 |
+
import json
|
329 |
+
dict = {}
|
330 |
+
path = 'model_dict.json' if Path('model_dict.json').exists() else './tagger/model_dict.json'
|
331 |
+
try:
|
332 |
+
with open('model_dict.json', encoding='utf-8') as f:
|
333 |
+
dict = json.load(f)
|
334 |
+
except Exception:
|
335 |
+
pass
|
336 |
+
return dict
|
337 |
+
|
338 |
+
|
339 |
+
model_prompt_dict = load_model_prompt_dict()
|
340 |
+
|
341 |
+
|
342 |
+
def insert_model_recom_prompt(prompt: str = "", neg_prompt: str = "", model_name: str = "None"):
|
343 |
+
if not model_name or not enable_auto_recom_prompt: return prompt, neg_prompt
|
344 |
+
prompts = to_list(prompt)
|
345 |
+
neg_prompts = to_list(neg_prompt)
|
346 |
+
prompts = list_sub(prompts, animagine_ps + pony_ps + other_ps)
|
347 |
+
neg_prompts = list_sub(neg_prompts, animagine_nps + pony_nps + other_nps)
|
348 |
+
last_empty_p = [""] if not prompts and type != "None" else []
|
349 |
+
last_empty_np = [""] if not neg_prompts and type != "None" else []
|
350 |
+
ps = []
|
351 |
+
nps = []
|
352 |
+
if model_name in model_prompt_dict.keys():
|
353 |
+
ps = to_list(model_prompt_dict[model_name]["prompt"])
|
354 |
+
nps = to_list(model_prompt_dict[model_name]["negative_prompt"])
|
355 |
+
else:
|
356 |
+
ps = default_ps
|
357 |
+
nps = default_nps
|
358 |
+
prompts = prompts + ps
|
359 |
+
neg_prompts = neg_prompts + nps
|
360 |
+
prompt = ", ".join(list_uniq(prompts) + last_empty_p)
|
361 |
+
neg_prompt = ", ".join(list_uniq(neg_prompts) + last_empty_np)
|
362 |
+
return prompt, neg_prompt
|
363 |
+
|
364 |
+
|
365 |
+
tag_group_dict = load_dict_from_csv('tag_group.csv')
|
366 |
+
|
367 |
+
|
368 |
+
def remove_specific_prompt(input_prompt: str = "", keep_tags: str = "all"):
|
369 |
+
def is_dressed(tag):
|
370 |
+
import re
|
371 |
+
p = re.compile(r'dress|cloth|uniform|costume|vest|sweater|coat|shirt|jacket|blazer|apron|leotard|hood|sleeve|skirt|shorts|pant|loafer|ribbon|necktie|bow|collar|glove|sock|shoe|boots|wear|emblem')
|
372 |
+
return p.search(tag)
|
373 |
+
|
374 |
+
def is_background(tag):
|
375 |
+
import re
|
376 |
+
p = re.compile(r'background|outline|light|sky|build|day|screen|tree|city')
|
377 |
+
return p.search(tag)
|
378 |
+
|
379 |
+
un_tags = ['solo']
|
380 |
+
group_list = ['groups', 'body_parts', 'attire', 'posture', 'objects', 'creatures', 'locations', 'disambiguation_pages', 'commonly_misused_tags', 'phrases', 'verbs_and_gerunds', 'subjective', 'nudity', 'sex_objects', 'sex', 'sex_acts', 'image_composition', 'artistic_license', 'text', 'year_tags', 'metatags']
|
381 |
+
keep_group_dict = {
|
382 |
+
"body": ['groups', 'body_parts'],
|
383 |
+
"dress": ['groups', 'body_parts', 'attire'],
|
384 |
+
"all": group_list,
|
385 |
+
}
|
386 |
+
|
387 |
+
def is_necessary(tag, keep_tags, group_dict):
|
388 |
+
if keep_tags == "all":
|
389 |
+
return True
|
390 |
+
elif tag in un_tags or group_dict.get(tag, "") in explicit_group:
|
391 |
+
return False
|
392 |
+
elif keep_tags == "body" and is_dressed(tag):
|
393 |
+
return False
|
394 |
+
elif is_background(tag):
|
395 |
+
return False
|
396 |
+
else:
|
397 |
+
return True
|
398 |
+
|
399 |
+
if keep_tags == "all": return input_prompt
|
400 |
+
keep_group = keep_group_dict.get(keep_tags, keep_group_dict["body"])
|
401 |
+
explicit_group = list(set(group_list) ^ set(keep_group))
|
402 |
+
|
403 |
+
tags = input_prompt.split(",") if input_prompt else []
|
404 |
+
people_tags: list[str] = []
|
405 |
+
other_tags: list[str] = []
|
406 |
+
|
407 |
+
group_dict = tag_group_dict
|
408 |
+
for tag in tags:
|
409 |
+
tag = replace_underline(tag)
|
410 |
+
if tag in PEOPLE_TAGS:
|
411 |
+
people_tags.append(tag)
|
412 |
+
elif is_necessary(tag, keep_tags, group_dict):
|
413 |
+
other_tags.append(tag)
|
414 |
+
|
415 |
+
output_prompt = ", ".join(people_tags + other_tags)
|
416 |
+
|
417 |
+
return output_prompt
|
418 |
+
|
419 |
+
|
420 |
+
def sort_taglist(tags: list[str]):
|
421 |
+
if not tags: return []
|
422 |
+
character_tags: list[str] = []
|
423 |
+
series_tags: list[str] = []
|
424 |
+
people_tags: list[str] = []
|
425 |
+
group_list = ['groups', 'body_parts', 'attire', 'posture', 'objects', 'creatures', 'locations', 'disambiguation_pages', 'commonly_misused_tags', 'phrases', 'verbs_and_gerunds', 'subjective', 'nudity', 'sex_objects', 'sex', 'sex_acts', 'image_composition', 'artistic_license', 'text', 'year_tags', 'metatags']
|
426 |
+
group_tags = {}
|
427 |
+
other_tags: list[str] = []
|
428 |
+
rating_tags: list[str] = []
|
429 |
+
|
430 |
+
group_dict = tag_group_dict
|
431 |
+
group_set = set(group_dict.keys())
|
432 |
+
character_set = set(anime_series_dict.keys())
|
433 |
+
series_set = set(anime_series_dict.values())
|
434 |
+
rating_set = set(DANBOORU_TO_E621_RATING_MAP.keys()) | set(DANBOORU_TO_E621_RATING_MAP.values())
|
435 |
+
|
436 |
+
for tag in tags:
|
437 |
+
tag = replace_underline(tag)
|
438 |
+
if tag in PEOPLE_TAGS:
|
439 |
+
people_tags.append(tag)
|
440 |
+
elif tag in rating_set:
|
441 |
+
rating_tags.append(tag)
|
442 |
+
elif tag in group_set:
|
443 |
+
elem = group_dict[tag]
|
444 |
+
group_tags[elem] = group_tags[elem] + [tag] if elem in group_tags else [tag]
|
445 |
+
elif tag in character_set:
|
446 |
+
character_tags.append(tag)
|
447 |
+
elif tag in series_set:
|
448 |
+
series_tags.append(tag)
|
449 |
+
else:
|
450 |
+
other_tags.append(tag)
|
451 |
+
|
452 |
+
output_group_tags: list[str] = []
|
453 |
+
for k in group_list:
|
454 |
+
output_group_tags.extend(group_tags.get(k, []))
|
455 |
+
|
456 |
+
rating_tags = [rating_tags[0]] if rating_tags else []
|
457 |
+
rating_tags = ["explicit, nsfw"] if rating_tags and rating_tags[0] == "explicit" else rating_tags
|
458 |
+
|
459 |
+
output_tags = character_tags + series_tags + people_tags + output_group_tags + other_tags + rating_tags
|
460 |
+
|
461 |
+
return output_tags
|
462 |
+
|
463 |
+
|
464 |
+
def sort_tags(tags: str):
|
465 |
+
if not tags: return ""
|
466 |
+
taglist: list[str] = []
|
467 |
+
for tag in tags.split(","):
|
468 |
+
taglist.append(tag.strip())
|
469 |
+
taglist = list(filter(lambda x: x != "", taglist))
|
470 |
+
return ", ".join(sort_taglist(taglist))
|
471 |
+
|
472 |
+
|
473 |
+
def postprocess_results(results: dict[str, float], general_threshold: float, character_threshold: float):
|
474 |
+
results = {
|
475 |
+
k: v for k, v in sorted(results.items(), key=lambda item: item[1], reverse=True)
|
476 |
+
}
|
477 |
+
|
478 |
+
rating = {}
|
479 |
+
character = {}
|
480 |
+
general = {}
|
481 |
+
|
482 |
+
for k, v in results.items():
|
483 |
+
if k.startswith("rating:"):
|
484 |
+
rating[k.replace("rating:", "")] = v
|
485 |
+
continue
|
486 |
+
elif k.startswith("character:"):
|
487 |
+
character[k.replace("character:", "")] = v
|
488 |
+
continue
|
489 |
+
|
490 |
+
general[k] = v
|
491 |
+
|
492 |
+
character = {k: v for k, v in character.items() if v >= character_threshold}
|
493 |
+
general = {k: v for k, v in general.items() if v >= general_threshold}
|
494 |
+
|
495 |
+
return rating, character, general
|
496 |
+
|
497 |
+
|
498 |
+
def gen_prompt(rating: list[str], character: list[str], general: list[str]):
|
499 |
+
people_tags: list[str] = []
|
500 |
+
other_tags: list[str] = []
|
501 |
+
rating_tag = RATING_MAP[rating[0]]
|
502 |
+
|
503 |
+
for tag in general:
|
504 |
+
if tag in PEOPLE_TAGS:
|
505 |
+
people_tags.append(tag)
|
506 |
+
else:
|
507 |
+
other_tags.append(tag)
|
508 |
+
|
509 |
+
all_tags = people_tags + other_tags
|
510 |
+
|
511 |
+
return ", ".join(all_tags)
|
512 |
+
|
513 |
+
|
514 |
+
@spaces.GPU(duration=30)
|
515 |
+
def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_threshold: float = 0.8):
|
516 |
+
inputs = wd_processor.preprocess(image, return_tensors="pt")
|
517 |
+
|
518 |
+
outputs = wd_model(**inputs.to(wd_model.device, wd_model.dtype))
|
519 |
+
logits = torch.sigmoid(outputs.logits[0]) # take the first logits
|
520 |
+
|
521 |
+
# get probabilities
|
522 |
+
if device != default_device: wd_model.to(device=device)
|
523 |
+
results = {
|
524 |
+
wd_model.config.id2label[i]: float(logit.float()) for i, logit in enumerate(logits)
|
525 |
+
}
|
526 |
+
if device != default_device: wd_model.to(device=default_device)
|
527 |
+
# rating, character, general
|
528 |
+
rating, character, general = postprocess_results(
|
529 |
+
results, general_threshold, character_threshold
|
530 |
+
)
|
531 |
+
prompt = gen_prompt(
|
532 |
+
list(rating.keys()), list(character.keys()), list(general.keys())
|
533 |
+
)
|
534 |
+
output_series_tag = ""
|
535 |
+
output_series_list = character_list_to_series_list(character.keys())
|
536 |
+
if output_series_list:
|
537 |
+
output_series_tag = output_series_list[0]
|
538 |
+
else:
|
539 |
+
output_series_tag = ""
|
540 |
+
return output_series_tag, ", ".join(character.keys()), prompt, gr.update(interactive=True)
|
541 |
+
|
542 |
+
|
543 |
+
def predict_tags_wd(image: Image.Image, input_tags: str, algo: list[str], general_threshold: float = 0.3,
|
544 |
+
character_threshold: float = 0.8, input_series: str = "", input_character: str = ""):
|
545 |
+
if not "Use WD Tagger" in algo and len(algo) != 0:
|
546 |
+
return input_series, input_character, input_tags, gr.update(interactive=True)
|
547 |
+
return predict_tags(image, general_threshold, character_threshold)
|
548 |
+
|
549 |
+
|
550 |
+
def compose_prompt_to_copy(character: str, series: str, general: str):
|
551 |
+
characters = character.split(",") if character else []
|
552 |
+
serieses = series.split(",") if series else []
|
553 |
+
generals = general.split(",") if general else []
|
554 |
+
tags = characters + serieses + generals
|
555 |
+
cprompt = ",".join(tags) if tags else ""
|
556 |
+
return cprompt
|
tagger/utils.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from dartrs.v2 import AspectRatioTag, LengthTag, RatingTag, IdentityTag
|
3 |
+
|
4 |
+
|
5 |
+
V2_ASPECT_RATIO_OPTIONS: list[AspectRatioTag] = [
|
6 |
+
"ultra_wide",
|
7 |
+
"wide",
|
8 |
+
"square",
|
9 |
+
"tall",
|
10 |
+
"ultra_tall",
|
11 |
+
]
|
12 |
+
V2_RATING_OPTIONS: list[RatingTag] = [
|
13 |
+
"sfw",
|
14 |
+
"general",
|
15 |
+
"sensitive",
|
16 |
+
"nsfw",
|
17 |
+
"questionable",
|
18 |
+
"explicit",
|
19 |
+
]
|
20 |
+
V2_LENGTH_OPTIONS: list[LengthTag] = [
|
21 |
+
"very_short",
|
22 |
+
"short",
|
23 |
+
"medium",
|
24 |
+
"long",
|
25 |
+
"very_long",
|
26 |
+
]
|
27 |
+
V2_IDENTITY_OPTIONS: list[IdentityTag] = [
|
28 |
+
"none",
|
29 |
+
"lax",
|
30 |
+
"strict",
|
31 |
+
]
|
32 |
+
|
33 |
+
|
34 |
+
# ref: https://qiita.com/tregu148/items/fccccbbc47d966dd2fc2
|
35 |
+
def gradio_copy_text(_text: None):
|
36 |
+
gr.Info("Copied!")
|
37 |
+
|
38 |
+
|
39 |
+
COPY_ACTION_JS = """\
|
40 |
+
(inputs, _outputs) => {
|
41 |
+
// inputs is the string value of the input_text
|
42 |
+
if (inputs.trim() !== "") {
|
43 |
+
navigator.clipboard.writeText(inputs);
|
44 |
+
}
|
45 |
+
}"""
|
46 |
+
|
47 |
+
|
48 |
+
def gradio_copy_prompt(prompt: str):
|
49 |
+
gr.Info("Copied!")
|
50 |
+
return prompt
|
tagger/v2.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import torch
|
3 |
+
from typing import Callable
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
from dartrs.v2 import (
|
7 |
+
V2Model,
|
8 |
+
MixtralModel,
|
9 |
+
MistralModel,
|
10 |
+
compose_prompt,
|
11 |
+
LengthTag,
|
12 |
+
AspectRatioTag,
|
13 |
+
RatingTag,
|
14 |
+
IdentityTag,
|
15 |
+
)
|
16 |
+
from dartrs.dartrs import DartTokenizer
|
17 |
+
from dartrs.utils import get_generation_config
|
18 |
+
|
19 |
+
|
20 |
+
import gradio as gr
|
21 |
+
from gradio.components import Component
|
22 |
+
|
23 |
+
|
24 |
+
try:
|
25 |
+
from output import UpsamplingOutput
|
26 |
+
except:
|
27 |
+
from .output import UpsamplingOutput
|
28 |
+
|
29 |
+
|
30 |
+
V2_ALL_MODELS = {
|
31 |
+
"dart-v2-moe-sft": {
|
32 |
+
"repo": "p1atdev/dart-v2-moe-sft",
|
33 |
+
"type": "sft",
|
34 |
+
"class": MixtralModel,
|
35 |
+
},
|
36 |
+
"dart-v2-sft": {
|
37 |
+
"repo": "p1atdev/dart-v2-sft",
|
38 |
+
"type": "sft",
|
39 |
+
"class": MistralModel,
|
40 |
+
},
|
41 |
+
}
|
42 |
+
|
43 |
+
|
44 |
+
def prepare_models(model_config: dict):
|
45 |
+
model_name = model_config["repo"]
|
46 |
+
tokenizer = DartTokenizer.from_pretrained(model_name)
|
47 |
+
model = model_config["class"].from_pretrained(model_name)
|
48 |
+
|
49 |
+
return {
|
50 |
+
"tokenizer": tokenizer,
|
51 |
+
"model": model,
|
52 |
+
}
|
53 |
+
|
54 |
+
|
55 |
+
def normalize_tags(tokenizer: DartTokenizer, tags: str):
|
56 |
+
"""Just remove unk tokens."""
|
57 |
+
return ", ".join([tag for tag in tokenizer.tokenize(tags) if tag != "<|unk|>"])
|
58 |
+
|
59 |
+
|
60 |
+
@torch.no_grad()
|
61 |
+
def generate_tags(
|
62 |
+
model: V2Model,
|
63 |
+
tokenizer: DartTokenizer,
|
64 |
+
prompt: str,
|
65 |
+
ban_token_ids: list[int],
|
66 |
+
):
|
67 |
+
output = model.generate(
|
68 |
+
get_generation_config(
|
69 |
+
prompt,
|
70 |
+
tokenizer=tokenizer,
|
71 |
+
temperature=1,
|
72 |
+
top_p=0.9,
|
73 |
+
top_k=100,
|
74 |
+
max_new_tokens=256,
|
75 |
+
ban_token_ids=ban_token_ids,
|
76 |
+
),
|
77 |
+
)
|
78 |
+
|
79 |
+
return output
|
80 |
+
|
81 |
+
|
82 |
+
def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
|
83 |
+
return (
|
84 |
+
[f"1{noun}"]
|
85 |
+
+ [f"{num}{noun}s" for num in range(minimum + 1, maximum + 1)]
|
86 |
+
+ [f"{maximum+1}+{noun}s"]
|
87 |
+
)
|
88 |
+
|
89 |
+
|
90 |
+
PEOPLE_TAGS = (
|
91 |
+
_people_tag("girl") + _people_tag("boy") + _people_tag("other") + ["no humans"]
|
92 |
+
)
|
93 |
+
|
94 |
+
|
95 |
+
def gen_prompt_text(output: UpsamplingOutput):
|
96 |
+
# separate people tags (e.g. 1girl)
|
97 |
+
people_tags = []
|
98 |
+
other_general_tags = []
|
99 |
+
|
100 |
+
for tag in output.general_tags.split(","):
|
101 |
+
tag = tag.strip()
|
102 |
+
if tag in PEOPLE_TAGS:
|
103 |
+
people_tags.append(tag)
|
104 |
+
else:
|
105 |
+
other_general_tags.append(tag)
|
106 |
+
|
107 |
+
return ", ".join(
|
108 |
+
[
|
109 |
+
part.strip()
|
110 |
+
for part in [
|
111 |
+
*people_tags,
|
112 |
+
output.character_tags,
|
113 |
+
output.copyright_tags,
|
114 |
+
*other_general_tags,
|
115 |
+
output.upsampled_tags,
|
116 |
+
output.rating_tag,
|
117 |
+
]
|
118 |
+
if part.strip() != ""
|
119 |
+
]
|
120 |
+
)
|
121 |
+
|
122 |
+
|
123 |
+
def elapsed_time_format(elapsed_time: float) -> str:
|
124 |
+
return f"Elapsed: {elapsed_time:.2f} seconds"
|
125 |
+
|
126 |
+
|
127 |
+
def parse_upsampling_output(
|
128 |
+
upsampler: Callable[..., UpsamplingOutput],
|
129 |
+
):
|
130 |
+
def _parse_upsampling_output(*args) -> tuple[str, str, dict]:
|
131 |
+
output = upsampler(*args)
|
132 |
+
|
133 |
+
return (
|
134 |
+
gen_prompt_text(output),
|
135 |
+
elapsed_time_format(output.elapsed_time),
|
136 |
+
gr.update(interactive=True),
|
137 |
+
gr.update(interactive=True),
|
138 |
+
)
|
139 |
+
|
140 |
+
return _parse_upsampling_output
|
141 |
+
|
142 |
+
|
143 |
+
class V2UI:
|
144 |
+
model_name: str | None = None
|
145 |
+
model: V2Model
|
146 |
+
tokenizer: DartTokenizer
|
147 |
+
|
148 |
+
input_components: list[Component] = []
|
149 |
+
generate_btn: gr.Button
|
150 |
+
|
151 |
+
def on_generate(
|
152 |
+
self,
|
153 |
+
model_name: str,
|
154 |
+
copyright_tags: str,
|
155 |
+
character_tags: str,
|
156 |
+
general_tags: str,
|
157 |
+
rating_tag: RatingTag,
|
158 |
+
aspect_ratio_tag: AspectRatioTag,
|
159 |
+
length_tag: LengthTag,
|
160 |
+
identity_tag: IdentityTag,
|
161 |
+
ban_tags: str,
|
162 |
+
*args,
|
163 |
+
) -> UpsamplingOutput:
|
164 |
+
if self.model_name is None or self.model_name != model_name:
|
165 |
+
models = prepare_models(V2_ALL_MODELS[model_name])
|
166 |
+
self.model = models["model"]
|
167 |
+
self.tokenizer = models["tokenizer"]
|
168 |
+
self.model_name = model_name
|
169 |
+
|
170 |
+
# normalize tags
|
171 |
+
# copyright_tags = normalize_tags(self.tokenizer, copyright_tags)
|
172 |
+
# character_tags = normalize_tags(self.tokenizer, character_tags)
|
173 |
+
# general_tags = normalize_tags(self.tokenizer, general_tags)
|
174 |
+
|
175 |
+
ban_token_ids = self.tokenizer.encode(ban_tags.strip())
|
176 |
+
|
177 |
+
prompt = compose_prompt(
|
178 |
+
prompt=general_tags,
|
179 |
+
copyright=copyright_tags,
|
180 |
+
character=character_tags,
|
181 |
+
rating=rating_tag,
|
182 |
+
aspect_ratio=aspect_ratio_tag,
|
183 |
+
length=length_tag,
|
184 |
+
identity=identity_tag,
|
185 |
+
)
|
186 |
+
|
187 |
+
start = time.time()
|
188 |
+
upsampled_tags = generate_tags(
|
189 |
+
self.model,
|
190 |
+
self.tokenizer,
|
191 |
+
prompt,
|
192 |
+
ban_token_ids,
|
193 |
+
)
|
194 |
+
elapsed_time = time.time() - start
|
195 |
+
|
196 |
+
return UpsamplingOutput(
|
197 |
+
upsampled_tags=upsampled_tags,
|
198 |
+
copyright_tags=copyright_tags,
|
199 |
+
character_tags=character_tags,
|
200 |
+
general_tags=general_tags,
|
201 |
+
rating_tag=rating_tag,
|
202 |
+
aspect_ratio_tag=aspect_ratio_tag,
|
203 |
+
length_tag=length_tag,
|
204 |
+
identity_tag=identity_tag,
|
205 |
+
elapsed_time=elapsed_time,
|
206 |
+
)
|
207 |
+
|
208 |
+
|
209 |
+
def parse_upsampling_output_simple(upsampler: UpsamplingOutput):
|
210 |
+
return gen_prompt_text(upsampler)
|
211 |
+
|
212 |
+
|
213 |
+
v2 = V2UI()
|
214 |
+
|
215 |
+
|
216 |
+
def v2_upsampling_prompt(model: str = "dart-v2-moe-sft", copyright: str = "", character: str = "",
|
217 |
+
general_tags: str = "", rating: str = "nsfw", aspect_ratio: str = "square",
|
218 |
+
length: str = "very_long", identity: str = "lax", ban_tags: str = "censored"):
|
219 |
+
raw_prompt = parse_upsampling_output_simple(v2.on_generate(model, copyright, character, general_tags,
|
220 |
+
rating, aspect_ratio, length, identity, ban_tags))
|
221 |
+
return raw_prompt
|
222 |
+
|
223 |
+
|
224 |
+
def load_dict_from_csv(filename):
|
225 |
+
dict = {}
|
226 |
+
if not Path(filename).exists():
|
227 |
+
if Path('./tagger/', filename).exists(): filename = str(Path('./tagger/', filename))
|
228 |
+
else: return dict
|
229 |
+
try:
|
230 |
+
with open(filename, 'r', encoding="utf-8") as f:
|
231 |
+
lines = f.readlines()
|
232 |
+
except Exception:
|
233 |
+
print(f"Failed to open dictionary file: {filename}")
|
234 |
+
return dict
|
235 |
+
for line in lines:
|
236 |
+
parts = line.strip().split(',')
|
237 |
+
dict[parts[0]] = parts[1]
|
238 |
+
return dict
|
239 |
+
|
240 |
+
|
241 |
+
anime_series_dict = load_dict_from_csv('character_series_dict.csv')
|
242 |
+
|
243 |
+
|
244 |
+
def select_random_character(series: str, character: str):
|
245 |
+
from random import seed, randrange
|
246 |
+
seed()
|
247 |
+
character_list = list(anime_series_dict.keys())
|
248 |
+
character = character_list[randrange(len(character_list) - 1)]
|
249 |
+
series = anime_series_dict.get(character.split(",")[0].strip(), "")
|
250 |
+
return series, character
|
251 |
+
|
252 |
+
|
253 |
+
def v2_random_prompt(general_tags: str = "", copyright: str = "", character: str = "", rating: str = "nsfw",
|
254 |
+
aspect_ratio: str = "square", length: str = "very_long", identity: str = "lax",
|
255 |
+
ban_tags: str = "censored", model: str = "dart-v2-moe-sft"):
|
256 |
+
if copyright == "" and character == "":
|
257 |
+
copyright, character = select_random_character("", "")
|
258 |
+
raw_prompt = v2_upsampling_prompt(model, copyright, character, general_tags, rating,
|
259 |
+
aspect_ratio, length, identity, ban_tags)
|
260 |
+
return raw_prompt, copyright, character
|
utils.py
ADDED
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import gradio as gr
|
4 |
+
from constants import (
|
5 |
+
DIFFUSERS_FORMAT_LORAS,
|
6 |
+
CIVITAI_API_KEY,
|
7 |
+
HF_TOKEN,
|
8 |
+
MODEL_TYPE_CLASS,
|
9 |
+
DIRECTORY_LORAS,
|
10 |
+
)
|
11 |
+
from huggingface_hub import HfApi
|
12 |
+
from diffusers import DiffusionPipeline
|
13 |
+
from huggingface_hub import model_info as model_info_data
|
14 |
+
from diffusers.pipelines.pipeline_loading_utils import variant_compatible_siblings
|
15 |
+
from pathlib import PosixPath
|
16 |
+
from unidecode import unidecode
|
17 |
+
import urllib.parse
|
18 |
+
import copy
|
19 |
+
import requests
|
20 |
+
from requests.adapters import HTTPAdapter
|
21 |
+
from urllib3.util import Retry
|
22 |
+
|
23 |
+
USER_AGENT = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:127.0) Gecko/20100101 Firefox/127.0'
|
24 |
+
|
25 |
+
|
26 |
+
def request_json_data(url):
|
27 |
+
model_version_id = url.split('/')[-1]
|
28 |
+
if "?modelVersionId=" in model_version_id:
|
29 |
+
match = re.search(r'modelVersionId=(\d+)', url)
|
30 |
+
model_version_id = match.group(1)
|
31 |
+
|
32 |
+
endpoint_url = f"https://civitai.com/api/v1/model-versions/{model_version_id}"
|
33 |
+
|
34 |
+
params = {}
|
35 |
+
headers = {'User-Agent': USER_AGENT, 'content-type': 'application/json'}
|
36 |
+
session = requests.Session()
|
37 |
+
retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
|
38 |
+
session.mount("https://", HTTPAdapter(max_retries=retries))
|
39 |
+
|
40 |
+
try:
|
41 |
+
result = session.get(endpoint_url, params=params, headers=headers, stream=True, timeout=(3.0, 15))
|
42 |
+
result.raise_for_status()
|
43 |
+
json_data = result.json()
|
44 |
+
return json_data if json_data else None
|
45 |
+
except Exception as e:
|
46 |
+
print(f"Error: {e}")
|
47 |
+
return None
|
48 |
+
|
49 |
+
|
50 |
+
class ModelInformation:
|
51 |
+
def __init__(self, json_data):
|
52 |
+
self.model_version_id = json_data.get("id", "")
|
53 |
+
self.model_id = json_data.get("modelId", "")
|
54 |
+
self.download_url = json_data.get("downloadUrl", "")
|
55 |
+
self.model_url = f"https://civitai.com/models/{self.model_id}?modelVersionId={self.model_version_id}"
|
56 |
+
self.filename_url = next(
|
57 |
+
(v.get("name", "") for v in json_data.get("files", []) if str(self.model_version_id) in v.get("downloadUrl", "")), ""
|
58 |
+
)
|
59 |
+
self.filename_url = self.filename_url if self.filename_url else ""
|
60 |
+
self.description = json_data.get("description", "")
|
61 |
+
if self.description is None: self.description = ""
|
62 |
+
self.model_name = json_data.get("model", {}).get("name", "")
|
63 |
+
self.model_type = json_data.get("model", {}).get("type", "")
|
64 |
+
self.nsfw = json_data.get("model", {}).get("nsfw", False)
|
65 |
+
self.poi = json_data.get("model", {}).get("poi", False)
|
66 |
+
self.images = [img.get("url", "") for img in json_data.get("images", [])]
|
67 |
+
self.example_prompt = json_data.get("trainedWords", [""])[0] if json_data.get("trainedWords") else ""
|
68 |
+
self.original_json = copy.deepcopy(json_data)
|
69 |
+
|
70 |
+
|
71 |
+
def retrieve_model_info(url):
|
72 |
+
json_data = request_json_data(url)
|
73 |
+
if not json_data:
|
74 |
+
return None
|
75 |
+
model_descriptor = ModelInformation(json_data)
|
76 |
+
return model_descriptor
|
77 |
+
|
78 |
+
|
79 |
+
def download_things(directory, url, hf_token="", civitai_api_key="", romanize=False):
|
80 |
+
url = url.strip()
|
81 |
+
downloaded_file_path = None
|
82 |
+
|
83 |
+
if "drive.google.com" in url:
|
84 |
+
original_dir = os.getcwd()
|
85 |
+
os.chdir(directory)
|
86 |
+
os.system(f"gdown --fuzzy {url}")
|
87 |
+
os.chdir(original_dir)
|
88 |
+
elif "huggingface.co" in url:
|
89 |
+
url = url.replace("?download=true", "")
|
90 |
+
# url = urllib.parse.quote(url, safe=':/') # fix encoding
|
91 |
+
if "/blob/" in url:
|
92 |
+
url = url.replace("/blob/", "/resolve/")
|
93 |
+
user_header = f'"Authorization: Bearer {hf_token}"'
|
94 |
+
|
95 |
+
filename = unidecode(url.split('/')[-1]) if romanize else url.split('/')[-1]
|
96 |
+
|
97 |
+
if hf_token:
|
98 |
+
os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory} -o {filename}")
|
99 |
+
else:
|
100 |
+
os.system(f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {filename}")
|
101 |
+
|
102 |
+
downloaded_file_path = os.path.join(directory, filename)
|
103 |
+
|
104 |
+
elif "civitai.com" in url:
|
105 |
+
|
106 |
+
if not civitai_api_key:
|
107 |
+
print("\033[91mYou need an API key to download Civitai models.\033[0m")
|
108 |
+
|
109 |
+
model_profile = retrieve_model_info(url)
|
110 |
+
if model_profile.download_url and model_profile.filename_url:
|
111 |
+
url = model_profile.download_url
|
112 |
+
filename = unidecode(model_profile.filename_url) if romanize else model_profile.filename_url
|
113 |
+
else:
|
114 |
+
if "?" in url:
|
115 |
+
url = url.split("?")[0]
|
116 |
+
filename = ""
|
117 |
+
|
118 |
+
url_dl = url + f"?token={civitai_api_key}"
|
119 |
+
print(f"Filename: {filename}")
|
120 |
+
|
121 |
+
param_filename = ""
|
122 |
+
if filename:
|
123 |
+
param_filename = f"-o '{filename}'"
|
124 |
+
|
125 |
+
aria2_command = (
|
126 |
+
f'aria2c --console-log-level=error --summary-interval=10 -c -x 16 '
|
127 |
+
f'-k 1M -s 16 -d "{directory}" {param_filename} "{url_dl}"'
|
128 |
+
)
|
129 |
+
os.system(aria2_command)
|
130 |
+
|
131 |
+
if param_filename and os.path.exists(os.path.join(directory, filename)):
|
132 |
+
downloaded_file_path = os.path.join(directory, filename)
|
133 |
+
|
134 |
+
# # PLAN B
|
135 |
+
# # Follow the redirect to get the actual download URL
|
136 |
+
# curl_command = (
|
137 |
+
# f'curl -L -sI --connect-timeout 5 --max-time 5 '
|
138 |
+
# f'-H "Content-Type: application/json" '
|
139 |
+
# f'-H "Authorization: Bearer {civitai_api_key}" "{url}"'
|
140 |
+
# )
|
141 |
+
|
142 |
+
# headers = os.popen(curl_command).read()
|
143 |
+
|
144 |
+
# # Look for the redirected "Location" URL
|
145 |
+
# location_match = re.search(r'location: (.+)', headers, re.IGNORECASE)
|
146 |
+
|
147 |
+
# if location_match:
|
148 |
+
# redirect_url = location_match.group(1).strip()
|
149 |
+
|
150 |
+
# # Extract the filename from the redirect URL's "Content-Disposition"
|
151 |
+
# filename_match = re.search(r'filename%3D%22(.+?)%22', redirect_url)
|
152 |
+
# if filename_match:
|
153 |
+
# encoded_filename = filename_match.group(1)
|
154 |
+
# # Decode the URL-encoded filename
|
155 |
+
# decoded_filename = urllib.parse.unquote(encoded_filename)
|
156 |
+
|
157 |
+
# filename = unidecode(decoded_filename) if romanize else decoded_filename
|
158 |
+
# print(f"Filename: {filename}")
|
159 |
+
|
160 |
+
# aria2_command = (
|
161 |
+
# f'aria2c --console-log-level=error --summary-interval=10 -c -x 16 '
|
162 |
+
# f'-k 1M -s 16 -d "{directory}" -o "{filename}" "{redirect_url}"'
|
163 |
+
# )
|
164 |
+
# return_code = os.system(aria2_command)
|
165 |
+
|
166 |
+
# # if return_code != 0:
|
167 |
+
# # raise RuntimeError(f"Failed to download file: {filename}. Error code: {return_code}")
|
168 |
+
# downloaded_file_path = os.path.join(directory, filename)
|
169 |
+
# if not os.path.exists(downloaded_file_path):
|
170 |
+
# downloaded_file_path = None
|
171 |
+
|
172 |
+
# if not downloaded_file_path:
|
173 |
+
# # Old method
|
174 |
+
# if "?" in url:
|
175 |
+
# url = url.split("?")[0]
|
176 |
+
# url = url + f"?token={civitai_api_key}"
|
177 |
+
# os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
|
178 |
+
|
179 |
+
else:
|
180 |
+
os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
|
181 |
+
|
182 |
+
return downloaded_file_path
|
183 |
+
|
184 |
+
|
185 |
+
def get_model_list(directory_path):
|
186 |
+
model_list = []
|
187 |
+
valid_extensions = {'.ckpt', '.pt', '.pth', '.safetensors', '.bin'}
|
188 |
+
|
189 |
+
for filename in os.listdir(directory_path):
|
190 |
+
if os.path.splitext(filename)[1] in valid_extensions:
|
191 |
+
# name_without_extension = os.path.splitext(filename)[0]
|
192 |
+
file_path = os.path.join(directory_path, filename)
|
193 |
+
# model_list.append((name_without_extension, file_path))
|
194 |
+
model_list.append(file_path)
|
195 |
+
print('\033[34mFILE: ' + file_path + '\033[0m')
|
196 |
+
return model_list
|
197 |
+
|
198 |
+
|
199 |
+
def extract_parameters(input_string):
|
200 |
+
parameters = {}
|
201 |
+
input_string = input_string.replace("\n", "")
|
202 |
+
|
203 |
+
if "Negative prompt:" not in input_string:
|
204 |
+
if "Steps:" in input_string:
|
205 |
+
input_string = input_string.replace("Steps:", "Negative prompt: Steps:")
|
206 |
+
else:
|
207 |
+
print("Invalid metadata")
|
208 |
+
parameters["prompt"] = input_string
|
209 |
+
return parameters
|
210 |
+
|
211 |
+
parm = input_string.split("Negative prompt:")
|
212 |
+
parameters["prompt"] = parm[0].strip()
|
213 |
+
if "Steps:" not in parm[1]:
|
214 |
+
print("Steps not detected")
|
215 |
+
parameters["neg_prompt"] = parm[1].strip()
|
216 |
+
return parameters
|
217 |
+
parm = parm[1].split("Steps:")
|
218 |
+
parameters["neg_prompt"] = parm[0].strip()
|
219 |
+
input_string = "Steps:" + parm[1]
|
220 |
+
|
221 |
+
# Extracting Steps
|
222 |
+
steps_match = re.search(r'Steps: (\d+)', input_string)
|
223 |
+
if steps_match:
|
224 |
+
parameters['Steps'] = int(steps_match.group(1))
|
225 |
+
|
226 |
+
# Extracting Size
|
227 |
+
size_match = re.search(r'Size: (\d+x\d+)', input_string)
|
228 |
+
if size_match:
|
229 |
+
parameters['Size'] = size_match.group(1)
|
230 |
+
width, height = map(int, parameters['Size'].split('x'))
|
231 |
+
parameters['width'] = width
|
232 |
+
parameters['height'] = height
|
233 |
+
|
234 |
+
# Extracting other parameters
|
235 |
+
other_parameters = re.findall(r'(\w+): (.*?)(?=, \w+|$)', input_string)
|
236 |
+
for param in other_parameters:
|
237 |
+
parameters[param[0]] = param[1].strip('"')
|
238 |
+
|
239 |
+
return parameters
|
240 |
+
|
241 |
+
|
242 |
+
def get_my_lora(link_url, romanize):
|
243 |
+
l_name = ""
|
244 |
+
for url in [url.strip() for url in link_url.split(',')]:
|
245 |
+
if not os.path.exists(f"./loras/{url.split('/')[-1]}"):
|
246 |
+
l_name = download_things(DIRECTORY_LORAS, url, HF_TOKEN, CIVITAI_API_KEY, romanize)
|
247 |
+
new_lora_model_list = get_model_list(DIRECTORY_LORAS)
|
248 |
+
new_lora_model_list.insert(0, "None")
|
249 |
+
new_lora_model_list = new_lora_model_list + DIFFUSERS_FORMAT_LORAS
|
250 |
+
msg_lora = "Downloaded"
|
251 |
+
if l_name:
|
252 |
+
msg_lora += f": <b>{l_name}</b>"
|
253 |
+
print(msg_lora)
|
254 |
+
|
255 |
+
return gr.update(
|
256 |
+
choices=new_lora_model_list
|
257 |
+
), gr.update(
|
258 |
+
choices=new_lora_model_list
|
259 |
+
), gr.update(
|
260 |
+
choices=new_lora_model_list
|
261 |
+
), gr.update(
|
262 |
+
choices=new_lora_model_list
|
263 |
+
), gr.update(
|
264 |
+
choices=new_lora_model_list
|
265 |
+
), gr.update(
|
266 |
+
value=msg_lora
|
267 |
+
)
|
268 |
+
|
269 |
+
|
270 |
+
def info_html(json_data, title, subtitle):
|
271 |
+
return f"""
|
272 |
+
<div style='padding: 0; border-radius: 10px;'>
|
273 |
+
<p style='margin: 0; font-weight: bold;'>{title}</p>
|
274 |
+
<details>
|
275 |
+
<summary>Details</summary>
|
276 |
+
<p style='margin: 0; font-weight: bold;'>{subtitle}</p>
|
277 |
+
</details>
|
278 |
+
</div>
|
279 |
+
"""
|
280 |
+
|
281 |
+
|
282 |
+
def get_model_type(repo_id: str):
|
283 |
+
api = HfApi(token=os.environ.get("HF_TOKEN")) # if use private or gated model
|
284 |
+
default = "SD 1.5"
|
285 |
+
try:
|
286 |
+
model = api.model_info(repo_id=repo_id, timeout=5.0)
|
287 |
+
tags = model.tags
|
288 |
+
for tag in tags:
|
289 |
+
if tag in MODEL_TYPE_CLASS.keys(): return MODEL_TYPE_CLASS.get(tag, default)
|
290 |
+
except Exception:
|
291 |
+
return default
|
292 |
+
return default
|
293 |
+
|
294 |
+
|
295 |
+
def restart_space(repo_id: str, factory_reboot: bool):
|
296 |
+
api = HfApi(token=os.environ.get("HF_TOKEN"))
|
297 |
+
try:
|
298 |
+
runtime = api.get_space_runtime(repo_id=repo_id)
|
299 |
+
if runtime.stage == "RUNNING":
|
300 |
+
api.restart_space(repo_id=repo_id, factory_reboot=factory_reboot)
|
301 |
+
print(f"Restarting space: {repo_id}")
|
302 |
+
else:
|
303 |
+
print(f"Space {repo_id} is in stage: {runtime.stage}")
|
304 |
+
except Exception as e:
|
305 |
+
print(e)
|
306 |
+
|
307 |
+
|
308 |
+
def extract_exif_data(image):
|
309 |
+
if image is None: return ""
|
310 |
+
|
311 |
+
try:
|
312 |
+
metadata_keys = ['parameters', 'metadata', 'prompt', 'Comment']
|
313 |
+
|
314 |
+
for key in metadata_keys:
|
315 |
+
if key in image.info:
|
316 |
+
return image.info[key]
|
317 |
+
|
318 |
+
return str(image.info)
|
319 |
+
|
320 |
+
except Exception as e:
|
321 |
+
return f"Error extracting metadata: {str(e)}"
|
322 |
+
|
323 |
+
|
324 |
+
def create_mask_now(img, invert):
|
325 |
+
import numpy as np
|
326 |
+
import time
|
327 |
+
|
328 |
+
time.sleep(0.5)
|
329 |
+
|
330 |
+
transparent_image = img["layers"][0]
|
331 |
+
|
332 |
+
# Extract the alpha channel
|
333 |
+
alpha_channel = np.array(transparent_image)[:, :, 3]
|
334 |
+
|
335 |
+
# Create a binary mask by thresholding the alpha channel
|
336 |
+
binary_mask = alpha_channel > 1
|
337 |
+
|
338 |
+
if invert:
|
339 |
+
print("Invert")
|
340 |
+
# Invert the binary mask so that the drawn shape is white and the rest is black
|
341 |
+
binary_mask = np.invert(binary_mask)
|
342 |
+
|
343 |
+
# Convert the binary mask to a 3-channel RGB mask
|
344 |
+
rgb_mask = np.stack((binary_mask,) * 3, axis=-1)
|
345 |
+
|
346 |
+
# Convert the mask to uint8
|
347 |
+
rgb_mask = rgb_mask.astype(np.uint8) * 255
|
348 |
+
|
349 |
+
return img["background"], rgb_mask
|
350 |
+
|
351 |
+
|
352 |
+
def download_diffuser_repo(repo_name: str, model_type: str, revision: str = "main", token=True):
|
353 |
+
|
354 |
+
variant = None
|
355 |
+
if token is True and not os.environ.get("HF_TOKEN"):
|
356 |
+
token = None
|
357 |
+
|
358 |
+
if model_type == "SDXL":
|
359 |
+
info = model_info_data(
|
360 |
+
repo_name,
|
361 |
+
token=token,
|
362 |
+
revision=revision,
|
363 |
+
timeout=5.0,
|
364 |
+
)
|
365 |
+
|
366 |
+
filenames = {sibling.rfilename for sibling in info.siblings}
|
367 |
+
model_filenames, variant_filenames = variant_compatible_siblings(
|
368 |
+
filenames, variant="fp16"
|
369 |
+
)
|
370 |
+
|
371 |
+
if len(variant_filenames):
|
372 |
+
variant = "fp16"
|
373 |
+
|
374 |
+
cached_folder = DiffusionPipeline.download(
|
375 |
+
pretrained_model_name=repo_name,
|
376 |
+
force_download=False,
|
377 |
+
token=token,
|
378 |
+
revision=revision,
|
379 |
+
# mirror="https://hf-mirror.com",
|
380 |
+
variant=variant,
|
381 |
+
use_safetensors=True,
|
382 |
+
trust_remote_code=False,
|
383 |
+
timeout=5.0,
|
384 |
+
)
|
385 |
+
|
386 |
+
if isinstance(cached_folder, PosixPath):
|
387 |
+
cached_folder = cached_folder.as_posix()
|
388 |
+
|
389 |
+
# Task model
|
390 |
+
# from huggingface_hub import hf_hub_download
|
391 |
+
# hf_hub_download(
|
392 |
+
# task_model,
|
393 |
+
# filename="diffusion_pytorch_model.safetensors", # fix fp16 variant
|
394 |
+
# )
|
395 |
+
|
396 |
+
return cached_folder
|
397 |
+
|
398 |
+
|
399 |
+
def progress_step_bar(step, total):
|
400 |
+
# Calculate the percentage for the progress bar width
|
401 |
+
percentage = min(100, ((step / total) * 100))
|
402 |
+
|
403 |
+
return f"""
|
404 |
+
<div style="position: relative; width: 100%; background-color: gray; border-radius: 5px; overflow: hidden;">
|
405 |
+
<div style="width: {percentage}%; height: 17px; background-color: #800080; transition: width 0.5s;"></div>
|
406 |
+
<div style="position: absolute; width: 100%; text-align: center; color: white; top: 0; line-height: 19px; font-size: 13px;">
|
407 |
+
{int(percentage)}%
|
408 |
+
</div>
|
409 |
+
</div>
|
410 |
+
"""
|
411 |
+
|
412 |
+
|
413 |
+
def html_template_message(msg):
|
414 |
+
return f"""
|
415 |
+
<div style="position: relative; width: 100%; background-color: gray; border-radius: 5px; overflow: hidden;">
|
416 |
+
<div style="width: 0%; height: 17px; background-color: #800080; transition: width 0.5s;"></div>
|
417 |
+
<div style="position: absolute; width: 100%; text-align: center; color: white; top: 0; line-height: 19px; font-size: 14px; font-weight: bold; text-shadow: 1px 1px 2px black;">
|
418 |
+
{msg}
|
419 |
+
</div>
|
420 |
+
</div>
|
421 |
+
"""
|