John6666 commited on
Commit
8f8a974
·
verified ·
1 Parent(s): b7ca65a

Upload 22 files

Browse files
app.py CHANGED
@@ -1,231 +1,104 @@
1
  import spaces
2
- import gradio as gr
3
  import os
4
  from stablepy import Model_Diffusers
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from stablepy.diffusers_vanilla.style_prompt_config import STYLE_NAMES
6
- from stablepy.diffusers_vanilla.constants import FLUX_CN_UNION_MODES
7
  import torch
8
  import re
9
- from huggingface_hub import HfApi
10
  from stablepy import (
11
- CONTROLNET_MODEL_IDS,
12
- VALID_TASKS,
13
- T2I_PREPROCESSOR_NAME,
14
- FLASH_LORA,
15
- SCHEDULER_CONFIG_MAP,
16
  scheduler_names,
17
- IP_ADAPTER_MODELS,
18
  IP_ADAPTERS_SD,
19
  IP_ADAPTERS_SDXL,
20
- REPO_IMAGE_ENCODER,
21
- ALL_PROMPT_WEIGHT_OPTIONS,
22
- SD15_TASKS,
23
- SDXL_TASKS,
24
  )
25
  import time
26
  from PIL import ImageFile
27
- #import urllib.parse
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  ImageFile.LOAD_TRUNCATED_IMAGES = True
 
30
  print(os.getenv("SPACES_ZERO_GPU"))
31
 
32
- PREPROCESSOR_CONTROLNET = {
33
- "openpose": [
34
- "Openpose",
35
- "None",
36
- ],
37
- "scribble": [
38
- "HED",
39
- "PidiNet",
40
- "None",
41
- ],
42
- "softedge": [
43
- "PidiNet",
44
- "HED",
45
- "HED safe",
46
- "PidiNet safe",
47
- "None",
48
- ],
49
- "segmentation": [
50
- "UPerNet",
51
- "None",
52
- ],
53
- "depth": [
54
- "DPT",
55
- "Midas",
56
- "None",
57
- ],
58
- "normalbae": [
59
- "NormalBae",
60
- "None",
61
- ],
62
- "lineart": [
63
- "Lineart",
64
- "Lineart coarse",
65
- "Lineart (anime)",
66
- "None",
67
- "None (anime)",
68
- ],
69
- "lineart_anime": [
70
- "Lineart",
71
- "Lineart coarse",
72
- "Lineart (anime)",
73
- "None",
74
- "None (anime)",
75
- ],
76
- "shuffle": [
77
- "ContentShuffle",
78
- "None",
79
- ],
80
- "canny": [
81
- "Canny",
82
- "None",
83
- ],
84
- "mlsd": [
85
- "MLSD",
86
- "None",
87
- ],
88
- "ip2p": [
89
- "ip2p"
90
- ],
91
- "recolor": [
92
- "Recolor luminance",
93
- "Recolor intensity",
94
- "None",
95
- ],
96
- "tile": [
97
- "Mild Blur",
98
- "Moderate Blur",
99
- "Heavy Blur",
100
- "None",
101
- ],
102
- }
103
-
104
- TASK_STABLEPY = {
105
- 'txt2img': 'txt2img',
106
- 'img2img': 'img2img',
107
- 'inpaint': 'inpaint',
108
- # 'canny T2I Adapter': 'sdxl_canny_t2i', # NO HAVE STEP CALLBACK PARAMETERS SO NOT WORKS WITH DIFFUSERS 0.29.0
109
- # 'sketch T2I Adapter': 'sdxl_sketch_t2i',
110
- # 'lineart T2I Adapter': 'sdxl_lineart_t2i',
111
- # 'depth-midas T2I Adapter': 'sdxl_depth-midas_t2i',
112
- # 'openpose T2I Adapter': 'sdxl_openpose_t2i',
113
- 'openpose ControlNet': 'openpose',
114
- 'canny ControlNet': 'canny',
115
- 'mlsd ControlNet': 'mlsd',
116
- 'scribble ControlNet': 'scribble',
117
- 'softedge ControlNet': 'softedge',
118
- 'segmentation ControlNet': 'segmentation',
119
- 'depth ControlNet': 'depth',
120
- 'normalbae ControlNet': 'normalbae',
121
- 'lineart ControlNet': 'lineart',
122
- 'lineart_anime ControlNet': 'lineart_anime',
123
- 'shuffle ControlNet': 'shuffle',
124
- 'ip2p ControlNet': 'ip2p',
125
- 'optical pattern ControlNet': 'pattern',
126
- 'recolor ControlNet': 'recolor',
127
- 'tile ControlNet': 'tile',
128
- }
129
-
130
- TASK_MODEL_LIST = list(TASK_STABLEPY.keys())
131
-
132
- UPSCALER_DICT_GUI = {
133
- None: None,
134
- "Lanczos": "Lanczos",
135
- "Nearest": "Nearest",
136
- 'Latent': 'Latent',
137
- 'Latent (antialiased)': 'Latent (antialiased)',
138
- 'Latent (bicubic)': 'Latent (bicubic)',
139
- 'Latent (bicubic antialiased)': 'Latent (bicubic antialiased)',
140
- 'Latent (nearest)': 'Latent (nearest)',
141
- 'Latent (nearest-exact)': 'Latent (nearest-exact)',
142
- "RealESRGAN_x4plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
143
- "RealESRNet_x4plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth",
144
- "RealESRGAN_x4plus_anime_6B": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
145
- "RealESRGAN_x2plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
146
- "realesr-animevideov3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
147
- "realesr-general-x4v3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
148
- "realesr-general-wdn-x4v3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
149
- "4x-UltraSharp": "https://huggingface.co/Shandypur/ESRGAN-4x-UltraSharp/resolve/main/4x-UltraSharp.pth",
150
- "4x_foolhardy_Remacri": "https://huggingface.co/FacehugmanIII/4x_foolhardy_Remacri/resolve/main/4x_foolhardy_Remacri.pth",
151
- "Remacri4xExtraSmoother": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/Remacri%204x%20ExtraSmoother.pth",
152
- "AnimeSharp4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/AnimeSharp%204x.pth",
153
- "lollypop": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/lollypop.pth",
154
- "RealisticRescaler4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/RealisticRescaler%204x.pth",
155
- "NickelbackFS4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/NickelbackFS%204x.pth"
156
- }
157
-
158
- UPSCALER_KEYS = list(UPSCALER_DICT_GUI.keys())
159
-
160
- def get_model_list(directory_path):
161
- model_list = []
162
- valid_extensions = {'.ckpt', '.pt', '.pth', '.safetensors', '.bin'}
163
-
164
- for filename in os.listdir(directory_path):
165
- if os.path.splitext(filename)[1] in valid_extensions:
166
- # name_without_extension = os.path.splitext(filename)[0]
167
- file_path = os.path.join(directory_path, filename)
168
- # model_list.append((name_without_extension, file_path))
169
- model_list.append(file_path)
170
- print('\033[34mFILE: ' + file_path + '\033[0m')
171
- return model_list
172
-
173
  ## BEGIN MOD
174
  from modutils import (list_uniq, download_private_repo, get_model_id_list, get_tupled_embed_list,
175
  get_lora_model_list, get_all_lora_tupled_list, update_loras, apply_lora_prompt, set_prompt_loras,
176
  get_my_lora, upload_file_lora, move_file_lora, search_civitai_lora, select_civitai_lora,
177
  update_civitai_selection, get_civitai_tag, CIVITAI_SORT, CIVITAI_PERIOD, CIVITAI_BASEMODEL,
178
  set_textual_inversion_prompt, get_model_pipeline, change_interface_mode, get_t2i_model_info,
179
- get_tupled_model_list, save_gallery_images, set_optimization, set_sampler_settings,
180
  set_quick_presets, process_style_prompt, optimization_list, save_images, download_things,
181
- preset_styles, preset_quality, preset_sampler_setting, translate_to_en)
182
  from env import (HF_TOKEN, CIVITAI_API_KEY, HF_LORA_ESSENTIAL_PRIVATE_REPO, HF_VAE_PRIVATE_REPO,
183
  HF_SDXL_EMBEDS_NEGATIVE_PRIVATE_REPO, HF_SDXL_EMBEDS_POSITIVE_PRIVATE_REPO,
184
- directory_models, directory_loras, directory_vaes, directory_embeds, directory_embeds_sdxl,
185
- directory_embeds_positive_sdxl, load_diffusers_format_model,
186
- download_model_list, download_lora_list, download_vae_list, download_embeds)
 
 
 
 
187
 
188
  # - **Download Models**
189
- download_model = ", ".join(download_model_list)
190
  # - **Download VAEs**
191
- download_vae = ", ".join(download_vae_list)
192
  # - **Download LoRAs**
193
- download_lora = ", ".join(download_lora_list)
194
-
195
- download_private_repo(HF_LORA_ESSENTIAL_PRIVATE_REPO, directory_loras, True)
196
- download_private_repo(HF_VAE_PRIVATE_REPO, directory_vaes, False)
197
-
198
- load_diffusers_format_model = list_uniq(get_model_id_list() + load_diffusers_format_model)
199
- ## END MOD
200
 
201
  # Download stuffs
202
- for url in [url.strip() for url in download_model.split(',')]:
203
  if not os.path.exists(f"./models/{url.split('/')[-1]}"):
204
- download_things(directory_models, url, HF_TOKEN, CIVITAI_API_KEY)
205
- for url in [url.strip() for url in download_vae.split(',')]:
206
  if not os.path.exists(f"./vaes/{url.split('/')[-1]}"):
207
- download_things(directory_vaes, url, HF_TOKEN, CIVITAI_API_KEY)
208
- for url in [url.strip() for url in download_lora.split(',')]:
209
  if not os.path.exists(f"./loras/{url.split('/')[-1]}"):
210
- download_things(directory_loras, url, HF_TOKEN, CIVITAI_API_KEY)
211
 
212
  # Download Embeddings
213
- for url_embed in download_embeds:
214
  if not os.path.exists(f"./embedings/{url_embed.split('/')[-1]}"):
215
- download_things(directory_embeds, url_embed, HF_TOKEN, CIVITAI_API_KEY)
216
 
217
  # Build list models
218
- embed_list = get_model_list(directory_embeds)
219
- model_list = get_model_list(directory_models)
220
- model_list = load_diffusers_format_model + model_list
221
- ## BEGIN MOD
222
  lora_model_list = get_lora_model_list()
223
- vae_model_list = get_model_list(directory_vaes)
224
  vae_model_list.insert(0, "None")
225
 
226
- download_private_repo(HF_SDXL_EMBEDS_NEGATIVE_PRIVATE_REPO, directory_embeds_sdxl, False)
227
- download_private_repo(HF_SDXL_EMBEDS_POSITIVE_PRIVATE_REPO, directory_embeds_positive_sdxl, False)
228
- embed_sdxl_list = get_model_list(directory_embeds_sdxl) + get_model_list(directory_embeds_positive_sdxl)
 
 
229
 
230
  def get_embed_list(pipeline_name):
231
  return get_tupled_embed_list(embed_sdxl_list if pipeline_name == "StableDiffusionXLPipeline" else embed_list)
@@ -248,12 +121,13 @@ warnings.filterwarnings(action="ignore", category=UserWarning, module="diffusers
248
  warnings.filterwarnings(action="ignore", category=FutureWarning, module="transformers")
249
  ## BEGIN MOD
250
  from stablepy import logger
251
- logger.setLevel(logging.CRITICAL)
 
252
 
253
- from v2 import V2_ALL_MODELS, v2_random_prompt, v2_upsampling_prompt
254
- from utils import (gradio_copy_text, COPY_ACTION_JS, gradio_copy_prompt,
255
  V2_ASPECT_RATIO_OPTIONS, V2_RATING_OPTIONS, V2_LENGTH_OPTIONS, V2_IDENTITY_OPTIONS)
256
- from tagger import (predict_tags_wd, convert_danbooru_to_e621_prompt,
257
  remove_specific_prompt, insert_recom_prompt, insert_model_recom_prompt,
258
  compose_prompt_to_copy, translate_prompt, select_random_character)
259
  def description_ui():
@@ -267,137 +141,91 @@ def description_ui():
267
  )
268
  ## END MOD
269
 
270
- msg_inc_vae = (
271
- "Use the right VAE for your model to maintain image quality. The wrong"
272
- " VAE can lead to poor results, like blurriness in the generated images."
273
- )
274
-
275
- SDXL_TASK = [k for k, v in TASK_STABLEPY.items() if v in SDXL_TASKS]
276
- SD_TASK = [k for k, v in TASK_STABLEPY.items() if v in SD15_TASKS]
277
- FLUX_TASK = list(TASK_STABLEPY.keys())[:3] + [k for k, v in TASK_STABLEPY.items() if v in FLUX_CN_UNION_MODES.keys()]
278
-
279
- MODEL_TYPE_TASK = {
280
- "SD 1.5": SD_TASK,
281
- "SDXL": SDXL_TASK,
282
- "FLUX": FLUX_TASK,
283
- }
284
-
285
- MODEL_TYPE_CLASS = {
286
- "diffusers:StableDiffusionPipeline": "SD 1.5",
287
- "diffusers:StableDiffusionXLPipeline": "SDXL",
288
- "diffusers:FluxPipeline": "FLUX",
289
- }
290
-
291
- POST_PROCESSING_SAMPLER = ["Use same sampler"] + scheduler_names[:-2]
292
-
293
- SUBTITLE_GUI = (
294
- "### This demo uses [diffusers](https://github.com/huggingface/diffusers)"
295
- " to perform different tasks in image generation."
296
- )
297
-
298
- def extract_parameters(input_string):
299
- parameters = {}
300
- input_string = input_string.replace("\n", "")
301
-
302
- if "Negative prompt:" not in input_string:
303
- if "Steps:" in input_string:
304
- input_string = input_string.replace("Steps:", "Negative prompt: Steps:")
305
- else:
306
- print("Invalid metadata")
307
- parameters["prompt"] = input_string
308
- return parameters
309
-
310
- parm = input_string.split("Negative prompt:")
311
- parameters["prompt"] = parm[0].strip()
312
- if "Steps:" not in parm[1]:
313
- print("Steps not detected")
314
- parameters["neg_prompt"] = parm[1].strip()
315
- return parameters
316
- parm = parm[1].split("Steps:")
317
- parameters["neg_prompt"] = parm[0].strip()
318
- input_string = "Steps:" + parm[1]
319
-
320
- # Extracting Steps
321
- steps_match = re.search(r'Steps: (\d+)', input_string)
322
- if steps_match:
323
- parameters['Steps'] = int(steps_match.group(1))
324
-
325
- # Extracting Size
326
- size_match = re.search(r'Size: (\d+x\d+)', input_string)
327
- if size_match:
328
- parameters['Size'] = size_match.group(1)
329
- width, height = map(int, parameters['Size'].split('x'))
330
- parameters['width'] = width
331
- parameters['height'] = height
332
-
333
- # Extracting other parameters
334
- other_parameters = re.findall(r'(\w+): (.*?)(?=, \w+|$)', input_string)
335
- for param in other_parameters:
336
- parameters[param[0]] = param[1].strip('"')
337
-
338
- return parameters
339
-
340
- def info_html(json_data, title, subtitle):
341
- return f"""
342
- <div style='padding: 0; border-radius: 10px;'>
343
- <p style='margin: 0; font-weight: bold;'>{title}</p>
344
- <details>
345
- <summary>Details</summary>
346
- <p style='margin: 0; font-weight: bold;'>{subtitle}</p>
347
- </details>
348
- </div>
349
- """
350
-
351
- def get_model_type(repo_id: str):
352
- api = HfApi(token=os.environ.get("HF_TOKEN")) # if use private or gated model
353
- default = "SD 1.5"
354
- try:
355
- model = api.model_info(repo_id=repo_id, timeout=5.0)
356
- tags = model.tags
357
- for tag in tags:
358
- if tag in MODEL_TYPE_CLASS.keys(): return MODEL_TYPE_CLASS.get(tag, default)
359
- except Exception:
360
- return default
361
- return default
362
-
363
  class GuiSD:
364
- def __init__(self):
365
  self.model = None
366
-
367
- print("Loading model...")
368
- self.model = Model_Diffusers(
369
- base_model_id="Lykon/dreamshaper-8",
370
- task_name="txt2img",
371
- vae_model=None,
372
- type_model_precision=torch.float16,
373
- retain_task_model_in_cache=False,
374
- device="cpu",
375
- )
376
- self.model.load_beta_styles()
377
- #self.model.device = torch.device("cpu") #
378
 
379
  def load_new_model(self, model_name, vae_model, task, progress=gr.Progress(track_tqdm=True)):
380
 
381
- yield f"Loading model: {model_name}"
382
-
383
  vae_model = vae_model if vae_model != "None" else None
384
  model_type = get_model_type(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
 
386
  if vae_model:
387
  vae_type = "SDXL" if "sdxl" in vae_model.lower() else "SD 1.5"
388
  if model_type != vae_type:
389
- gr.Warning(msg_inc_vae)
390
 
391
- self.model.device = torch.device("cpu")
392
- dtype_model = torch.bfloat16 if model_type == "FLUX" else torch.float16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
 
394
- self.model.load_pipe(
395
- model_name,
396
- task_name=TASK_STABLEPY[task],
397
- vae_model=vae_model if vae_model != "None" else None,
398
- type_model_precision=dtype_model,
399
- retain_task_model_in_cache=False,
400
- )
401
  yield f"Model loaded: {model_name}"
402
 
403
  #@spaces.GPU
@@ -506,18 +334,17 @@ class GuiSD:
506
  mode_ip2,
507
  scale_ip2,
508
  pag_scale,
509
- #progress=gr.Progress(track_tqdm=True),
510
  ):
511
- #progress(0, desc="Preparing inference...")
 
512
 
513
  vae_model = vae_model if vae_model != "None" else None
514
  loras_list = [lora1, lora2, lora3, lora4, lora5]
515
  vae_msg = f"VAE: {vae_model}" if vae_model else ""
516
  msg_lora = ""
517
 
518
- print("Config model:", model_name, vae_model, loras_list)
519
-
520
  ## BEGIN MOD
 
521
  global lora_model_list
522
  lora_model_list = get_lora_model_list()
523
  lora1, lora_scale1, lora2, lora_scale2, lora3, lora_scale3, lora4, lora_scale4, lora5, lora_scale5 = \
@@ -526,6 +353,8 @@ class GuiSD:
526
  prompt, neg_prompt = insert_model_recom_prompt(prompt, neg_prompt, model_name)
527
  ## END MOD
528
 
 
 
529
  task = TASK_STABLEPY[task]
530
 
531
  params_ip_img = []
@@ -548,7 +377,8 @@ class GuiSD:
548
  params_ip_mode.append(modeip)
549
  params_ip_scale.append(scaleip)
550
 
551
- self.model.stream_config(concurrency=5, latent_resize_by=1, vae_decoding=False)
 
552
 
553
  if task != "txt2img" and not image_control:
554
  raise ValueError("No control image found: To use this function, you have to upload an image in 'Image ControlNet/Inpaint/Img2img'")
@@ -621,15 +451,15 @@ class GuiSD:
621
  "high_threshold": high_threshold,
622
  "value_threshold": value_threshold,
623
  "distance_threshold": distance_threshold,
624
- "lora_A": lora1 if lora1 != "None" and lora1 != "" else None,
625
  "lora_scale_A": lora_scale1,
626
- "lora_B": lora2 if lora2 != "None" and lora2 != "" else None,
627
  "lora_scale_B": lora_scale2,
628
- "lora_C": lora3 if lora3 != "None" and lora3 != "" else None,
629
  "lora_scale_C": lora_scale3,
630
- "lora_D": lora4 if lora4 != "None" and lora4 != "" else None,
631
  "lora_scale_D": lora_scale4,
632
- "lora_E": lora5 if lora5 != "None" and lora5 != "" else None,
633
  "lora_scale_E": lora_scale5,
634
  ## BEGIN MOD
635
  "textual_inversion": get_embed_list(self.model.class_name) if textual_inversion else [],
@@ -679,19 +509,24 @@ class GuiSD:
679
  }
680
 
681
  self.model.device = torch.device("cuda:0")
682
- if hasattr(self.model.pipe, "transformer") and loras_list != ["None"] * 5 and loras_list != [""] * 5:
683
  self.model.pipe.transformer.to(self.model.device)
684
  print("transformer to cuda")
685
 
686
- #progress(0, desc="Preparation completed. Starting inference...")
687
-
688
- info_state = "PROCESSING "
689
  for img, seed, image_path, metadata in self.model(**pipe_params):
690
- info_state += ">"
 
691
  if image_path:
692
- info_state = f"COMPLETE. Seeds: {str(seed)}"
693
  if vae_msg:
694
- info_state = info_state + "<br>" + vae_msg
 
 
 
 
 
695
 
696
  for status, lora in zip(self.model.lora_status, self.model.lora_memory):
697
  if status:
@@ -700,9 +535,9 @@ class GuiSD:
700
  msg_lora += f"<br>Error with: {lora}"
701
 
702
  if msg_lora:
703
- info_state += msg_lora
704
 
705
- info_state = info_state + "<br>" + "GENERATION DATA:<br>" + metadata[0].replace("\n", "<br>") + "<br>-------<br>"
706
 
707
  download_links = "<br>".join(
708
  [
@@ -711,19 +546,16 @@ class GuiSD:
711
  ]
712
  )
713
  if save_generated_images:
714
- info_state += f"<br>{download_links}"
715
 
 
716
  img = save_images(img, metadata)
 
717
 
718
- yield img, info_state
719
-
720
- def update_task_options(model_name, task_name):
721
- new_choices = MODEL_TYPE_TASK[get_model_type(model_name)]
722
 
723
- if task_name not in new_choices:
724
- task_name = "txt2img"
725
 
726
- return gr.update(value=task_name, choices=new_choices)
727
 
728
  def dynamic_gpu_duration(func, duration, *args):
729
 
@@ -733,10 +565,12 @@ def dynamic_gpu_duration(func, duration, *args):
733
 
734
  return wrapped_func()
735
 
 
736
  @spaces.GPU
737
  def dummy_gpu():
738
  return None
739
 
 
740
  def sd_gen_generate_pipeline(*args):
741
 
742
  gpu_duration_arg = int(args[-1]) if args[-1] else 59
@@ -744,7 +578,7 @@ def sd_gen_generate_pipeline(*args):
744
  load_lora_cpu = args[-3]
745
  generation_args = args[:-3]
746
  lora_list = [
747
- None if item == "None" or item == "" else item
748
  for item in [args[7], args[9], args[11], args[13], args[15]]
749
  ]
750
  lora_status = [None] * 5
@@ -754,7 +588,7 @@ def sd_gen_generate_pipeline(*args):
754
  msg_load_lora = "Updating LoRAs in CPU (Slow but saves GPU usage)..."
755
 
756
  if lora_list != sd_gen.model.lora_memory and lora_list != [None] * 5:
757
- yield None, msg_load_lora
758
 
759
  # Load lora in CPU
760
  if load_lora_cpu:
@@ -780,14 +614,15 @@ def sd_gen_generate_pipeline(*args):
780
  )
781
  gr.Info(f"LoRAs in cache: {lora_cache_msg}")
782
 
783
- msg_request = f"Requesting {gpu_duration_arg}s. of GPU time"
 
784
  gr.Info(msg_request)
785
  print(msg_request)
786
-
787
- # yield from sd_gen.generate_pipeline(*generation_args)
788
 
789
  start_time = time.time()
790
 
 
791
  yield from dynamic_gpu_duration(
792
  sd_gen.generate_pipeline,
793
  gpu_duration_arg,
@@ -795,31 +630,19 @@ def sd_gen_generate_pipeline(*args):
795
  )
796
 
797
  end_time = time.time()
 
 
 
 
798
 
799
  if verbose_arg:
800
- execution_time = end_time - start_time
801
- msg_task_complete = (
802
- f"GPU task complete in: {round(execution_time, 0) + 1} seconds"
803
- )
804
  gr.Info(msg_task_complete)
805
  print(msg_task_complete)
806
 
807
- def extract_exif_data(image):
808
- if image is None: return ""
809
-
810
- try:
811
- metadata_keys = ['parameters', 'metadata', 'prompt', 'Comment']
812
-
813
- for key in metadata_keys:
814
- if key in image.info:
815
- return image.info[key]
816
-
817
- return str(image.info)
818
 
819
- except Exception as e:
820
- return f"Error extracting metadata: {str(e)}"
821
 
822
- @spaces.GPU(duration=20)
823
  def esrgan_upscale(image, upscaler_name, upscaler_size):
824
  if image is None: return None
825
 
@@ -841,17 +664,20 @@ def esrgan_upscale(image, upscaler_name, upscaler_size):
841
 
842
  return image_path
843
 
 
844
  dynamic_gpu_duration.zerogpu = True
845
  sd_gen_generate_pipeline.zerogpu = True
846
  sd_gen = GuiSD()
847
 
 
848
  ## BEGIN MOD
849
  CSS ="""
850
- .gradio-container, #main { width:100%; height:100%; max-width:100%; padding-left:0; padding-right:0; margin-left:0; margin-right:0; !important; }
851
- .contain { display:flex; flex-direction:column; !important; }
852
- #component-0 { width:100%; height:100%; !important; }
853
- #gallery { flex-grow:1; !important; }
854
- .lora { min-width:480px; !important; }
 
855
  #model-info { text-align:center; }
856
  .title { font-size: 3em; align-items: center; text-align: center; }
857
  .info { align-items: center; text-align: center; }
@@ -865,6 +691,15 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs
865
  with gr.Tab("Generation"):
866
  with gr.Row():
867
  with gr.Column(scale=2):
 
 
 
 
 
 
 
 
 
868
  interface_mode_gui = gr.Radio(label="Quick settings", choices=["Simple", "Standard", "Fast", "LoRA"], value="Standard")
869
  with gr.Accordion("Model and Task", open=False) as menu_model:
870
  task_gui = gr.Dropdown(label="Task", choices=SDXL_TASK, value=TASK_MODEL_LIST[0])
@@ -928,7 +763,7 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs
928
  [task_gui],
929
  )
930
 
931
- load_model_gui = gr.HTML()
932
 
933
  result_images = gr.Gallery(
934
  label="Generated images",
@@ -950,6 +785,13 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs
950
 
951
  actual_task_info = gr.HTML()
952
 
 
 
 
 
 
 
 
953
  with gr.Row(equal_height=False, variant="default"):
954
  gpu_duration_gui = gr.Number(minimum=5, maximum=240, value=59, show_label=False, container=False, info="GPU time duration (seconds)")
955
  with gr.Column():
@@ -972,15 +814,7 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs
972
  with gr.Row():
973
  sampler_gui = gr.Dropdown(label="Sampler", choices=scheduler_names, value="Euler a")
974
  vae_model_gui = gr.Dropdown(label="VAE Model", choices=vae_model_list, value=vae_model_list[0])
975
- prompt_s_options = [
976
- ("Compel format: (word)weight", "Compel"),
977
- ("Classic format: (word:weight)", "Classic"),
978
- ("Classic-original format: (word:weight)", "Classic-original"),
979
- ("Classic-no_norm format: (word:weight)", "Classic-no_norm"),
980
- ("Classic-ignore", "Classic-ignore"),
981
- ("None", "None"),
982
- ]
983
- prompt_syntax_gui = gr.Dropdown(label="Prompt Syntax", choices=prompt_s_options, value=prompt_s_options[1][1])
984
 
985
  with gr.Row(equal_height=False):
986
 
@@ -1130,8 +964,11 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs
1130
  with gr.Accordion("Select from Gallery", open=False):
1131
  search_civitai_gallery_lora = gr.Gallery([], label="Results", allow_preview=False, columns=5, show_share_button=False, interactive=False)
1132
  search_civitai_result_lora = gr.Dropdown(label="Search Results", choices=[("", "")], value="", allow_custom_value=True, visible=False)
1133
- text_lora = gr.Textbox(label="LoRA URL", placeholder="https://civitai.com/api/download/models/28907", lines=1)
 
 
1134
  button_lora = gr.Button("Get and update lists of LoRAs")
 
1135
  with gr.Accordion("From Local", open=True, visible=True):
1136
  file_output_lora = gr.File(label="Uploaded LoRA", file_types=['.ckpt', '.pt', '.pth', '.safetensors', '.bin'], file_count="multiple", interactive=False, visible=False)
1137
  upload_button_lora = gr.UploadButton(label="Upload LoRA from your disk (very slow)", file_types=['.ckpt', '.pt', '.pth', '.safetensors', '.bin'], file_count="multiple")
@@ -1169,7 +1006,7 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs
1169
  negative_prompt_ad_a_gui = gr.Textbox(label="Negative prompt", placeholder="Main negative prompt will be use", lines=3)
1170
  with gr.Row():
1171
  strength_ad_a_gui = gr.Number(label="Strength:", value=0.35, step=0.01, minimum=0.01, maximum=1.0)
1172
- face_detector_ad_a_gui = gr.Checkbox(label="Face detector", value=True)
1173
  person_detector_ad_a_gui = gr.Checkbox(label="Person detector", value=True)
1174
  hand_detector_ad_a_gui = gr.Checkbox(label="Hand detector", value=False)
1175
  with gr.Row():
@@ -1184,7 +1021,7 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs
1184
  negative_prompt_ad_b_gui = gr.Textbox(label="Negative prompt", placeholder="Main negative prompt will be use", lines=3)
1185
  with gr.Row():
1186
  strength_ad_b_gui = gr.Number(label="Strength:", value=0.35, step=0.01, minimum=0.01, maximum=1.0)
1187
- face_detector_ad_b_gui = gr.Checkbox(label="Face detector", value=True)
1188
  person_detector_ad_b_gui = gr.Checkbox(label="Person detector", value=True)
1189
  hand_detector_ad_b_gui = gr.Checkbox(label="Hand detector", value=False)
1190
  with gr.Row():
@@ -1314,73 +1151,7 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs
1314
 
1315
  with gr.Accordion("Examples and help", open=True, visible=True) as menu_example:
1316
  gr.Examples(
1317
- examples=[
1318
- [
1319
- "1girl, souryuu asuka langley, neon genesis evangelion, plugsuit, pilot suit, red bodysuit, sitting, crossing legs, black eye patch, cat hat, throne, symmetrical, looking down, from bottom, looking at viewer, outdoors, masterpiece, best quality, very aesthetic, absurdres",
1320
- "nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
1321
- 1,
1322
- 30,
1323
- 7.5,
1324
- True,
1325
- -1,
1326
- "Euler a",
1327
- 1152,
1328
- 896,
1329
- "votepurchase/animagine-xl-3.1",
1330
- ],
1331
- [
1332
- "solo, princess Zelda OOT, score_9, score_8_up, score_8, medium breasts, cute, eyelashes, cute small face, long hair, crown braid, hairclip, pointy ears, soft curvy body, looking at viewer, smile, blush, white dress, medium body, (((holding the Master Sword))), standing, deep forest in the background",
1333
- "score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white,",
1334
- 1,
1335
- 30,
1336
- 5.,
1337
- True,
1338
- -1,
1339
- "Euler a",
1340
- 1024,
1341
- 1024,
1342
- "votepurchase/ponyDiffusionV6XL",
1343
- ],
1344
- [
1345
- "1girl, oomuro sakurako, yuru yuri, official art, school uniform, anime artwork, anime style, vibrant, studio anime, highly detailed, masterpiece, best quality, very aesthetic, absurdres",
1346
- "photo, deformed, black and white, realism, disfigured, low contrast, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
1347
- 1,
1348
- 40,
1349
- 7.0,
1350
- True,
1351
- -1,
1352
- "Euler a",
1353
- 1024,
1354
- 1024,
1355
- "Raelina/Rae-Diffusion-XL-V2",
1356
- ],
1357
- [
1358
- "1girl, akaza akari, yuru yuri, official art, anime artwork, anime style, vibrant, studio anime, highly detailed, masterpiece, best quality, very aesthetic, absurdres",
1359
- "photo, deformed, black and white, realism, disfigured, low contrast, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
1360
- 1,
1361
- 35,
1362
- 7.0,
1363
- True,
1364
- -1,
1365
- "Euler a",
1366
- 1024,
1367
- 1024,
1368
- "Raelina/Raemu-XL-V4",
1369
- ],
1370
- [
1371
- "yoshida yuuko, machikado mazoku, 1girl, solo, demon horns,horns, school uniform, long hair, open mouth, skirt, demon girl, ahoge, shiny, shiny hair, anime artwork",
1372
- "nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
1373
- 1,
1374
- 50,
1375
- 7.,
1376
- True,
1377
- -1,
1378
- "Euler a",
1379
- 1024,
1380
- 1024,
1381
- "cagliostrolab/animagine-xl-3.1",
1382
- ],
1383
- ],
1384
  fn=sd_gen.generate_pipeline,
1385
  inputs=[
1386
  prompt_gui,
@@ -1395,16 +1166,12 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs
1395
  img_width_gui,
1396
  model_name_gui,
1397
  ],
1398
- outputs=[result_images, actual_task_info],
1399
  cache_examples=False,
1400
  #elem_id="examples",
1401
  )
1402
 
1403
- gr.Markdown(
1404
- """### Resources
1405
- - You can also try the image generator in Colab’s free tier, which provides free GPU [link](https://github.com/R3gm/SD_diffusers_interactive).
1406
- """
1407
- )
1408
  ## END MOD
1409
 
1410
  with gr.Tab("Inpaint mask maker", render=True):
@@ -1563,7 +1330,7 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs
1563
  )
1564
  search_civitai_result_lora.change(select_civitai_lora, [search_civitai_result_lora], [text_lora, search_civitai_desc_lora], queue=False, scroll_to_output=True)
1565
  search_civitai_gallery_lora.select(update_civitai_selection, None, [search_civitai_result_lora], queue=False, show_api=False)
1566
- button_lora.click(get_my_lora, [text_lora], [lora1_gui, lora2_gui, lora3_gui, lora4_gui, lora5_gui], scroll_to_output=True)
1567
  upload_button_lora.upload(upload_file_lora, [upload_button_lora], [file_output_lora, upload_button_lora]).success(
1568
  move_file_lora, [file_output_lora], [lora1_gui, lora2_gui, lora3_gui, lora4_gui, lora5_gui], scroll_to_output=True)
1569
 
@@ -1719,10 +1486,11 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs
1719
  verbose_info_gui,
1720
  gpu_duration_gui,
1721
  ],
1722
- outputs=[result_images, actual_task_info],
1723
  queue=True,
1724
  show_progress="full",
1725
- ).success(save_gallery_images, [result_images], [result_images, result_images_files], queue=False, show_api=False)
 
1726
 
1727
  with gr.Tab("Danbooru Tags Transformer with WD Tagger", render=True):
1728
  with gr.Column(scale=2):
@@ -1818,5 +1586,5 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs
1818
  gr.DuplicateButton(value="Duplicate Space for private use (This demo does not work on CPU. Requires GPU Space)")
1819
 
1820
  app.queue()
1821
- app.launch() # allowed_paths=["./images/"], show_error=True, debug=True
1822
  ## END MOD
 
1
  import spaces
 
2
  import os
3
  from stablepy import Model_Diffusers
4
+ from constants import (
5
+ PREPROCESSOR_CONTROLNET,
6
+ TASK_STABLEPY,
7
+ TASK_MODEL_LIST,
8
+ UPSCALER_DICT_GUI,
9
+ UPSCALER_KEYS,
10
+ PROMPT_W_OPTIONS,
11
+ WARNING_MSG_VAE,
12
+ SDXL_TASK,
13
+ MODEL_TYPE_TASK,
14
+ POST_PROCESSING_SAMPLER,
15
+
16
+ )
17
  from stablepy.diffusers_vanilla.style_prompt_config import STYLE_NAMES
 
18
  import torch
19
  import re
 
20
  from stablepy import (
 
 
 
 
 
21
  scheduler_names,
 
22
  IP_ADAPTERS_SD,
23
  IP_ADAPTERS_SDXL,
 
 
 
 
24
  )
25
  import time
26
  from PIL import ImageFile
27
+ from utils import (
28
+ get_model_list,
29
+ extract_parameters,
30
+ get_model_type,
31
+ extract_exif_data,
32
+ create_mask_now,
33
+ download_diffuser_repo,
34
+ progress_step_bar,
35
+ html_template_message,
36
+ )
37
+ from datetime import datetime
38
+ import gradio as gr
39
+ import logging
40
+ import diffusers
41
+ import warnings
42
+ from stablepy import logger
43
+ # import urllib.parse
44
 
45
  ImageFile.LOAD_TRUNCATED_IMAGES = True
46
+ # os.environ["PYTORCH_NO_CUDA_MEMORY_CACHING"] = "1"
47
  print(os.getenv("SPACES_ZERO_GPU"))
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  ## BEGIN MOD
50
  from modutils import (list_uniq, download_private_repo, get_model_id_list, get_tupled_embed_list,
51
  get_lora_model_list, get_all_lora_tupled_list, update_loras, apply_lora_prompt, set_prompt_loras,
52
  get_my_lora, upload_file_lora, move_file_lora, search_civitai_lora, select_civitai_lora,
53
  update_civitai_selection, get_civitai_tag, CIVITAI_SORT, CIVITAI_PERIOD, CIVITAI_BASEMODEL,
54
  set_textual_inversion_prompt, get_model_pipeline, change_interface_mode, get_t2i_model_info,
55
+ get_tupled_model_list, save_gallery_images, save_gallery_history, set_optimization, set_sampler_settings,
56
  set_quick_presets, process_style_prompt, optimization_list, save_images, download_things,
57
+ preset_styles, preset_quality, preset_sampler_setting, translate_to_en, EXAMPLES_GUI, RESOURCES)
58
  from env import (HF_TOKEN, CIVITAI_API_KEY, HF_LORA_ESSENTIAL_PRIVATE_REPO, HF_VAE_PRIVATE_REPO,
59
  HF_SDXL_EMBEDS_NEGATIVE_PRIVATE_REPO, HF_SDXL_EMBEDS_POSITIVE_PRIVATE_REPO,
60
+ DIRECTORY_MODELS, DIRECTORY_LORAS, DIRECTORY_VAES, DIRECTORY_EMBEDS, DIRECTORY_EMBEDS_SDXL,
61
+ DIRECTORY_EMBEDS_POSITIVE_SDXL, LOAD_DIFFUSERS_FORMAT_MODEL,
62
+ DOWNLOAD_MODEL_LIST, DOWNLOAD_LORA_LIST, DOWNLOAD_VAE_LIST, DOWNLOAD_EMBEDS)
63
+
64
+ download_private_repo(HF_LORA_ESSENTIAL_PRIVATE_REPO, DIRECTORY_LORAS, True)
65
+ download_private_repo(HF_VAE_PRIVATE_REPO, DIRECTORY_VAES, False)
66
+ ## END MOD
67
 
68
  # - **Download Models**
69
+ DOWNLOAD_MODEL = ", ".join(DOWNLOAD_MODEL_LIST)
70
  # - **Download VAEs**
71
+ DOWNLOAD_VAE = ", ".join(DOWNLOAD_VAE_LIST)
72
  # - **Download LoRAs**
73
+ DOWNLOAD_LORA = ", ".join(DOWNLOAD_LORA_LIST)
 
 
 
 
 
 
74
 
75
  # Download stuffs
76
+ for url in [url.strip() for url in DOWNLOAD_MODEL.split(',')]:
77
  if not os.path.exists(f"./models/{url.split('/')[-1]}"):
78
+ download_things(DIRECTORY_MODELS, url, HF_TOKEN, CIVITAI_API_KEY)
79
+ for url in [url.strip() for url in DOWNLOAD_VAE.split(',')]:
80
  if not os.path.exists(f"./vaes/{url.split('/')[-1]}"):
81
+ download_things(DIRECTORY_VAES, url, HF_TOKEN, CIVITAI_API_KEY)
82
+ for url in [url.strip() for url in DOWNLOAD_LORA.split(',')]:
83
  if not os.path.exists(f"./loras/{url.split('/')[-1]}"):
84
+ download_things(DIRECTORY_LORAS, url, HF_TOKEN, CIVITAI_API_KEY)
85
 
86
  # Download Embeddings
87
+ for url_embed in DOWNLOAD_EMBEDS:
88
  if not os.path.exists(f"./embedings/{url_embed.split('/')[-1]}"):
89
+ download_things(DIRECTORY_EMBEDS, url_embed, HF_TOKEN, CIVITAI_API_KEY)
90
 
91
  # Build list models
92
+ embed_list = get_model_list(DIRECTORY_EMBEDS)
 
 
 
93
  lora_model_list = get_lora_model_list()
94
+ vae_model_list = get_model_list(DIRECTORY_VAES)
95
  vae_model_list.insert(0, "None")
96
 
97
+ ## BEGIN MOD
98
+ model_list = list_uniq(get_model_id_list() + LOAD_DIFFUSERS_FORMAT_MODEL + get_model_list(DIRECTORY_MODELS))
99
+ download_private_repo(HF_SDXL_EMBEDS_NEGATIVE_PRIVATE_REPO, DIRECTORY_EMBEDS_SDXL, False)
100
+ download_private_repo(HF_SDXL_EMBEDS_POSITIVE_PRIVATE_REPO, DIRECTORY_EMBEDS_POSITIVE_SDXL, False)
101
+ embed_sdxl_list = get_model_list(DIRECTORY_EMBEDS_SDXL) + get_model_list(DIRECTORY_EMBEDS_POSITIVE_SDXL)
102
 
103
  def get_embed_list(pipeline_name):
104
  return get_tupled_embed_list(embed_sdxl_list if pipeline_name == "StableDiffusionXLPipeline" else embed_list)
 
121
  warnings.filterwarnings(action="ignore", category=FutureWarning, module="transformers")
122
  ## BEGIN MOD
123
  from stablepy import logger
124
+ #logger.setLevel(logging.CRITICAL)
125
+ logger.setLevel(logging.DEBUG)
126
 
127
+ from tagger.v2 import V2_ALL_MODELS, v2_random_prompt, v2_upsampling_prompt
128
+ from tagger.utils import (gradio_copy_text, COPY_ACTION_JS, gradio_copy_prompt,
129
  V2_ASPECT_RATIO_OPTIONS, V2_RATING_OPTIONS, V2_LENGTH_OPTIONS, V2_IDENTITY_OPTIONS)
130
+ from tagger.tagger import (predict_tags_wd, convert_danbooru_to_e621_prompt,
131
  remove_specific_prompt, insert_recom_prompt, insert_model_recom_prompt,
132
  compose_prompt_to_copy, translate_prompt, select_random_character)
133
  def description_ui():
 
141
  )
142
  ## END MOD
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  class GuiSD:
145
+ def __init__(self, stream=True):
146
  self.model = None
147
+ self.status_loading = False
148
+ self.sleep_loading = 4
149
+ self.last_load = datetime.now()
 
 
 
 
 
 
 
 
 
150
 
151
  def load_new_model(self, model_name, vae_model, task, progress=gr.Progress(track_tqdm=True)):
152
 
 
 
153
  vae_model = vae_model if vae_model != "None" else None
154
  model_type = get_model_type(model_name)
155
+ dtype_model = torch.bfloat16 if model_type == "FLUX" else torch.float16
156
+
157
+ if not os.path.exists(model_name):
158
+ _ = download_diffuser_repo(
159
+ repo_name=model_name,
160
+ model_type=model_type,
161
+ revision="main",
162
+ token=True,
163
+ )
164
+
165
+ for i in range(68):
166
+ if not self.status_loading:
167
+ self.status_loading = True
168
+ if i > 0:
169
+ time.sleep(self.sleep_loading)
170
+ print("Previous model ops...")
171
+ break
172
+ time.sleep(0.5)
173
+ print(f"Waiting queue {i}")
174
+ yield "Waiting queue"
175
+
176
+ self.status_loading = True
177
+
178
+ yield f"Loading model: {model_name}"
179
 
180
  if vae_model:
181
  vae_type = "SDXL" if "sdxl" in vae_model.lower() else "SD 1.5"
182
  if model_type != vae_type:
183
+ gr.Warning(WARNING_MSG_VAE)
184
 
185
+ print("Loading model...")
186
+
187
+ try:
188
+ start_time = time.time()
189
+
190
+ if self.model is None:
191
+ self.model = Model_Diffusers(
192
+ base_model_id=model_name,
193
+ task_name=TASK_STABLEPY[task],
194
+ vae_model=vae_model,
195
+ type_model_precision=dtype_model,
196
+ retain_task_model_in_cache=False,
197
+ device="cpu",
198
+ )
199
+ else:
200
+
201
+ if self.model.base_model_id != model_name:
202
+ load_now_time = datetime.now()
203
+ elapsed_time = max((load_now_time - self.last_load).total_seconds(), 0)
204
+
205
+ if elapsed_time <= 8:
206
+ print("Waiting for the previous model's time ops...")
207
+ time.sleep(8-elapsed_time)
208
+
209
+ self.model.device = torch.device("cpu")
210
+ self.model.load_pipe(
211
+ model_name,
212
+ task_name=TASK_STABLEPY[task],
213
+ vae_model=vae_model,
214
+ type_model_precision=dtype_model,
215
+ retain_task_model_in_cache=False,
216
+ )
217
+
218
+ end_time = time.time()
219
+ self.sleep_loading = max(min(int(end_time - start_time), 10), 4)
220
+ except Exception as e:
221
+ self.last_load = datetime.now()
222
+ self.status_loading = False
223
+ self.sleep_loading = 4
224
+ raise e
225
+
226
+ self.last_load = datetime.now()
227
+ self.status_loading = False
228
 
 
 
 
 
 
 
 
229
  yield f"Model loaded: {model_name}"
230
 
231
  #@spaces.GPU
 
334
  mode_ip2,
335
  scale_ip2,
336
  pag_scale,
 
337
  ):
338
+ info_state = html_template_message("Navigating latent space...")
339
+ yield info_state, gr.update(), gr.update()
340
 
341
  vae_model = vae_model if vae_model != "None" else None
342
  loras_list = [lora1, lora2, lora3, lora4, lora5]
343
  vae_msg = f"VAE: {vae_model}" if vae_model else ""
344
  msg_lora = ""
345
 
 
 
346
  ## BEGIN MOD
347
+ loras_list = [s if s else "None" for s in loras_list]
348
  global lora_model_list
349
  lora_model_list = get_lora_model_list()
350
  lora1, lora_scale1, lora2, lora_scale2, lora3, lora_scale3, lora4, lora_scale4, lora5, lora_scale5 = \
 
353
  prompt, neg_prompt = insert_model_recom_prompt(prompt, neg_prompt, model_name)
354
  ## END MOD
355
 
356
+ print("Config model:", model_name, vae_model, loras_list)
357
+
358
  task = TASK_STABLEPY[task]
359
 
360
  params_ip_img = []
 
377
  params_ip_mode.append(modeip)
378
  params_ip_scale.append(scaleip)
379
 
380
+ concurrency = 5
381
+ self.model.stream_config(concurrency=concurrency, latent_resize_by=1, vae_decoding=False)
382
 
383
  if task != "txt2img" and not image_control:
384
  raise ValueError("No control image found: To use this function, you have to upload an image in 'Image ControlNet/Inpaint/Img2img'")
 
451
  "high_threshold": high_threshold,
452
  "value_threshold": value_threshold,
453
  "distance_threshold": distance_threshold,
454
+ "lora_A": lora1 if lora1 != "None" else None,
455
  "lora_scale_A": lora_scale1,
456
+ "lora_B": lora2 if lora2 != "None" else None,
457
  "lora_scale_B": lora_scale2,
458
+ "lora_C": lora3 if lora3 != "None" else None,
459
  "lora_scale_C": lora_scale3,
460
+ "lora_D": lora4 if lora4 != "None" else None,
461
  "lora_scale_D": lora_scale4,
462
+ "lora_E": lora5 if lora5 != "None" else None,
463
  "lora_scale_E": lora_scale5,
464
  ## BEGIN MOD
465
  "textual_inversion": get_embed_list(self.model.class_name) if textual_inversion else [],
 
509
  }
510
 
511
  self.model.device = torch.device("cuda:0")
512
+ if hasattr(self.model.pipe, "transformer") and loras_list != ["None"] * 5:
513
  self.model.pipe.transformer.to(self.model.device)
514
  print("transformer to cuda")
515
 
516
+ actual_progress = 0
517
+ info_images = gr.update()
 
518
  for img, seed, image_path, metadata in self.model(**pipe_params):
519
+ info_state = progress_step_bar(actual_progress, steps)
520
+ actual_progress += concurrency
521
  if image_path:
522
+ info_images = f"Seeds: {str(seed)}"
523
  if vae_msg:
524
+ info_images = info_images + "<br>" + vae_msg
525
+
526
+ if "Cannot copy out of meta tensor; no data!" in self.model.last_lora_error:
527
+ msg_ram = "Unable to process the LoRAs due to high RAM usage; please try again later."
528
+ print(msg_ram)
529
+ msg_lora += f"<br>{msg_ram}"
530
 
531
  for status, lora in zip(self.model.lora_status, self.model.lora_memory):
532
  if status:
 
535
  msg_lora += f"<br>Error with: {lora}"
536
 
537
  if msg_lora:
538
+ info_images += msg_lora
539
 
540
+ info_images = info_images + "<br>" + "GENERATION DATA:<br>" + metadata[0].replace("\n", "<br>") + "<br>-------<br>"
541
 
542
  download_links = "<br>".join(
543
  [
 
546
  ]
547
  )
548
  if save_generated_images:
549
+ info_images += f"<br>{download_links}"
550
 
551
+ ## BEGIN MOD
552
  img = save_images(img, metadata)
553
+ ## END MOD
554
 
555
+ info_state = "COMPLETE"
 
 
 
556
 
557
+ yield info_state, img, info_images
 
558
 
 
559
 
560
  def dynamic_gpu_duration(func, duration, *args):
561
 
 
565
 
566
  return wrapped_func()
567
 
568
+
569
  @spaces.GPU
570
  def dummy_gpu():
571
  return None
572
 
573
+
574
  def sd_gen_generate_pipeline(*args):
575
 
576
  gpu_duration_arg = int(args[-1]) if args[-1] else 59
 
578
  load_lora_cpu = args[-3]
579
  generation_args = args[:-3]
580
  lora_list = [
581
+ None if item == "None" or item == "" else item # MOD
582
  for item in [args[7], args[9], args[11], args[13], args[15]]
583
  ]
584
  lora_status = [None] * 5
 
588
  msg_load_lora = "Updating LoRAs in CPU (Slow but saves GPU usage)..."
589
 
590
  if lora_list != sd_gen.model.lora_memory and lora_list != [None] * 5:
591
+ yield msg_load_lora, gr.update(), gr.update()
592
 
593
  # Load lora in CPU
594
  if load_lora_cpu:
 
614
  )
615
  gr.Info(f"LoRAs in cache: {lora_cache_msg}")
616
 
617
+ msg_request = f"Requesting {gpu_duration_arg}s. of GPU time.\nModel: {sd_gen.model.base_model_id}"
618
+ if verbose_arg:
619
  gr.Info(msg_request)
620
  print(msg_request)
621
+ yield msg_request.replace("\n", "<br>"), gr.update(), gr.update()
 
622
 
623
  start_time = time.time()
624
 
625
+ # yield from sd_gen.generate_pipeline(*generation_args)
626
  yield from dynamic_gpu_duration(
627
  sd_gen.generate_pipeline,
628
  gpu_duration_arg,
 
630
  )
631
 
632
  end_time = time.time()
633
+ execution_time = end_time - start_time
634
+ msg_task_complete = (
635
+ f"GPU task complete in: {int(round(execution_time, 0) + 1)} seconds"
636
+ )
637
 
638
  if verbose_arg:
 
 
 
 
639
  gr.Info(msg_task_complete)
640
  print(msg_task_complete)
641
 
642
+ yield msg_task_complete, gr.update(), gr.update()
 
 
 
 
 
 
 
 
 
 
643
 
 
 
644
 
645
+ @spaces.GPU(duration=15)
646
  def esrgan_upscale(image, upscaler_name, upscaler_size):
647
  if image is None: return None
648
 
 
664
 
665
  return image_path
666
 
667
+
668
  dynamic_gpu_duration.zerogpu = True
669
  sd_gen_generate_pipeline.zerogpu = True
670
  sd_gen = GuiSD()
671
 
672
+
673
  ## BEGIN MOD
674
  CSS ="""
675
+ .gradio-container, #main { width:100%; height:100%; max-width:100%; padding-left:0; padding-right:0; margin-left:0; margin-right:0; }
676
+ .contain { display:flex; flex-direction:column; }
677
+ #component-0 { width:100%; height:100%; }
678
+ #gallery { flex-grow:1; }
679
+ #load_model { height: 50px; }
680
+ .lora { min-width:480px; }
681
  #model-info { text-align:center; }
682
  .title { font-size: 3em; align-items: center; text-align: center; }
683
  .info { align-items: center; text-align: center; }
 
691
  with gr.Tab("Generation"):
692
  with gr.Row():
693
  with gr.Column(scale=2):
694
+
695
+ def update_task_options(model_name, task_name):
696
+ new_choices = MODEL_TYPE_TASK[get_model_type(model_name)]
697
+
698
+ if task_name not in new_choices:
699
+ task_name = "txt2img"
700
+
701
+ return gr.update(value=task_name, choices=new_choices)
702
+
703
  interface_mode_gui = gr.Radio(label="Quick settings", choices=["Simple", "Standard", "Fast", "LoRA"], value="Standard")
704
  with gr.Accordion("Model and Task", open=False) as menu_model:
705
  task_gui = gr.Dropdown(label="Task", choices=SDXL_TASK, value=TASK_MODEL_LIST[0])
 
763
  [task_gui],
764
  )
765
 
766
+ load_model_gui = gr.HTML(elem_id="load_model", elem_classes="contain")
767
 
768
  result_images = gr.Gallery(
769
  label="Generated images",
 
785
 
786
  actual_task_info = gr.HTML()
787
 
788
+ with gr.Accordion("History", open=False):
789
+ history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", format="png", interactive=False, show_share_button=False,
790
+ show_download_button=True)
791
+ history_files = gr.Files(interactive=False, visible=False)
792
+ history_clear_button = gr.Button(value="Clear History", variant="secondary")
793
+ history_clear_button.click(lambda: ([], []), None, [history_gallery, history_files], queue=False, show_api=False)
794
+
795
  with gr.Row(equal_height=False, variant="default"):
796
  gpu_duration_gui = gr.Number(minimum=5, maximum=240, value=59, show_label=False, container=False, info="GPU time duration (seconds)")
797
  with gr.Column():
 
814
  with gr.Row():
815
  sampler_gui = gr.Dropdown(label="Sampler", choices=scheduler_names, value="Euler a")
816
  vae_model_gui = gr.Dropdown(label="VAE Model", choices=vae_model_list, value=vae_model_list[0])
817
+ prompt_syntax_gui = gr.Dropdown(label="Prompt Syntax", choices=PROMPT_W_OPTIONS, value=PROMPT_W_OPTIONS[1][1])
 
 
 
 
 
 
 
 
818
 
819
  with gr.Row(equal_height=False):
820
 
 
964
  with gr.Accordion("Select from Gallery", open=False):
965
  search_civitai_gallery_lora = gr.Gallery([], label="Results", allow_preview=False, columns=5, show_share_button=False, interactive=False)
966
  search_civitai_result_lora = gr.Dropdown(label="Search Results", choices=[("", "")], value="", allow_custom_value=True, visible=False)
967
+ with gr.Row():
968
+ text_lora = gr.Textbox(label="LoRA URL", placeholder="https://civitai.com/api/download/models/28907", lines=1, scale=4)
969
+ romanize_text = gr.Checkbox(value=False, label="Transliterate name", scale=1)
970
  button_lora = gr.Button("Get and update lists of LoRAs")
971
+ new_lora_status = gr.HTML()
972
  with gr.Accordion("From Local", open=True, visible=True):
973
  file_output_lora = gr.File(label="Uploaded LoRA", file_types=['.ckpt', '.pt', '.pth', '.safetensors', '.bin'], file_count="multiple", interactive=False, visible=False)
974
  upload_button_lora = gr.UploadButton(label="Upload LoRA from your disk (very slow)", file_types=['.ckpt', '.pt', '.pth', '.safetensors', '.bin'], file_count="multiple")
 
1006
  negative_prompt_ad_a_gui = gr.Textbox(label="Negative prompt", placeholder="Main negative prompt will be use", lines=3)
1007
  with gr.Row():
1008
  strength_ad_a_gui = gr.Number(label="Strength:", value=0.35, step=0.01, minimum=0.01, maximum=1.0)
1009
+ face_detector_ad_a_gui = gr.Checkbox(label="Face detector", value=False)
1010
  person_detector_ad_a_gui = gr.Checkbox(label="Person detector", value=True)
1011
  hand_detector_ad_a_gui = gr.Checkbox(label="Hand detector", value=False)
1012
  with gr.Row():
 
1021
  negative_prompt_ad_b_gui = gr.Textbox(label="Negative prompt", placeholder="Main negative prompt will be use", lines=3)
1022
  with gr.Row():
1023
  strength_ad_b_gui = gr.Number(label="Strength:", value=0.35, step=0.01, minimum=0.01, maximum=1.0)
1024
+ face_detector_ad_b_gui = gr.Checkbox(label="Face detector", value=False)
1025
  person_detector_ad_b_gui = gr.Checkbox(label="Person detector", value=True)
1026
  hand_detector_ad_b_gui = gr.Checkbox(label="Hand detector", value=False)
1027
  with gr.Row():
 
1151
 
1152
  with gr.Accordion("Examples and help", open=True, visible=True) as menu_example:
1153
  gr.Examples(
1154
+ examples=EXAMPLES_GUI,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1155
  fn=sd_gen.generate_pipeline,
1156
  inputs=[
1157
  prompt_gui,
 
1166
  img_width_gui,
1167
  model_name_gui,
1168
  ],
1169
+ outputs=[load_model_gui, result_images, actual_task_info],
1170
  cache_examples=False,
1171
  #elem_id="examples",
1172
  )
1173
 
1174
+ gr.Markdown(RESOURCES)
 
 
 
 
1175
  ## END MOD
1176
 
1177
  with gr.Tab("Inpaint mask maker", render=True):
 
1330
  )
1331
  search_civitai_result_lora.change(select_civitai_lora, [search_civitai_result_lora], [text_lora, search_civitai_desc_lora], queue=False, scroll_to_output=True)
1332
  search_civitai_gallery_lora.select(update_civitai_selection, None, [search_civitai_result_lora], queue=False, show_api=False)
1333
+ button_lora.click(get_my_lora, [text_lora, romanize_text], [lora1_gui, lora2_gui, lora3_gui, lora4_gui, lora5_gui, new_lora_status], scroll_to_output=True)
1334
  upload_button_lora.upload(upload_file_lora, [upload_button_lora], [file_output_lora, upload_button_lora]).success(
1335
  move_file_lora, [file_output_lora], [lora1_gui, lora2_gui, lora3_gui, lora4_gui, lora5_gui], scroll_to_output=True)
1336
 
 
1486
  verbose_info_gui,
1487
  gpu_duration_gui,
1488
  ],
1489
+ outputs=[load_model_gui, result_images, actual_task_info],
1490
  queue=True,
1491
  show_progress="full",
1492
+ ).success(save_gallery_images, [result_images, model_name_gui], [result_images, result_images_files], queue=False, show_api=False)\
1493
+ .success(save_gallery_history, [result_images, result_images_files, history_gallery, history_files], [history_gallery, history_files], queue=False, show_api=False)
1494
 
1495
  with gr.Tab("Danbooru Tags Transformer with WD Tagger", render=True):
1496
  with gr.Column(scale=2):
 
1586
  gr.DuplicateButton(value="Duplicate Space for private use (This demo does not work on CPU. Requires GPU Space)")
1587
 
1588
  app.queue()
1589
+ app.launch(show_error=True, debug=True) # allowed_paths=["./images/"], show_error=True, debug=True
1590
  ## END MOD
constants.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from stablepy.diffusers_vanilla.constants import FLUX_CN_UNION_MODES
3
+ from stablepy import (
4
+ scheduler_names,
5
+ SD15_TASKS,
6
+ SDXL_TASKS,
7
+ )
8
+
9
+ # - **Download Models**
10
+ DOWNLOAD_MODEL = "https://civitai.com/api/download/models/574369, https://huggingface.co/TechnoByte/MilkyWonderland/resolve/main/milkyWonderland_v40.safetensors"
11
+
12
+ # - **Download VAEs**
13
+ DOWNLOAD_VAE = "https://huggingface.co/nubby/blessed-sdxl-vae-fp16-fix/resolve/main/sdxl_vae-fp16fix-c-1.1-b-0.5.safetensors?download=true, https://huggingface.co/nubby/blessed-sdxl-vae-fp16-fix/resolve/main/sdxl_vae-fp16fix-blessed.safetensors?download=true, https://huggingface.co/digiplay/VAE/resolve/main/vividReal_v20.safetensors?download=true, https://huggingface.co/fp16-guy/anything_kl-f8-anime2_vae-ft-mse-840000-ema-pruned_blessed_clearvae_fp16_cleaned/resolve/main/vae-ft-mse-840000-ema-pruned_fp16.safetensors?download=true"
14
+
15
+ # - **Download LoRAs**
16
+ DOWNLOAD_LORA = "https://huggingface.co/Leopain/color/resolve/main/Coloring_book_-_LineArt.safetensors, https://civitai.com/api/download/models/135867, https://huggingface.co/Linaqruf/anime-detailer-xl-lora/resolve/main/anime-detailer-xl.safetensors?download=true, https://huggingface.co/Linaqruf/style-enhancer-xl-lora/resolve/main/style-enhancer-xl.safetensors?download=true, https://huggingface.co/ByteDance/Hyper-SD/resolve/main/Hyper-SD15-8steps-CFG-lora.safetensors?download=true, https://huggingface.co/ByteDance/Hyper-SD/resolve/main/Hyper-SDXL-8steps-CFG-lora.safetensors?download=true"
17
+
18
+ LOAD_DIFFUSERS_FORMAT_MODEL = [
19
+ 'stabilityai/stable-diffusion-xl-base-1.0',
20
+ 'black-forest-labs/FLUX.1-dev',
21
+ 'John6666/blue-pencil-flux1-v021-fp8-flux',
22
+ 'John6666/wai-ani-flux-v10forfp8-fp8-flux',
23
+ 'John6666/xe-anime-flux-v04-fp8-flux',
24
+ 'John6666/lyh-anime-flux-v2a1-fp8-flux',
25
+ 'John6666/carnival-unchained-v10-fp8-flux',
26
+ 'cagliostrolab/animagine-xl-3.1',
27
+ 'John6666/epicrealism-xl-v8kiss-sdxl',
28
+ 'misri/epicrealismXL_v7FinalDestination',
29
+ 'misri/juggernautXL_juggernautX',
30
+ 'misri/zavychromaxl_v80',
31
+ 'SG161222/RealVisXL_V4.0',
32
+ 'SG161222/RealVisXL_V5.0',
33
+ 'misri/newrealityxlAllInOne_Newreality40',
34
+ 'eienmojiki/Anything-XL',
35
+ 'eienmojiki/Starry-XL-v5.2',
36
+ 'gsdf/CounterfeitXL',
37
+ 'KBlueLeaf/Kohaku-XL-Zeta',
38
+ 'John6666/silvermoon-mix-01xl-v11-sdxl',
39
+ 'WhiteAiZ/autismmixSDXL_autismmixConfetti_diffusers',
40
+ 'kitty7779/ponyDiffusionV6XL',
41
+ 'GraydientPlatformAPI/aniverse-pony',
42
+ 'John6666/ras-real-anime-screencap-v1-sdxl',
43
+ 'John6666/duchaiten-pony-xl-no-score-v60-sdxl',
44
+ 'John6666/mistoon-anime-ponyalpha-sdxl',
45
+ 'John6666/3x3x3mixxl-v2-sdxl',
46
+ 'John6666/3x3x3mixxl-3dv01-sdxl',
47
+ 'John6666/ebara-mfcg-pony-mix-v12-sdxl',
48
+ 'John6666/t-ponynai3-v51-sdxl',
49
+ 'John6666/t-ponynai3-v65-sdxl',
50
+ 'John6666/prefect-pony-xl-v3-sdxl',
51
+ 'John6666/mala-anime-mix-nsfw-pony-xl-v5-sdxl',
52
+ 'John6666/wai-real-mix-v11-sdxl',
53
+ 'John6666/wai-c-v6-sdxl',
54
+ 'John6666/iniverse-mix-xl-sfwnsfw-pony-guofeng-v43-sdxl',
55
+ 'John6666/photo-realistic-pony-v5-sdxl',
56
+ 'John6666/pony-realism-v21main-sdxl',
57
+ 'John6666/pony-realism-v22main-sdxl',
58
+ 'John6666/cyberrealistic-pony-v63-sdxl',
59
+ 'John6666/cyberrealistic-pony-v64-sdxl',
60
+ 'John6666/cyberrealistic-pony-v65-sdxl',
61
+ 'GraydientPlatformAPI/realcartoon-pony-diffusion',
62
+ 'John6666/nova-anime-xl-pony-v5-sdxl',
63
+ 'John6666/autismmix-sdxl-autismmix-pony-sdxl',
64
+ 'John6666/aimz-dream-real-pony-mix-v3-sdxl',
65
+ 'John6666/duchaiten-pony-real-v11fix-sdxl',
66
+ 'John6666/duchaiten-pony-real-v20-sdxl',
67
+ 'yodayo-ai/kivotos-xl-2.0',
68
+ 'yodayo-ai/holodayo-xl-2.1',
69
+ 'yodayo-ai/clandestine-xl-1.0',
70
+ 'digiplay/majicMIX_sombre_v2',
71
+ 'digiplay/majicMIX_realistic_v6',
72
+ 'digiplay/majicMIX_realistic_v7',
73
+ 'digiplay/DreamShaper_8',
74
+ 'digiplay/BeautifulArt_v1',
75
+ 'digiplay/DarkSushi2.5D_v1',
76
+ 'digiplay/darkphoenix3D_v1.1',
77
+ 'digiplay/BeenYouLiteL11_diffusers',
78
+ 'Yntec/RevAnimatedV2Rebirth',
79
+ 'youknownothing/cyberrealistic_v50',
80
+ 'youknownothing/deliberate-v6',
81
+ 'GraydientPlatformAPI/deliberate-cyber3',
82
+ 'GraydientPlatformAPI/picx-real',
83
+ 'GraydientPlatformAPI/perfectworld6',
84
+ 'emilianJR/epiCRealism',
85
+ 'votepurchase/counterfeitV30_v30',
86
+ 'votepurchase/ChilloutMix',
87
+ 'Meina/MeinaMix_V11',
88
+ 'Meina/MeinaUnreal_V5',
89
+ 'Meina/MeinaPastel_V7',
90
+ 'GraydientPlatformAPI/realcartoon3d-17',
91
+ 'GraydientPlatformAPI/realcartoon-pixar11',
92
+ 'GraydientPlatformAPI/realcartoon-real17',
93
+ ]
94
+
95
+ DIFFUSERS_FORMAT_LORAS = [
96
+ "nerijs/animation2k-flux",
97
+ "XLabs-AI/flux-RealismLora",
98
+ ]
99
+
100
+ DOWNLOAD_EMBEDS = [
101
+ 'https://huggingface.co/datasets/Nerfgun3/bad_prompt/blob/main/bad_prompt_version2.pt',
102
+ 'https://huggingface.co/embed/negative/resolve/main/EasyNegativeV2.safetensors',
103
+ 'https://huggingface.co/embed/negative/resolve/main/bad-hands-5.pt',
104
+ ]
105
+
106
+ CIVITAI_API_KEY = os.environ.get("CIVITAI_API_KEY")
107
+ HF_TOKEN = os.environ.get("HF_READ_TOKEN")
108
+
109
+ DIRECTORY_MODELS = 'models'
110
+ DIRECTORY_LORAS = 'loras'
111
+ DIRECTORY_VAES = 'vaes'
112
+ DIRECTORY_EMBEDS = 'embedings'
113
+
114
+ PREPROCESSOR_CONTROLNET = {
115
+ "openpose": [
116
+ "Openpose",
117
+ "None",
118
+ ],
119
+ "scribble": [
120
+ "HED",
121
+ "PidiNet",
122
+ "None",
123
+ ],
124
+ "softedge": [
125
+ "PidiNet",
126
+ "HED",
127
+ "HED safe",
128
+ "PidiNet safe",
129
+ "None",
130
+ ],
131
+ "segmentation": [
132
+ "UPerNet",
133
+ "None",
134
+ ],
135
+ "depth": [
136
+ "DPT",
137
+ "Midas",
138
+ "None",
139
+ ],
140
+ "normalbae": [
141
+ "NormalBae",
142
+ "None",
143
+ ],
144
+ "lineart": [
145
+ "Lineart",
146
+ "Lineart coarse",
147
+ "Lineart (anime)",
148
+ "None",
149
+ "None (anime)",
150
+ ],
151
+ "lineart_anime": [
152
+ "Lineart",
153
+ "Lineart coarse",
154
+ "Lineart (anime)",
155
+ "None",
156
+ "None (anime)",
157
+ ],
158
+ "shuffle": [
159
+ "ContentShuffle",
160
+ "None",
161
+ ],
162
+ "canny": [
163
+ "Canny",
164
+ "None",
165
+ ],
166
+ "mlsd": [
167
+ "MLSD",
168
+ "None",
169
+ ],
170
+ "ip2p": [
171
+ "ip2p"
172
+ ],
173
+ "recolor": [
174
+ "Recolor luminance",
175
+ "Recolor intensity",
176
+ "None",
177
+ ],
178
+ "tile": [
179
+ "Mild Blur",
180
+ "Moderate Blur",
181
+ "Heavy Blur",
182
+ "None",
183
+ ],
184
+
185
+ }
186
+
187
+ TASK_STABLEPY = {
188
+ 'txt2img': 'txt2img',
189
+ 'img2img': 'img2img',
190
+ 'inpaint': 'inpaint',
191
+ # 'canny T2I Adapter': 'sdxl_canny_t2i', # NO HAVE STEP CALLBACK PARAMETERS SO NOT WORKS WITH DIFFUSERS 0.29.0
192
+ # 'sketch T2I Adapter': 'sdxl_sketch_t2i',
193
+ # 'lineart T2I Adapter': 'sdxl_lineart_t2i',
194
+ # 'depth-midas T2I Adapter': 'sdxl_depth-midas_t2i',
195
+ # 'openpose T2I Adapter': 'sdxl_openpose_t2i',
196
+ 'openpose ControlNet': 'openpose',
197
+ 'canny ControlNet': 'canny',
198
+ 'mlsd ControlNet': 'mlsd',
199
+ 'scribble ControlNet': 'scribble',
200
+ 'softedge ControlNet': 'softedge',
201
+ 'segmentation ControlNet': 'segmentation',
202
+ 'depth ControlNet': 'depth',
203
+ 'normalbae ControlNet': 'normalbae',
204
+ 'lineart ControlNet': 'lineart',
205
+ 'lineart_anime ControlNet': 'lineart_anime',
206
+ 'shuffle ControlNet': 'shuffle',
207
+ 'ip2p ControlNet': 'ip2p',
208
+ 'optical pattern ControlNet': 'pattern',
209
+ 'recolor ControlNet': 'recolor',
210
+ 'tile ControlNet': 'tile',
211
+ }
212
+
213
+ TASK_MODEL_LIST = list(TASK_STABLEPY.keys())
214
+
215
+ UPSCALER_DICT_GUI = {
216
+ None: None,
217
+ "Lanczos": "Lanczos",
218
+ "Nearest": "Nearest",
219
+ 'Latent': 'Latent',
220
+ 'Latent (antialiased)': 'Latent (antialiased)',
221
+ 'Latent (bicubic)': 'Latent (bicubic)',
222
+ 'Latent (bicubic antialiased)': 'Latent (bicubic antialiased)',
223
+ 'Latent (nearest)': 'Latent (nearest)',
224
+ 'Latent (nearest-exact)': 'Latent (nearest-exact)',
225
+ "RealESRGAN_x4plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
226
+ "RealESRNet_x4plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth",
227
+ "RealESRGAN_x4plus_anime_6B": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
228
+ "RealESRGAN_x2plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
229
+ "realesr-animevideov3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
230
+ "realesr-general-x4v3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
231
+ "realesr-general-wdn-x4v3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
232
+ "4x-UltraSharp": "https://huggingface.co/Shandypur/ESRGAN-4x-UltraSharp/resolve/main/4x-UltraSharp.pth",
233
+ "4x_foolhardy_Remacri": "https://huggingface.co/FacehugmanIII/4x_foolhardy_Remacri/resolve/main/4x_foolhardy_Remacri.pth",
234
+ "Remacri4xExtraSmoother": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/Remacri%204x%20ExtraSmoother.pth",
235
+ "AnimeSharp4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/AnimeSharp%204x.pth",
236
+ "lollypop": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/lollypop.pth",
237
+ "RealisticRescaler4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/RealisticRescaler%204x.pth",
238
+ "NickelbackFS4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/NickelbackFS%204x.pth"
239
+ }
240
+
241
+ UPSCALER_KEYS = list(UPSCALER_DICT_GUI.keys())
242
+
243
+ PROMPT_W_OPTIONS = [
244
+ ("Compel format: (word)weight", "Compel"),
245
+ ("Classic format: (word:weight)", "Classic"),
246
+ ("Classic-original format: (word:weight)", "Classic-original"),
247
+ ("Classic-no_norm format: (word:weight)", "Classic-no_norm"),
248
+ ("Classic-ignore", "Classic-ignore"),
249
+ ("None", "None"),
250
+ ]
251
+
252
+ WARNING_MSG_VAE = (
253
+ "Use the right VAE for your model to maintain image quality. The wrong"
254
+ " VAE can lead to poor results, like blurriness in the generated images."
255
+ )
256
+
257
+ SDXL_TASK = [k for k, v in TASK_STABLEPY.items() if v in SDXL_TASKS]
258
+ SD_TASK = [k for k, v in TASK_STABLEPY.items() if v in SD15_TASKS]
259
+ FLUX_TASK = list(TASK_STABLEPY.keys())[:3] + [k for k, v in TASK_STABLEPY.items() if v in FLUX_CN_UNION_MODES.keys()]
260
+
261
+ MODEL_TYPE_TASK = {
262
+ "SD 1.5": SD_TASK,
263
+ "SDXL": SDXL_TASK,
264
+ "FLUX": FLUX_TASK,
265
+ }
266
+
267
+ MODEL_TYPE_CLASS = {
268
+ "diffusers:StableDiffusionPipeline": "SD 1.5",
269
+ "diffusers:StableDiffusionXLPipeline": "SDXL",
270
+ "diffusers:FluxPipeline": "FLUX",
271
+ }
272
+
273
+ POST_PROCESSING_SAMPLER = ["Use same sampler"] + scheduler_names[:-2]
274
+
275
+ SUBTITLE_GUI = (
276
+ "### This demo uses [diffusers](https://github.com/huggingface/diffusers)"
277
+ " to perform different tasks in image generation."
278
+ )
279
+
280
+ HELP_GUI = (
281
+ """### Help:
282
+ - The current space runs on a ZERO GPU which is assigned for approximately 60 seconds; Therefore, if you submit expensive tasks, the operation may be canceled upon reaching the maximum allowed time with 'GPU TASK ABORTED'.
283
+ - Distorted or strange images often result from high prompt weights, so it's best to use low weights and scales, and consider using Classic variants like 'Classic-original'.
284
+ - For better results with Pony Diffusion, try using sampler DPM++ 1s or DPM2 with Compel or Classic prompt weights.
285
+ """
286
+ )
287
+
288
+ EXAMPLES_GUI_HELP = (
289
+ """### The following examples perform specific tasks:
290
+ 1. Generation with SDXL and upscale
291
+ 2. Generation with FLUX dev
292
+ 3. ControlNet Canny SDXL
293
+ 4. Optical pattern (Optical illusion) SDXL
294
+ 5. Convert an image to a coloring drawing
295
+ 6. ControlNet OpenPose SD 1.5 and Latent upscale
296
+
297
+ - Different tasks can be performed, such as img2img or using the IP adapter, to preserve a person's appearance or a specific style based on an image.
298
+ """
299
+ )
300
+
301
+ EXAMPLES_GUI = [
302
+ [
303
+ "1girl, souryuu asuka langley, neon genesis evangelion, rebuild of evangelion, lance of longinus, cat hat, plugsuit, pilot suit, red bodysuit, sitting, crossed legs, black eye patch, throne, looking down, from bottom, looking at viewer, outdoors, (masterpiece), (best quality), (ultra-detailed), very aesthetic, illustration, disheveled hair, perfect composition, moist skin, intricate details",
304
+ "nfsw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, unfinished, very displeasing, oldest, early, chromatic aberration, artistic error, scan, abstract",
305
+ 28,
306
+ 7.0,
307
+ -1,
308
+ "None",
309
+ 0.33,
310
+ "Euler a",
311
+ 1152,
312
+ 896,
313
+ "cagliostrolab/animagine-xl-3.1",
314
+ "txt2img",
315
+ "image.webp", # img conttol
316
+ 1024, # img resolution
317
+ 0.35, # strength
318
+ 1.0, # cn scale
319
+ 0.0, # cn start
320
+ 1.0, # cn end
321
+ "Classic",
322
+ "Nearest",
323
+ 45,
324
+ False,
325
+ ],
326
+ [
327
+ "a digital illustration of a movie poster titled 'Finding Emo', finding nemo parody poster, featuring a depressed cartoon clownfish with black emo hair, eyeliner, and piercings, bored expression, swimming in a dark underwater scene, in the background, movie title in a dripping, grungy font, moody blue and purple color palette",
328
+ "",
329
+ 24,
330
+ 3.5,
331
+ -1,
332
+ "None",
333
+ 0.33,
334
+ "Euler a",
335
+ 1152,
336
+ 896,
337
+ "black-forest-labs/FLUX.1-dev",
338
+ "txt2img",
339
+ None, # img conttol
340
+ 1024, # img resolution
341
+ 0.35, # strength
342
+ 1.0, # cn scale
343
+ 0.0, # cn start
344
+ 1.0, # cn end
345
+ "Classic",
346
+ None,
347
+ 70,
348
+ True,
349
+ ],
350
+ [
351
+ "((masterpiece)), best quality, blonde disco girl, detailed face, realistic face, realistic hair, dynamic pose, pink pvc, intergalactic disco background, pastel lights, dynamic contrast, airbrush, fine detail, 70s vibe, midriff",
352
+ "(worst quality:1.2), (bad quality:1.2), (poor quality:1.2), (missing fingers:1.2), bad-artist-anime, bad-artist, bad-picture-chill-75v",
353
+ 48,
354
+ 3.5,
355
+ -1,
356
+ "None",
357
+ 0.33,
358
+ "DPM++ 2M SDE Lu",
359
+ 1024,
360
+ 1024,
361
+ "misri/epicrealismXL_v7FinalDestination",
362
+ "canny ControlNet",
363
+ "image.webp", # img conttol
364
+ 1024, # img resolution
365
+ 0.35, # strength
366
+ 1.0, # cn scale
367
+ 0.0, # cn start
368
+ 1.0, # cn end
369
+ "Classic",
370
+ None,
371
+ 44,
372
+ False,
373
+ ],
374
+ [
375
+ "cinematic scenery old city ruins",
376
+ "(worst quality, low quality, illustration, 3d, 2d, painting, cartoons, sketch), (illustration, 3d, 2d, painting, cartoons, sketch, blurry, film grain, noise), (low quality, worst quality:1.2)",
377
+ 50,
378
+ 4.0,
379
+ -1,
380
+ "None",
381
+ 0.33,
382
+ "Euler a",
383
+ 1024,
384
+ 1024,
385
+ "misri/juggernautXL_juggernautX",
386
+ "optical pattern ControlNet",
387
+ "spiral_no_transparent.png", # img conttol
388
+ 1024, # img resolution
389
+ 0.35, # strength
390
+ 1.0, # cn scale
391
+ 0.05, # cn start
392
+ 0.75, # cn end
393
+ "Classic",
394
+ None,
395
+ 35,
396
+ False,
397
+ ],
398
+ [
399
+ "black and white, line art, coloring drawing, clean line art, black strokes, no background, white, black, free lines, black scribbles, on paper, A blend of comic book art and lineart full of black and white color, masterpiece, high-resolution, trending on Pixiv fan box, palette knife, brush strokes, two-dimensional, planar vector, T-shirt design, stickers, and T-shirt design, vector art, fantasy art, Adobe Illustrator, hand-painted, digital painting, low polygon, soft lighting, aerial view, isometric style, retro aesthetics, 8K resolution, black sketch lines, monochrome, invert color",
400
+ "color, red, green, yellow, colored, duplicate, blurry, abstract, disfigured, deformed, animated, toy, figure, framed, 3d, bad art, poorly drawn, extra limbs, close up, b&w, weird colors, blurry, watermark, blur haze, 2 heads, long neck, watermark, elongated body, cropped image, out of frame, draft, deformed hands, twisted fingers, double image, malformed hands, multiple heads, extra limb, ugly, poorly drawn hands, missing limb, cut-off, over satured, grain, lowères, bad anatomy, poorly drawn face, mutation, mutated, floating limbs, disconnected limbs, out of focus, long body, disgusting, extra fingers, groos proportions, missing arms, mutated hands, cloned face, missing legs, ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, signature, cut off, draft, deformed, blurry, bad anatomy, disfigured, poorly drawn face, mutation, bluelish, blue",
401
+ 20,
402
+ 4.0,
403
+ -1,
404
+ "loras/Coloring_book_-_LineArt.safetensors",
405
+ 1.0,
406
+ "DPM++ 2M SDE Karras",
407
+ 1024,
408
+ 1024,
409
+ "cagliostrolab/animagine-xl-3.1",
410
+ "lineart ControlNet",
411
+ "color_image.png", # img conttol
412
+ 896, # img resolution
413
+ 0.35, # strength
414
+ 1.0, # cn scale
415
+ 0.0, # cn start
416
+ 1.0, # cn end
417
+ "Compel",
418
+ None,
419
+ 35,
420
+ False,
421
+ ],
422
+ [
423
+ "1girl,face,curly hair,red hair,white background,",
424
+ "(worst quality:2),(low quality:2),(normal quality:2),lowres,watermark,",
425
+ 38,
426
+ 5.0,
427
+ -1,
428
+ "None",
429
+ 0.33,
430
+ "DPM++ 2M SDE Karras",
431
+ 512,
432
+ 512,
433
+ "digiplay/majicMIX_realistic_v7",
434
+ "openpose ControlNet",
435
+ "image.webp", # img conttol
436
+ 1024, # img resolution
437
+ 0.35, # strength
438
+ 1.0, # cn scale
439
+ 0.0, # cn start
440
+ 0.9, # cn end
441
+ "Compel",
442
+ "Latent (antialiased)",
443
+ 46,
444
+ False,
445
+ ],
446
+ ]
447
+
448
+ RESOURCES = (
449
+ """### Resources
450
+ - John6666's space has some great features you might find helpful [link](https://huggingface.co/spaces/John6666/DiffuseCraftMod).
451
+ - You can also try the image generator in Colab’s free tier, which provides free GPU [link](https://github.com/R3gm/SD_diffusers_interactive).
452
+ """
453
+ )
env.py CHANGED
@@ -2,10 +2,10 @@ import os
2
 
3
  CIVITAI_API_KEY = os.environ.get("CIVITAI_API_KEY")
4
  HF_TOKEN = os.environ.get("HF_TOKEN")
5
- hf_read_token = os.environ.get('HF_READ_TOKEN') # only use for private repo
6
 
7
  # - **List Models**
8
- load_diffusers_format_model = [
9
  'stabilityai/stable-diffusion-xl-base-1.0',
10
  'John6666/blue-pencil-flux1-v021-fp8-flux',
11
  'John6666/wai-ani-flux-v10forfp8-fp8-flux',
@@ -106,11 +106,11 @@ HF_MODEL_USER_EX = ["John6666"] # sorted by a special rule
106
 
107
 
108
  # - **Download Models**
109
- download_model_list = [
110
  ]
111
 
112
  # - **Download VAEs**
113
- download_vae_list = [
114
  'https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/sdxl.vae.safetensors?download=true',
115
  'https://huggingface.co/nubby/blessed-sdxl-vae-fp16-fix/resolve/main/sdxl_vae-fp16fix-c-1.1-b-0.5.safetensors?download=true',
116
  'https://huggingface.co/nubby/blessed-sdxl-vae-fp16-fix/resolve/main/sdxl_vae-fp16fix-blessed.safetensors?download=true',
@@ -119,29 +119,26 @@ download_vae_list = [
119
  ]
120
 
121
  # - **Download LoRAs**
122
- download_lora_list = [
123
  ]
124
 
125
  # Download Embeddings
126
- download_embeds = [
127
  'https://huggingface.co/datasets/Nerfgun3/bad_prompt/blob/main/bad_prompt_version2.pt',
128
  'https://huggingface.co/embed/negative/resolve/main/EasyNegativeV2.safetensors',
129
  'https://huggingface.co/embed/negative/resolve/main/bad-hands-5.pt',
130
  ]
131
 
132
- directory_models = 'models'
133
- os.makedirs(directory_models, exist_ok=True)
134
- directory_loras = 'loras'
135
- os.makedirs(directory_loras, exist_ok=True)
136
- directory_vaes = 'vaes'
137
- os.makedirs(directory_vaes, exist_ok=True)
138
- directory_embeds = 'embedings'
139
- os.makedirs(directory_embeds, exist_ok=True)
140
 
141
- directory_embeds_sdxl = 'embedings_xl'
142
- os.makedirs(directory_embeds_sdxl, exist_ok=True)
143
- directory_embeds_positive_sdxl = 'embedings_xl/positive'
144
- os.makedirs(directory_embeds_positive_sdxl, exist_ok=True)
145
 
146
  HF_LORA_PRIVATE_REPOS1 = ['John6666/loratest1', 'John6666/loratest3', 'John6666/loratest4', 'John6666/loratest6']
147
  HF_LORA_PRIVATE_REPOS2 = ['John6666/loratest10', 'John6666/loratest11','John6666/loratest'] # to be sorted as 1 repo
 
2
 
3
  CIVITAI_API_KEY = os.environ.get("CIVITAI_API_KEY")
4
  HF_TOKEN = os.environ.get("HF_TOKEN")
5
+ HF_READ_TOKEN = os.environ.get('HF_READ_TOKEN') # only use for private repo
6
 
7
  # - **List Models**
8
+ LOAD_DIFFUSERS_FORMAT_MODEL = [
9
  'stabilityai/stable-diffusion-xl-base-1.0',
10
  'John6666/blue-pencil-flux1-v021-fp8-flux',
11
  'John6666/wai-ani-flux-v10forfp8-fp8-flux',
 
106
 
107
 
108
  # - **Download Models**
109
+ DOWNLOAD_MODEL_LIST = [
110
  ]
111
 
112
  # - **Download VAEs**
113
+ DOWNLOAD_VAE_LIST = [
114
  'https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/sdxl.vae.safetensors?download=true',
115
  'https://huggingface.co/nubby/blessed-sdxl-vae-fp16-fix/resolve/main/sdxl_vae-fp16fix-c-1.1-b-0.5.safetensors?download=true',
116
  'https://huggingface.co/nubby/blessed-sdxl-vae-fp16-fix/resolve/main/sdxl_vae-fp16fix-blessed.safetensors?download=true',
 
119
  ]
120
 
121
  # - **Download LoRAs**
122
+ DOWNLOAD_LORA_LIST = [
123
  ]
124
 
125
  # Download Embeddings
126
+ DOWNLOAD_EMBEDS = [
127
  'https://huggingface.co/datasets/Nerfgun3/bad_prompt/blob/main/bad_prompt_version2.pt',
128
  'https://huggingface.co/embed/negative/resolve/main/EasyNegativeV2.safetensors',
129
  'https://huggingface.co/embed/negative/resolve/main/bad-hands-5.pt',
130
  ]
131
 
132
+ DIRECTORY_MODELS = 'models'
133
+ DIRECTORY_LORAS = 'loras'
134
+ DIRECTORY_VAES = 'vaes'
135
+ DIRECTORY_EMBEDS = 'embedings'
136
+ DIRECTORY_EMBEDS_SDXL = 'embedings_xl'
137
+ DIRECTORY_EMBEDS_POSITIVE_SDXL = 'embedings_xl/positive'
 
 
138
 
139
+ directories = [DIRECTORY_MODELS, DIRECTORY_LORAS, DIRECTORY_VAES, DIRECTORY_EMBEDS, DIRECTORY_EMBEDS_SDXL, DIRECTORY_EMBEDS_POSITIVE_SDXL]
140
+ for directory in directories:
141
+ os.makedirs(directory, exist_ok=True)
 
142
 
143
  HF_LORA_PRIVATE_REPOS1 = ['John6666/loratest1', 'John6666/loratest3', 'John6666/loratest4', 'John6666/loratest6']
144
  HF_LORA_PRIVATE_REPOS2 = ['John6666/loratest10', 'John6666/loratest11','John6666/loratest'] # to be sorted as 1 repo
modutils.py CHANGED
@@ -5,6 +5,7 @@ import os
5
  import re
6
  from pathlib import Path
7
  from PIL import Image
 
8
  import shutil
9
  import requests
10
  from requests.adapters import HTTPAdapter
@@ -12,11 +13,16 @@ from urllib3.util import Retry
12
  import urllib.parse
13
  import pandas as pd
14
  from huggingface_hub import HfApi, HfFolder, hf_hub_download, snapshot_download
 
 
 
 
 
15
 
16
 
17
  from env import (HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2,
18
  HF_MODEL_USER_EX, HF_MODEL_USER_LIKES, DIFFUSERS_FORMAT_LORAS,
19
- directory_loras, hf_read_token, HF_TOKEN, CIVITAI_API_KEY)
20
 
21
 
22
  MODEL_TYPE_DICT = {
@@ -46,7 +52,6 @@ def is_repo_name(s):
46
  return re.fullmatch(r'^[^/]+?/[^/]+?$', s)
47
 
48
 
49
- from translatepy import Translator
50
  translator = Translator()
51
  def translate_to_en(input: str):
52
  try:
@@ -64,6 +69,7 @@ def get_local_model_list(dir_path):
64
  if file.suffix in valid_extensions:
65
  file_path = str(Path(f"{dir_path}/{file.name}"))
66
  model_list.append(file_path)
 
67
  return model_list
68
 
69
 
@@ -98,21 +104,81 @@ def split_hf_url(url: str):
98
  print(e)
99
 
100
 
101
- def download_hf_file(directory, url, progress=gr.Progress(track_tqdm=True)):
102
- hf_token = get_token()
103
  repo_id, filename, subfolder, repo_type = split_hf_url(url)
 
 
 
104
  try:
105
- print(f"Downloading {url} to {directory}")
106
- if subfolder is not None: path = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder, repo_type=repo_type, local_dir=directory, token=hf_token)
107
- else: path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type=repo_type, local_dir=directory, token=hf_token)
108
  return path
109
  except Exception as e:
110
- print(f"Failed to download: {e}")
111
  return None
112
 
113
 
114
- def download_things(directory, url, hf_token="", civitai_api_key=""):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  url = url.strip()
 
 
116
  if "drive.google.com" in url:
117
  original_dir = os.getcwd()
118
  os.chdir(directory)
@@ -123,18 +189,48 @@ def download_things(directory, url, hf_token="", civitai_api_key=""):
123
  # url = urllib.parse.quote(url, safe=':/') # fix encoding
124
  if "/blob/" in url:
125
  url = url.replace("/blob/", "/resolve/")
126
- download_hf_file(directory, url)
 
 
 
 
 
 
127
  elif "civitai.com" in url:
128
- if "?" in url:
129
- url = url.split("?")[0]
130
- if civitai_api_key:
131
- url = url + f"?token={civitai_api_key}"
132
- os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
133
- else:
134
  print("\033[91mYou need an API key to download Civitai models.\033[0m")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  else:
136
  os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
137
 
 
 
138
 
139
  def get_download_file(temp_dir, url, civitai_key="", progress=gr.Progress(track_tqdm=True)):
140
  if not "http" in url and is_repo_name(url) and not Path(url).exists():
@@ -173,7 +269,7 @@ def to_lora_key(path: str):
173
 
174
  def to_lora_path(key: str):
175
  if Path(key).is_file(): return key
176
- path = Path(f"{directory_loras}/{escape_lora_basename(key)}.safetensors")
177
  return str(path)
178
 
179
 
@@ -203,25 +299,21 @@ def save_images(images: list[Image.Image], metadatas: list[str]):
203
  raise Exception(f"Failed to save image file:") from e
204
 
205
 
206
- def save_gallery_images(images, progress=gr.Progress(track_tqdm=True)):
207
- from datetime import datetime, timezone, timedelta
208
  progress(0, desc="Updating gallery...")
209
- dt_now = datetime.now(timezone(timedelta(hours=9)))
210
- basename = dt_now.strftime('%Y%m%d_%H%M%S_')
211
- i = 1
212
- if not images: return images, gr.update(visible=False)
213
  output_images = []
214
  output_paths = []
215
- for image in images:
216
- filename = basename + str(i) + ".png"
217
- i += 1
218
  oldpath = Path(image[0])
219
  newpath = oldpath
220
  try:
221
  if oldpath.exists():
222
  newpath = oldpath.resolve().rename(Path(filename).resolve())
223
  except Exception as e:
224
- print(e)
225
  finally:
226
  output_paths.append(str(newpath))
227
  output_images.append((str(newpath), str(filename)))
@@ -229,10 +321,47 @@ def save_gallery_images(images, progress=gr.Progress(track_tqdm=True)):
229
  return gr.update(value=output_images), gr.update(value=output_paths, visible=True)
230
 
231
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  def download_private_repo(repo_id, dir_path, is_replace):
233
- if not hf_read_token: return
234
  try:
235
- snapshot_download(repo_id=repo_id, local_dir=dir_path, allow_patterns=['*.ckpt', '*.pt', '*.pth', '*.safetensors', '*.bin'], use_auth_token=hf_read_token)
236
  except Exception as e:
237
  print(f"Error: Failed to download {repo_id}.")
238
  print(e)
@@ -250,9 +379,9 @@ private_model_path_repo_dict = {} # {"local filepath": "huggingface repo_id", ..
250
  def get_private_model_list(repo_id, dir_path):
251
  global private_model_path_repo_dict
252
  api = HfApi()
253
- if not hf_read_token: return []
254
  try:
255
- files = api.list_repo_files(repo_id, token=hf_read_token)
256
  except Exception as e:
257
  print(f"Error: Failed to list {repo_id}.")
258
  print(e)
@@ -270,11 +399,11 @@ def get_private_model_list(repo_id, dir_path):
270
  def download_private_file(repo_id, path, is_replace):
271
  file = Path(path)
272
  newpath = Path(f'{file.parent.name}/{escape_lora_basename(file.stem)}{file.suffix}') if is_replace else file
273
- if not hf_read_token or newpath.exists(): return
274
  filename = file.name
275
  dirname = file.parent.name
276
  try:
277
- hf_hub_download(repo_id=repo_id, filename=filename, local_dir=dirname, use_auth_token=hf_read_token)
278
  except Exception as e:
279
  print(f"Error: Failed to download {filename}.")
280
  print(e)
@@ -404,9 +533,9 @@ def get_private_lora_model_lists():
404
  models1 = []
405
  models2 = []
406
  for repo in HF_LORA_PRIVATE_REPOS1:
407
- models1.extend(get_private_model_list(repo, directory_loras))
408
  for repo in HF_LORA_PRIVATE_REPOS2:
409
- models2.extend(get_private_model_list(repo, directory_loras))
410
  models = list_uniq(models1 + sorted(models2))
411
  private_lora_model_list = models.copy()
412
  return models
@@ -451,7 +580,7 @@ def get_civitai_info(path):
451
 
452
 
453
  def get_lora_model_list():
454
- loras = list_uniq(get_private_lora_model_lists() + DIFFUSERS_FORMAT_LORAS + get_local_model_list(directory_loras))
455
  loras.insert(0, "None")
456
  loras.insert(0, "")
457
  return loras
@@ -503,14 +632,14 @@ def update_lora_dict(path):
503
  def download_lora(dl_urls: str):
504
  global loras_url_to_path_dict
505
  dl_path = ""
506
- before = get_local_model_list(directory_loras)
507
  urls = []
508
  for url in [url.strip() for url in dl_urls.split(',')]:
509
- local_path = f"{directory_loras}/{url.split('/')[-1]}"
510
  if not Path(local_path).exists():
511
- download_things(directory_loras, url, HF_TOKEN, CIVITAI_API_KEY)
512
  urls.append(url)
513
- after = get_local_model_list(directory_loras)
514
  new_files = list_sub(after, before)
515
  i = 0
516
  for file in new_files:
@@ -761,12 +890,14 @@ def update_loras(prompt, prompt_syntax, lora1, lora1_wt, lora2, lora2_wt, lora3,
761
  gr.update(value=tag5, label=label5, visible=on5), gr.update(visible=on5), gr.update(value=md5, visible=on5)
762
 
763
 
764
- def get_my_lora(link_url):
765
- before = get_local_model_list(directory_loras)
 
 
766
  for url in [url.strip() for url in link_url.split(',')]:
767
- if not Path(f"{directory_loras}/{url.split('/')[-1]}").exists():
768
- download_things(directory_loras, url, HF_TOKEN, CIVITAI_API_KEY)
769
- after = get_local_model_list(directory_loras)
770
  new_files = list_sub(after, before)
771
  for file in new_files:
772
  path = Path(file)
@@ -774,11 +905,16 @@ def get_my_lora(link_url):
774
  new_path = Path(f'{path.parent.name}/{escape_lora_basename(path.stem)}{path.suffix}')
775
  path.resolve().rename(new_path.resolve())
776
  update_lora_dict(str(new_path))
 
777
  new_lora_model_list = get_lora_model_list()
778
  new_lora_tupled_list = get_all_lora_tupled_list()
779
-
 
 
 
 
780
  return gr.update(
781
- choices=new_lora_tupled_list, value=new_lora_model_list[-1]
782
  ), gr.update(
783
  choices=new_lora_tupled_list
784
  ), gr.update(
@@ -787,6 +923,8 @@ def get_my_lora(link_url):
787
  choices=new_lora_tupled_list
788
  ), gr.update(
789
  choices=new_lora_tupled_list
 
 
790
  )
791
 
792
 
@@ -794,12 +932,12 @@ def upload_file_lora(files, progress=gr.Progress(track_tqdm=True)):
794
  progress(0, desc="Uploading...")
795
  file_paths = [file.name for file in files]
796
  progress(1, desc="Uploaded.")
797
- return gr.update(value=file_paths, visible=True), gr.update(visible=True)
798
 
799
 
800
  def move_file_lora(filepaths):
801
  for file in filepaths:
802
- path = Path(shutil.move(Path(file).resolve(), Path(f"./{directory_loras}").resolve()))
803
  newpath = Path(f'{path.parent.name}/{escape_lora_basename(path.stem)}{path.suffix}')
804
  path.resolve().rename(newpath.resolve())
805
  update_lora_dict(str(newpath))
@@ -941,7 +1079,7 @@ def update_civitai_selection(evt: gr.SelectData):
941
  selected = civitai_last_choices[selected_index][1]
942
  return gr.update(value=selected)
943
  except Exception:
944
- return gr.update(visible=True)
945
 
946
 
947
  def select_civitai_lora(search_result):
@@ -1425,3 +1563,78 @@ def get_model_pipeline(repo_id: str):
1425
  else:
1426
  return default
1427
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import re
6
  from pathlib import Path
7
  from PIL import Image
8
+ import numpy as np
9
  import shutil
10
  import requests
11
  from requests.adapters import HTTPAdapter
 
13
  import urllib.parse
14
  import pandas as pd
15
  from huggingface_hub import HfApi, HfFolder, hf_hub_download, snapshot_download
16
+ from translatepy import Translator
17
+ from unidecode import unidecode
18
+ import copy
19
+ from datetime import datetime, timezone, timedelta
20
+ FILENAME_TIMEZONE = timezone(timedelta(hours=9)) # JST
21
 
22
 
23
  from env import (HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2,
24
  HF_MODEL_USER_EX, HF_MODEL_USER_LIKES, DIFFUSERS_FORMAT_LORAS,
25
+ DIRECTORY_LORAS, HF_READ_TOKEN, HF_TOKEN, CIVITAI_API_KEY)
26
 
27
 
28
  MODEL_TYPE_DICT = {
 
52
  return re.fullmatch(r'^[^/]+?/[^/]+?$', s)
53
 
54
 
 
55
  translator = Translator()
56
  def translate_to_en(input: str):
57
  try:
 
69
  if file.suffix in valid_extensions:
70
  file_path = str(Path(f"{dir_path}/{file.name}"))
71
  model_list.append(file_path)
72
+ #print('\033[34mFILE: ' + file_path + '\033[0m')
73
  return model_list
74
 
75
 
 
104
  print(e)
105
 
106
 
107
+ def download_hf_file(directory, url, force_filename="", hf_token="", progress=gr.Progress(track_tqdm=True)):
 
108
  repo_id, filename, subfolder, repo_type = split_hf_url(url)
109
+ kwargs = {}
110
+ if subfolder is not None: kwargs["subfolder"] = subfolder
111
+ if force_filename: kwargs["force_filename"] = force_filename
112
  try:
113
+ print(f"Start downloading: {url} to {directory}")
114
+ path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type=repo_type, local_dir=directory, token=hf_token, **kwargs)
 
115
  return path
116
  except Exception as e:
117
+ print(f"Download failed: {url} {e}")
118
  return None
119
 
120
 
121
+ USER_AGENT = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:127.0) Gecko/20100101 Firefox/127.0'
122
+
123
+
124
+ def request_json_data(url):
125
+ model_version_id = url.split('/')[-1]
126
+ if "?modelVersionId=" in model_version_id:
127
+ match = re.search(r'modelVersionId=(\d+)', url)
128
+ model_version_id = match.group(1)
129
+
130
+ endpoint_url = f"https://civitai.com/api/v1/model-versions/{model_version_id}"
131
+
132
+ params = {}
133
+ headers = {'User-Agent': USER_AGENT, 'content-type': 'application/json'}
134
+ session = requests.Session()
135
+ retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
136
+ session.mount("https://", HTTPAdapter(max_retries=retries))
137
+
138
+ try:
139
+ result = session.get(endpoint_url, params=params, headers=headers, stream=True, timeout=(3.0, 15))
140
+ result.raise_for_status()
141
+ json_data = result.json()
142
+ return json_data if json_data else None
143
+ except Exception as e:
144
+ print(f"Error: {e}")
145
+ return None
146
+
147
+
148
+ class ModelInformation:
149
+ def __init__(self, json_data):
150
+ self.model_version_id = json_data.get("id", "")
151
+ self.model_id = json_data.get("modelId", "")
152
+ self.download_url = json_data.get("downloadUrl", "")
153
+ self.model_url = f"https://civitai.com/models/{self.model_id}?modelVersionId={self.model_version_id}"
154
+ self.filename_url = next(
155
+ (v.get("name", "") for v in json_data.get("files", []) if str(self.model_version_id) in v.get("downloadUrl", "")), ""
156
+ )
157
+ self.filename_url = self.filename_url if self.filename_url else ""
158
+ self.description = json_data.get("description", "")
159
+ if self.description is None: self.description = ""
160
+ self.model_name = json_data.get("model", {}).get("name", "")
161
+ self.model_type = json_data.get("model", {}).get("type", "")
162
+ self.nsfw = json_data.get("model", {}).get("nsfw", False)
163
+ self.poi = json_data.get("model", {}).get("poi", False)
164
+ self.images = [img.get("url", "") for img in json_data.get("images", [])]
165
+ self.example_prompt = json_data.get("trainedWords", [""])[0] if json_data.get("trainedWords") else ""
166
+ self.original_json = copy.deepcopy(json_data)
167
+
168
+
169
+ def retrieve_model_info(url):
170
+ json_data = request_json_data(url)
171
+ if not json_data:
172
+ return None
173
+ model_descriptor = ModelInformation(json_data)
174
+ return model_descriptor
175
+
176
+
177
+ def download_things(directory, url, hf_token="", civitai_api_key="", romanize=False):
178
+ hf_token = get_token()
179
  url = url.strip()
180
+ downloaded_file_path = None
181
+
182
  if "drive.google.com" in url:
183
  original_dir = os.getcwd()
184
  os.chdir(directory)
 
189
  # url = urllib.parse.quote(url, safe=':/') # fix encoding
190
  if "/blob/" in url:
191
  url = url.replace("/blob/", "/resolve/")
192
+
193
+ filename = unidecode(url.split('/')[-1]) if romanize else url.split('/')[-1]
194
+
195
+ download_hf_file(directory, url, filename, hf_token)
196
+
197
+ downloaded_file_path = os.path.join(directory, filename)
198
+
199
  elif "civitai.com" in url:
200
+
201
+ if not civitai_api_key:
 
 
 
 
202
  print("\033[91mYou need an API key to download Civitai models.\033[0m")
203
+
204
+ model_profile = retrieve_model_info(url)
205
+ if model_profile.download_url and model_profile.filename_url:
206
+ url = model_profile.download_url
207
+ filename = unidecode(model_profile.filename_url) if romanize else model_profile.filename_url
208
+ else:
209
+ if "?" in url:
210
+ url = url.split("?")[0]
211
+ filename = ""
212
+
213
+ url_dl = url + f"?token={civitai_api_key}"
214
+ print(f"Filename: {filename}")
215
+
216
+ param_filename = ""
217
+ if filename:
218
+ param_filename = f"-o '{filename}'"
219
+
220
+ aria2_command = (
221
+ f'aria2c --console-log-level=error --summary-interval=10 -c -x 16 '
222
+ f'-k 1M -s 16 -d "{directory}" {param_filename} "{url_dl}"'
223
+ )
224
+ os.system(aria2_command)
225
+
226
+ if param_filename and os.path.exists(os.path.join(directory, filename)):
227
+ downloaded_file_path = os.path.join(directory, filename)
228
+
229
  else:
230
  os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
231
 
232
+ return downloaded_file_path
233
+
234
 
235
  def get_download_file(temp_dir, url, civitai_key="", progress=gr.Progress(track_tqdm=True)):
236
  if not "http" in url and is_repo_name(url) and not Path(url).exists():
 
269
 
270
  def to_lora_path(key: str):
271
  if Path(key).is_file(): return key
272
+ path = Path(f"{DIRECTORY_LORAS}/{escape_lora_basename(key)}.safetensors")
273
  return str(path)
274
 
275
 
 
299
  raise Exception(f"Failed to save image file:") from e
300
 
301
 
302
+ def save_gallery_images(images, model_name="", progress=gr.Progress(track_tqdm=True)):
 
303
  progress(0, desc="Updating gallery...")
304
+ basename = f"{model_name.split('/')[-1]}_{datetime.now(FILENAME_TIMEZONE).strftime('%Y%m%d_%H%M%S')}_"
305
+ if not images: return images, gr.update()
 
 
306
  output_images = []
307
  output_paths = []
308
+ for i, image in enumerate(images):
309
+ filename = f"{basename}{str(i + 1)}.png"
 
310
  oldpath = Path(image[0])
311
  newpath = oldpath
312
  try:
313
  if oldpath.exists():
314
  newpath = oldpath.resolve().rename(Path(filename).resolve())
315
  except Exception as e:
316
+ print(e)
317
  finally:
318
  output_paths.append(str(newpath))
319
  output_images.append((str(newpath), str(filename)))
 
321
  return gr.update(value=output_images), gr.update(value=output_paths, visible=True)
322
 
323
 
324
+ def save_gallery_history(images, files, history_gallery, history_files, progress=gr.Progress(track_tqdm=True)):
325
+ if not images or not files: return gr.update(), gr.update()
326
+ if not history_gallery: history_gallery = []
327
+ if not history_files: history_files = []
328
+ output_gallery = images + history_gallery
329
+ output_files = files + history_files
330
+ return gr.update(value=output_gallery), gr.update(value=output_files, visible=True)
331
+
332
+
333
+ def save_image_history(image, gallery, files, model_name: str, progress=gr.Progress(track_tqdm=True)):
334
+ if not gallery: gallery = []
335
+ if not files: files = []
336
+ try:
337
+ basename = f"{model_name.split('/')[-1]}_{datetime.now(FILENAME_TIMEZONE).strftime('%Y%m%d_%H%M%S')}"
338
+ if image is None or not isinstance(image, (str, Image.Image, np.ndarray, tuple)): return gr.update(), gr.update()
339
+ filename = f"{basename}.png"
340
+ if isinstance(image, tuple): image = image[0]
341
+ if isinstance(image, str): oldpath = image
342
+ elif isinstance(image, Image.Image):
343
+ oldpath = "temp.png"
344
+ image.save(oldpath)
345
+ elif isinstance(image, np.ndarray):
346
+ oldpath = "temp.png"
347
+ Image.fromarray(image).convert('RGBA').save(oldpath)
348
+ oldpath = Path(oldpath)
349
+ newpath = oldpath
350
+ if oldpath.exists():
351
+ shutil.copy(oldpath.resolve(), Path(filename).resolve())
352
+ newpath = Path(filename).resolve()
353
+ files.insert(0, str(newpath))
354
+ gallery.insert(0, (str(newpath), str(filename)))
355
+ except Exception as e:
356
+ print(e)
357
+ finally:
358
+ return gr.update(value=gallery), gr.update(value=files, visible=True)
359
+
360
+
361
  def download_private_repo(repo_id, dir_path, is_replace):
362
+ if not HF_READ_TOKEN: return
363
  try:
364
+ snapshot_download(repo_id=repo_id, local_dir=dir_path, allow_patterns=['*.ckpt', '*.pt', '*.pth', '*.safetensors', '*.bin'], token=HF_READ_TOKEN)
365
  except Exception as e:
366
  print(f"Error: Failed to download {repo_id}.")
367
  print(e)
 
379
  def get_private_model_list(repo_id, dir_path):
380
  global private_model_path_repo_dict
381
  api = HfApi()
382
+ if not HF_READ_TOKEN: return []
383
  try:
384
+ files = api.list_repo_files(repo_id, token=HF_READ_TOKEN)
385
  except Exception as e:
386
  print(f"Error: Failed to list {repo_id}.")
387
  print(e)
 
399
  def download_private_file(repo_id, path, is_replace):
400
  file = Path(path)
401
  newpath = Path(f'{file.parent.name}/{escape_lora_basename(file.stem)}{file.suffix}') if is_replace else file
402
+ if not HF_READ_TOKEN or newpath.exists(): return
403
  filename = file.name
404
  dirname = file.parent.name
405
  try:
406
+ hf_hub_download(repo_id=repo_id, filename=filename, local_dir=dirname, token=HF_READ_TOKEN)
407
  except Exception as e:
408
  print(f"Error: Failed to download {filename}.")
409
  print(e)
 
533
  models1 = []
534
  models2 = []
535
  for repo in HF_LORA_PRIVATE_REPOS1:
536
+ models1.extend(get_private_model_list(repo, DIRECTORY_LORAS))
537
  for repo in HF_LORA_PRIVATE_REPOS2:
538
+ models2.extend(get_private_model_list(repo, DIRECTORY_LORAS))
539
  models = list_uniq(models1 + sorted(models2))
540
  private_lora_model_list = models.copy()
541
  return models
 
580
 
581
 
582
  def get_lora_model_list():
583
+ loras = list_uniq(get_private_lora_model_lists() + DIFFUSERS_FORMAT_LORAS + get_local_model_list(DIRECTORY_LORAS))
584
  loras.insert(0, "None")
585
  loras.insert(0, "")
586
  return loras
 
632
  def download_lora(dl_urls: str):
633
  global loras_url_to_path_dict
634
  dl_path = ""
635
+ before = get_local_model_list(DIRECTORY_LORAS)
636
  urls = []
637
  for url in [url.strip() for url in dl_urls.split(',')]:
638
+ local_path = f"{DIRECTORY_LORAS}/{url.split('/')[-1]}"
639
  if not Path(local_path).exists():
640
+ download_things(DIRECTORY_LORAS, url, HF_TOKEN, CIVITAI_API_KEY)
641
  urls.append(url)
642
+ after = get_local_model_list(DIRECTORY_LORAS)
643
  new_files = list_sub(after, before)
644
  i = 0
645
  for file in new_files:
 
890
  gr.update(value=tag5, label=label5, visible=on5), gr.update(visible=on5), gr.update(value=md5, visible=on5)
891
 
892
 
893
+ def get_my_lora(link_url, romanize):
894
+ l_name = ""
895
+ l_path = ""
896
+ before = get_local_model_list(DIRECTORY_LORAS)
897
  for url in [url.strip() for url in link_url.split(',')]:
898
+ if not Path(f"{DIRECTORY_LORAS}/{url.split('/')[-1]}").exists():
899
+ l_name = download_things(DIRECTORY_LORAS, url, HF_TOKEN, CIVITAI_API_KEY, romanize)
900
+ after = get_local_model_list(DIRECTORY_LORAS)
901
  new_files = list_sub(after, before)
902
  for file in new_files:
903
  path = Path(file)
 
905
  new_path = Path(f'{path.parent.name}/{escape_lora_basename(path.stem)}{path.suffix}')
906
  path.resolve().rename(new_path.resolve())
907
  update_lora_dict(str(new_path))
908
+ l_path = str(new_path)
909
  new_lora_model_list = get_lora_model_list()
910
  new_lora_tupled_list = get_all_lora_tupled_list()
911
+ msg_lora = "Downloaded"
912
+ if l_name:
913
+ msg_lora += f": <b>{l_name}</b>"
914
+ print(msg_lora)
915
+
916
  return gr.update(
917
+ choices=new_lora_tupled_list, value=l_path
918
  ), gr.update(
919
  choices=new_lora_tupled_list
920
  ), gr.update(
 
923
  choices=new_lora_tupled_list
924
  ), gr.update(
925
  choices=new_lora_tupled_list
926
+ ), gr.update(
927
+ value=msg_lora
928
  )
929
 
930
 
 
932
  progress(0, desc="Uploading...")
933
  file_paths = [file.name for file in files]
934
  progress(1, desc="Uploaded.")
935
+ return gr.update(value=file_paths, visible=True), gr.update()
936
 
937
 
938
  def move_file_lora(filepaths):
939
  for file in filepaths:
940
+ path = Path(shutil.move(Path(file).resolve(), Path(f"./{DIRECTORY_LORAS}").resolve()))
941
  newpath = Path(f'{path.parent.name}/{escape_lora_basename(path.stem)}{path.suffix}')
942
  path.resolve().rename(newpath.resolve())
943
  update_lora_dict(str(newpath))
 
1079
  selected = civitai_last_choices[selected_index][1]
1080
  return gr.update(value=selected)
1081
  except Exception:
1082
+ return gr.update()
1083
 
1084
 
1085
  def select_civitai_lora(search_result):
 
1563
  else:
1564
  return default
1565
 
1566
+
1567
+ EXAMPLES_GUI = [
1568
+ [
1569
+ "1girl, souryuu asuka langley, neon genesis evangelion, plugsuit, pilot suit, red bodysuit, sitting, crossing legs, black eye patch, cat hat, throne, symmetrical, looking down, from bottom, looking at viewer, outdoors, masterpiece, best quality, very aesthetic, absurdres",
1570
+ "nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
1571
+ 1,
1572
+ 30,
1573
+ 7.5,
1574
+ True,
1575
+ -1,
1576
+ "Euler a",
1577
+ 1152,
1578
+ 896,
1579
+ "votepurchase/animagine-xl-3.1",
1580
+ ],
1581
+ [
1582
+ "solo, princess Zelda OOT, score_9, score_8_up, score_8, medium breasts, cute, eyelashes, cute small face, long hair, crown braid, hairclip, pointy ears, soft curvy body, looking at viewer, smile, blush, white dress, medium body, (((holding the Master Sword))), standing, deep forest in the background",
1583
+ "score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white,",
1584
+ 1,
1585
+ 30,
1586
+ 5.,
1587
+ True,
1588
+ -1,
1589
+ "Euler a",
1590
+ 1024,
1591
+ 1024,
1592
+ "votepurchase/ponyDiffusionV6XL",
1593
+ ],
1594
+ [
1595
+ "1girl, oomuro sakurako, yuru yuri, official art, school uniform, anime artwork, anime style, vibrant, studio anime, highly detailed, masterpiece, best quality, very aesthetic, absurdres",
1596
+ "photo, deformed, black and white, realism, disfigured, low contrast, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
1597
+ 1,
1598
+ 40,
1599
+ 7.0,
1600
+ True,
1601
+ -1,
1602
+ "Euler a",
1603
+ 1024,
1604
+ 1024,
1605
+ "Raelina/Rae-Diffusion-XL-V2",
1606
+ ],
1607
+ [
1608
+ "1girl, akaza akari, yuru yuri, official art, anime artwork, anime style, vibrant, studio anime, highly detailed, masterpiece, best quality, very aesthetic, absurdres",
1609
+ "photo, deformed, black and white, realism, disfigured, low contrast, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
1610
+ 1,
1611
+ 35,
1612
+ 7.0,
1613
+ True,
1614
+ -1,
1615
+ "Euler a",
1616
+ 1024,
1617
+ 1024,
1618
+ "Raelina/Raemu-XL-V4",
1619
+ ],
1620
+ [
1621
+ "yoshida yuuko, machikado mazoku, 1girl, solo, demon horns,horns, school uniform, long hair, open mouth, skirt, demon girl, ahoge, shiny, shiny hair, anime artwork",
1622
+ "nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
1623
+ 1,
1624
+ 50,
1625
+ 7.,
1626
+ True,
1627
+ -1,
1628
+ "Euler a",
1629
+ 1024,
1630
+ 1024,
1631
+ "cagliostrolab/animagine-xl-3.1",
1632
+ ],
1633
+ ]
1634
+
1635
+
1636
+ RESOURCES = (
1637
+ """### Resources
1638
+ - You can also try the image generator in Colab’s free tier, which provides free GPU [link](https://github.com/R3gm/SD_diffusers_interactive).
1639
+ """
1640
+ )
requirements.txt CHANGED
@@ -12,3 +12,4 @@ translatepy
12
  timm
13
  rapidfuzz
14
  sentencepiece
 
 
12
  timm
13
  rapidfuzz
14
  sentencepiece
15
+ unidecode
tagger/character_series_dict.csv ADDED
The diff for this file is too large to render. See raw diff
 
tagger/danbooru_e621.csv ADDED
The diff for this file is too large to render. See raw diff
 
tagger/output.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+
4
+ @dataclass
5
+ class UpsamplingOutput:
6
+ upsampled_tags: str
7
+
8
+ copyright_tags: str
9
+ character_tags: str
10
+ general_tags: str
11
+ rating_tag: str
12
+ aspect_ratio_tag: str
13
+ length_tag: str
14
+ identity_tag: str
15
+
16
+ elapsed_time: float = 0.0
tagger/tag_group.csv ADDED
The diff for this file is too large to render. See raw diff
 
tagger/tagger.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ from PIL import Image
3
+ import torch
4
+ import gradio as gr
5
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
6
+ from pathlib import Path
7
+
8
+
9
+ WD_MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
10
+ WD_MODEL_NAME = WD_MODEL_NAMES[0]
11
+
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ default_device = device
14
+
15
+ try:
16
+ wd_model = AutoModelForImageClassification.from_pretrained(WD_MODEL_NAME, trust_remote_code=True).to(default_device).eval()
17
+ wd_processor = AutoImageProcessor.from_pretrained(WD_MODEL_NAME, trust_remote_code=True)
18
+ except Exception as e:
19
+ print(e)
20
+ wd_model = wd_processor = None
21
+
22
+ def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
23
+ return (
24
+ [f"1{noun}"]
25
+ + [f"{num}{noun}s" for num in range(minimum + 1, maximum + 1)]
26
+ + [f"{maximum+1}+{noun}s"]
27
+ )
28
+
29
+
30
+ PEOPLE_TAGS = (
31
+ _people_tag("girl") + _people_tag("boy") + _people_tag("other") + ["no humans"]
32
+ )
33
+
34
+
35
+ RATING_MAP = {
36
+ "sfw": "safe",
37
+ "general": "safe",
38
+ "sensitive": "sensitive",
39
+ "questionable": "nsfw",
40
+ "explicit": "explicit, nsfw",
41
+ }
42
+ DANBOORU_TO_E621_RATING_MAP = {
43
+ "sfw": "rating_safe",
44
+ "general": "rating_safe",
45
+ "safe": "rating_safe",
46
+ "sensitive": "rating_safe",
47
+ "nsfw": "rating_explicit",
48
+ "explicit, nsfw": "rating_explicit",
49
+ "explicit": "rating_explicit",
50
+ "rating:safe": "rating_safe",
51
+ "rating:general": "rating_safe",
52
+ "rating:sensitive": "rating_safe",
53
+ "rating:questionable, nsfw": "rating_explicit",
54
+ "rating:explicit, nsfw": "rating_explicit",
55
+ }
56
+
57
+
58
+ # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
59
+ kaomojis = [
60
+ "0_0",
61
+ "(o)_(o)",
62
+ "+_+",
63
+ "+_-",
64
+ "._.",
65
+ "<o>_<o>",
66
+ "<|>_<|>",
67
+ "=_=",
68
+ ">_<",
69
+ "3_3",
70
+ "6_9",
71
+ ">_o",
72
+ "@_@",
73
+ "^_^",
74
+ "o_o",
75
+ "u_u",
76
+ "x_x",
77
+ "|_|",
78
+ "||_||",
79
+ ]
80
+
81
+
82
+ def replace_underline(x: str):
83
+ return x.strip().replace("_", " ") if x not in kaomojis else x.strip()
84
+
85
+
86
+ def to_list(s):
87
+ return [x.strip() for x in s.split(",") if not s == ""]
88
+
89
+
90
+ def list_sub(a, b):
91
+ return [e for e in a if e not in b]
92
+
93
+
94
+ def list_uniq(l):
95
+ return sorted(set(l), key=l.index)
96
+
97
+
98
+ def load_dict_from_csv(filename):
99
+ dict = {}
100
+ if not Path(filename).exists():
101
+ if Path('./tagger/', filename).exists(): filename = str(Path('./tagger/', filename))
102
+ else: return dict
103
+ try:
104
+ with open(filename, 'r', encoding="utf-8") as f:
105
+ lines = f.readlines()
106
+ except Exception:
107
+ print(f"Failed to open dictionary file: {filename}")
108
+ return dict
109
+ for line in lines:
110
+ parts = line.strip().split(',')
111
+ dict[parts[0]] = parts[1]
112
+ return dict
113
+
114
+
115
+ anime_series_dict = load_dict_from_csv('character_series_dict.csv')
116
+
117
+
118
+ def character_list_to_series_list(character_list):
119
+ output_series_tag = []
120
+ series_tag = ""
121
+ series_dict = anime_series_dict
122
+ for tag in character_list:
123
+ series_tag = series_dict.get(tag, "")
124
+ if tag.endswith(")"):
125
+ tags = tag.split("(")
126
+ character_tag = "(".join(tags[:-1])
127
+ if character_tag.endswith(" "):
128
+ character_tag = character_tag[:-1]
129
+ series_tag = tags[-1].replace(")", "")
130
+
131
+ if series_tag:
132
+ output_series_tag.append(series_tag)
133
+
134
+ return output_series_tag
135
+
136
+
137
+ def select_random_character(series: str, character: str):
138
+ from random import seed, randrange
139
+ seed()
140
+ character_list = list(anime_series_dict.keys())
141
+ character = character_list[randrange(len(character_list) - 1)]
142
+ series = anime_series_dict.get(character.split(",")[0].strip(), "")
143
+ return series, character
144
+
145
+
146
+ def danbooru_to_e621(dtag, e621_dict):
147
+ def d_to_e(match, e621_dict):
148
+ dtag = match.group(0)
149
+ etag = e621_dict.get(replace_underline(dtag), "")
150
+ if etag:
151
+ return etag
152
+ else:
153
+ return dtag
154
+
155
+ import re
156
+ tag = re.sub(r'[\w ]+', lambda wrapper: d_to_e(wrapper, e621_dict), dtag, 2)
157
+ return tag
158
+
159
+
160
+ danbooru_to_e621_dict = load_dict_from_csv('danbooru_e621.csv')
161
+
162
+
163
+ def convert_danbooru_to_e621_prompt(input_prompt: str = "", prompt_type: str = "danbooru"):
164
+ if prompt_type == "danbooru": return input_prompt
165
+ tags = input_prompt.split(",") if input_prompt else []
166
+ people_tags: list[str] = []
167
+ other_tags: list[str] = []
168
+ rating_tags: list[str] = []
169
+
170
+ e621_dict = danbooru_to_e621_dict
171
+ for tag in tags:
172
+ tag = replace_underline(tag)
173
+ tag = danbooru_to_e621(tag, e621_dict)
174
+ if tag in PEOPLE_TAGS:
175
+ people_tags.append(tag)
176
+ elif tag in DANBOORU_TO_E621_RATING_MAP.keys():
177
+ rating_tags.append(DANBOORU_TO_E621_RATING_MAP.get(tag.replace(" ",""), ""))
178
+ else:
179
+ other_tags.append(tag)
180
+
181
+ rating_tags = sorted(set(rating_tags), key=rating_tags.index)
182
+ rating_tags = [rating_tags[0]] if rating_tags else []
183
+ rating_tags = ["explicit, nsfw"] if rating_tags and rating_tags[0] == "explicit" else rating_tags
184
+
185
+ output_prompt = ", ".join(people_tags + other_tags + rating_tags)
186
+
187
+ return output_prompt
188
+
189
+
190
+ from translatepy import Translator
191
+ translator = Translator()
192
+ def translate_prompt_old(prompt: str = ""):
193
+ def translate_to_english(input: str):
194
+ try:
195
+ output = str(translator.translate(input, 'English'))
196
+ except Exception as e:
197
+ output = input
198
+ print(e)
199
+ return output
200
+
201
+ def is_japanese(s):
202
+ import unicodedata
203
+ for ch in s:
204
+ name = unicodedata.name(ch, "")
205
+ if "CJK UNIFIED" in name or "HIRAGANA" in name or "KATAKANA" in name:
206
+ return True
207
+ return False
208
+
209
+ def to_list(s):
210
+ return [x.strip() for x in s.split(",")]
211
+
212
+ prompts = to_list(prompt)
213
+ outputs = []
214
+ for p in prompts:
215
+ p = translate_to_english(p) if is_japanese(p) else p
216
+ outputs.append(p)
217
+
218
+ return ", ".join(outputs)
219
+
220
+
221
+ def translate_prompt(input: str):
222
+ try:
223
+ output = str(translator.translate(input, 'English'))
224
+ except Exception as e:
225
+ output = input
226
+ print(e)
227
+ return output
228
+
229
+
230
+ def translate_prompt_to_ja(prompt: str = ""):
231
+ def translate_to_japanese(input: str):
232
+ try:
233
+ output = str(translator.translate(input, 'Japanese'))
234
+ except Exception as e:
235
+ output = input
236
+ print(e)
237
+ return output
238
+
239
+ def is_japanese(s):
240
+ import unicodedata
241
+ for ch in s:
242
+ name = unicodedata.name(ch, "")
243
+ if "CJK UNIFIED" in name or "HIRAGANA" in name or "KATAKANA" in name:
244
+ return True
245
+ return False
246
+
247
+ def to_list(s):
248
+ return [x.strip() for x in s.split(",")]
249
+
250
+ prompts = to_list(prompt)
251
+ outputs = []
252
+ for p in prompts:
253
+ p = translate_to_japanese(p) if not is_japanese(p) else p
254
+ outputs.append(p)
255
+
256
+ return ", ".join(outputs)
257
+
258
+
259
+ def tags_to_ja(itag, dict):
260
+ def t_to_j(match, dict):
261
+ tag = match.group(0)
262
+ ja = dict.get(replace_underline(tag), "")
263
+ if ja:
264
+ return ja
265
+ else:
266
+ return tag
267
+
268
+ import re
269
+ tag = re.sub(r'[\w ]+', lambda wrapper: t_to_j(wrapper, dict), itag, 2)
270
+
271
+ return tag
272
+
273
+
274
+ def convert_tags_to_ja(input_prompt: str = ""):
275
+ tags = input_prompt.split(",") if input_prompt else []
276
+ out_tags = []
277
+
278
+ tags_to_ja_dict = load_dict_from_csv('all_tags_ja_ext.csv')
279
+ dict = tags_to_ja_dict
280
+ for tag in tags:
281
+ tag = replace_underline(tag)
282
+ tag = tags_to_ja(tag, dict)
283
+ out_tags.append(tag)
284
+
285
+ return ", ".join(out_tags)
286
+
287
+
288
+ enable_auto_recom_prompt = True
289
+
290
+
291
+ animagine_ps = to_list("masterpiece, best quality, very aesthetic, absurdres")
292
+ animagine_nps = to_list("lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]")
293
+ pony_ps = to_list("score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres")
294
+ pony_nps = to_list("source_pony, score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends")
295
+ other_ps = to_list("anime artwork, anime style, studio anime, highly detailed, cinematic photo, 35mm photograph, film, bokeh, professional, 4k, highly detailed")
296
+ other_nps = to_list("photo, deformed, black and white, realism, disfigured, low contrast, drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly")
297
+ default_ps = to_list("highly detailed, masterpiece, best quality, very aesthetic, absurdres")
298
+ default_nps = to_list("score_6, score_5, score_4, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]")
299
+ def insert_recom_prompt(prompt: str = "", neg_prompt: str = "", type: str = "None"):
300
+ global enable_auto_recom_prompt
301
+ prompts = to_list(prompt)
302
+ neg_prompts = to_list(neg_prompt)
303
+
304
+ prompts = list_sub(prompts, animagine_ps + pony_ps)
305
+ neg_prompts = list_sub(neg_prompts, animagine_nps + pony_nps)
306
+
307
+ last_empty_p = [""] if not prompts and type != "None" else []
308
+ last_empty_np = [""] if not neg_prompts and type != "None" else []
309
+
310
+ if type == "Auto":
311
+ enable_auto_recom_prompt = True
312
+ else:
313
+ enable_auto_recom_prompt = False
314
+ if type == "Animagine":
315
+ prompts = prompts + animagine_ps
316
+ neg_prompts = neg_prompts + animagine_nps
317
+ elif type == "Pony":
318
+ prompts = prompts + pony_ps
319
+ neg_prompts = neg_prompts + pony_nps
320
+
321
+ prompt = ", ".join(list_uniq(prompts) + last_empty_p)
322
+ neg_prompt = ", ".join(list_uniq(neg_prompts) + last_empty_np)
323
+
324
+ return prompt, neg_prompt
325
+
326
+
327
+ def load_model_prompt_dict():
328
+ import json
329
+ dict = {}
330
+ path = 'model_dict.json' if Path('model_dict.json').exists() else './tagger/model_dict.json'
331
+ try:
332
+ with open('model_dict.json', encoding='utf-8') as f:
333
+ dict = json.load(f)
334
+ except Exception:
335
+ pass
336
+ return dict
337
+
338
+
339
+ model_prompt_dict = load_model_prompt_dict()
340
+
341
+
342
+ def insert_model_recom_prompt(prompt: str = "", neg_prompt: str = "", model_name: str = "None"):
343
+ if not model_name or not enable_auto_recom_prompt: return prompt, neg_prompt
344
+ prompts = to_list(prompt)
345
+ neg_prompts = to_list(neg_prompt)
346
+ prompts = list_sub(prompts, animagine_ps + pony_ps + other_ps)
347
+ neg_prompts = list_sub(neg_prompts, animagine_nps + pony_nps + other_nps)
348
+ last_empty_p = [""] if not prompts and type != "None" else []
349
+ last_empty_np = [""] if not neg_prompts and type != "None" else []
350
+ ps = []
351
+ nps = []
352
+ if model_name in model_prompt_dict.keys():
353
+ ps = to_list(model_prompt_dict[model_name]["prompt"])
354
+ nps = to_list(model_prompt_dict[model_name]["negative_prompt"])
355
+ else:
356
+ ps = default_ps
357
+ nps = default_nps
358
+ prompts = prompts + ps
359
+ neg_prompts = neg_prompts + nps
360
+ prompt = ", ".join(list_uniq(prompts) + last_empty_p)
361
+ neg_prompt = ", ".join(list_uniq(neg_prompts) + last_empty_np)
362
+ return prompt, neg_prompt
363
+
364
+
365
+ tag_group_dict = load_dict_from_csv('tag_group.csv')
366
+
367
+
368
+ def remove_specific_prompt(input_prompt: str = "", keep_tags: str = "all"):
369
+ def is_dressed(tag):
370
+ import re
371
+ p = re.compile(r'dress|cloth|uniform|costume|vest|sweater|coat|shirt|jacket|blazer|apron|leotard|hood|sleeve|skirt|shorts|pant|loafer|ribbon|necktie|bow|collar|glove|sock|shoe|boots|wear|emblem')
372
+ return p.search(tag)
373
+
374
+ def is_background(tag):
375
+ import re
376
+ p = re.compile(r'background|outline|light|sky|build|day|screen|tree|city')
377
+ return p.search(tag)
378
+
379
+ un_tags = ['solo']
380
+ group_list = ['groups', 'body_parts', 'attire', 'posture', 'objects', 'creatures', 'locations', 'disambiguation_pages', 'commonly_misused_tags', 'phrases', 'verbs_and_gerunds', 'subjective', 'nudity', 'sex_objects', 'sex', 'sex_acts', 'image_composition', 'artistic_license', 'text', 'year_tags', 'metatags']
381
+ keep_group_dict = {
382
+ "body": ['groups', 'body_parts'],
383
+ "dress": ['groups', 'body_parts', 'attire'],
384
+ "all": group_list,
385
+ }
386
+
387
+ def is_necessary(tag, keep_tags, group_dict):
388
+ if keep_tags == "all":
389
+ return True
390
+ elif tag in un_tags or group_dict.get(tag, "") in explicit_group:
391
+ return False
392
+ elif keep_tags == "body" and is_dressed(tag):
393
+ return False
394
+ elif is_background(tag):
395
+ return False
396
+ else:
397
+ return True
398
+
399
+ if keep_tags == "all": return input_prompt
400
+ keep_group = keep_group_dict.get(keep_tags, keep_group_dict["body"])
401
+ explicit_group = list(set(group_list) ^ set(keep_group))
402
+
403
+ tags = input_prompt.split(",") if input_prompt else []
404
+ people_tags: list[str] = []
405
+ other_tags: list[str] = []
406
+
407
+ group_dict = tag_group_dict
408
+ for tag in tags:
409
+ tag = replace_underline(tag)
410
+ if tag in PEOPLE_TAGS:
411
+ people_tags.append(tag)
412
+ elif is_necessary(tag, keep_tags, group_dict):
413
+ other_tags.append(tag)
414
+
415
+ output_prompt = ", ".join(people_tags + other_tags)
416
+
417
+ return output_prompt
418
+
419
+
420
+ def sort_taglist(tags: list[str]):
421
+ if not tags: return []
422
+ character_tags: list[str] = []
423
+ series_tags: list[str] = []
424
+ people_tags: list[str] = []
425
+ group_list = ['groups', 'body_parts', 'attire', 'posture', 'objects', 'creatures', 'locations', 'disambiguation_pages', 'commonly_misused_tags', 'phrases', 'verbs_and_gerunds', 'subjective', 'nudity', 'sex_objects', 'sex', 'sex_acts', 'image_composition', 'artistic_license', 'text', 'year_tags', 'metatags']
426
+ group_tags = {}
427
+ other_tags: list[str] = []
428
+ rating_tags: list[str] = []
429
+
430
+ group_dict = tag_group_dict
431
+ group_set = set(group_dict.keys())
432
+ character_set = set(anime_series_dict.keys())
433
+ series_set = set(anime_series_dict.values())
434
+ rating_set = set(DANBOORU_TO_E621_RATING_MAP.keys()) | set(DANBOORU_TO_E621_RATING_MAP.values())
435
+
436
+ for tag in tags:
437
+ tag = replace_underline(tag)
438
+ if tag in PEOPLE_TAGS:
439
+ people_tags.append(tag)
440
+ elif tag in rating_set:
441
+ rating_tags.append(tag)
442
+ elif tag in group_set:
443
+ elem = group_dict[tag]
444
+ group_tags[elem] = group_tags[elem] + [tag] if elem in group_tags else [tag]
445
+ elif tag in character_set:
446
+ character_tags.append(tag)
447
+ elif tag in series_set:
448
+ series_tags.append(tag)
449
+ else:
450
+ other_tags.append(tag)
451
+
452
+ output_group_tags: list[str] = []
453
+ for k in group_list:
454
+ output_group_tags.extend(group_tags.get(k, []))
455
+
456
+ rating_tags = [rating_tags[0]] if rating_tags else []
457
+ rating_tags = ["explicit, nsfw"] if rating_tags and rating_tags[0] == "explicit" else rating_tags
458
+
459
+ output_tags = character_tags + series_tags + people_tags + output_group_tags + other_tags + rating_tags
460
+
461
+ return output_tags
462
+
463
+
464
+ def sort_tags(tags: str):
465
+ if not tags: return ""
466
+ taglist: list[str] = []
467
+ for tag in tags.split(","):
468
+ taglist.append(tag.strip())
469
+ taglist = list(filter(lambda x: x != "", taglist))
470
+ return ", ".join(sort_taglist(taglist))
471
+
472
+
473
+ def postprocess_results(results: dict[str, float], general_threshold: float, character_threshold: float):
474
+ results = {
475
+ k: v for k, v in sorted(results.items(), key=lambda item: item[1], reverse=True)
476
+ }
477
+
478
+ rating = {}
479
+ character = {}
480
+ general = {}
481
+
482
+ for k, v in results.items():
483
+ if k.startswith("rating:"):
484
+ rating[k.replace("rating:", "")] = v
485
+ continue
486
+ elif k.startswith("character:"):
487
+ character[k.replace("character:", "")] = v
488
+ continue
489
+
490
+ general[k] = v
491
+
492
+ character = {k: v for k, v in character.items() if v >= character_threshold}
493
+ general = {k: v for k, v in general.items() if v >= general_threshold}
494
+
495
+ return rating, character, general
496
+
497
+
498
+ def gen_prompt(rating: list[str], character: list[str], general: list[str]):
499
+ people_tags: list[str] = []
500
+ other_tags: list[str] = []
501
+ rating_tag = RATING_MAP[rating[0]]
502
+
503
+ for tag in general:
504
+ if tag in PEOPLE_TAGS:
505
+ people_tags.append(tag)
506
+ else:
507
+ other_tags.append(tag)
508
+
509
+ all_tags = people_tags + other_tags
510
+
511
+ return ", ".join(all_tags)
512
+
513
+
514
+ @spaces.GPU(duration=30)
515
+ def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_threshold: float = 0.8):
516
+ inputs = wd_processor.preprocess(image, return_tensors="pt")
517
+
518
+ outputs = wd_model(**inputs.to(wd_model.device, wd_model.dtype))
519
+ logits = torch.sigmoid(outputs.logits[0]) # take the first logits
520
+
521
+ # get probabilities
522
+ if device != default_device: wd_model.to(device=device)
523
+ results = {
524
+ wd_model.config.id2label[i]: float(logit.float()) for i, logit in enumerate(logits)
525
+ }
526
+ if device != default_device: wd_model.to(device=default_device)
527
+ # rating, character, general
528
+ rating, character, general = postprocess_results(
529
+ results, general_threshold, character_threshold
530
+ )
531
+ prompt = gen_prompt(
532
+ list(rating.keys()), list(character.keys()), list(general.keys())
533
+ )
534
+ output_series_tag = ""
535
+ output_series_list = character_list_to_series_list(character.keys())
536
+ if output_series_list:
537
+ output_series_tag = output_series_list[0]
538
+ else:
539
+ output_series_tag = ""
540
+ return output_series_tag, ", ".join(character.keys()), prompt, gr.update(interactive=True)
541
+
542
+
543
+ def predict_tags_wd(image: Image.Image, input_tags: str, algo: list[str], general_threshold: float = 0.3,
544
+ character_threshold: float = 0.8, input_series: str = "", input_character: str = ""):
545
+ if not "Use WD Tagger" in algo and len(algo) != 0:
546
+ return input_series, input_character, input_tags, gr.update(interactive=True)
547
+ return predict_tags(image, general_threshold, character_threshold)
548
+
549
+
550
+ def compose_prompt_to_copy(character: str, series: str, general: str):
551
+ characters = character.split(",") if character else []
552
+ serieses = series.split(",") if series else []
553
+ generals = general.split(",") if general else []
554
+ tags = characters + serieses + generals
555
+ cprompt = ",".join(tags) if tags else ""
556
+ return cprompt
tagger/utils.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from dartrs.v2 import AspectRatioTag, LengthTag, RatingTag, IdentityTag
3
+
4
+
5
+ V2_ASPECT_RATIO_OPTIONS: list[AspectRatioTag] = [
6
+ "ultra_wide",
7
+ "wide",
8
+ "square",
9
+ "tall",
10
+ "ultra_tall",
11
+ ]
12
+ V2_RATING_OPTIONS: list[RatingTag] = [
13
+ "sfw",
14
+ "general",
15
+ "sensitive",
16
+ "nsfw",
17
+ "questionable",
18
+ "explicit",
19
+ ]
20
+ V2_LENGTH_OPTIONS: list[LengthTag] = [
21
+ "very_short",
22
+ "short",
23
+ "medium",
24
+ "long",
25
+ "very_long",
26
+ ]
27
+ V2_IDENTITY_OPTIONS: list[IdentityTag] = [
28
+ "none",
29
+ "lax",
30
+ "strict",
31
+ ]
32
+
33
+
34
+ # ref: https://qiita.com/tregu148/items/fccccbbc47d966dd2fc2
35
+ def gradio_copy_text(_text: None):
36
+ gr.Info("Copied!")
37
+
38
+
39
+ COPY_ACTION_JS = """\
40
+ (inputs, _outputs) => {
41
+ // inputs is the string value of the input_text
42
+ if (inputs.trim() !== "") {
43
+ navigator.clipboard.writeText(inputs);
44
+ }
45
+ }"""
46
+
47
+
48
+ def gradio_copy_prompt(prompt: str):
49
+ gr.Info("Copied!")
50
+ return prompt
tagger/v2.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ from typing import Callable
4
+ from pathlib import Path
5
+
6
+ from dartrs.v2 import (
7
+ V2Model,
8
+ MixtralModel,
9
+ MistralModel,
10
+ compose_prompt,
11
+ LengthTag,
12
+ AspectRatioTag,
13
+ RatingTag,
14
+ IdentityTag,
15
+ )
16
+ from dartrs.dartrs import DartTokenizer
17
+ from dartrs.utils import get_generation_config
18
+
19
+
20
+ import gradio as gr
21
+ from gradio.components import Component
22
+
23
+
24
+ try:
25
+ from output import UpsamplingOutput
26
+ except:
27
+ from .output import UpsamplingOutput
28
+
29
+
30
+ V2_ALL_MODELS = {
31
+ "dart-v2-moe-sft": {
32
+ "repo": "p1atdev/dart-v2-moe-sft",
33
+ "type": "sft",
34
+ "class": MixtralModel,
35
+ },
36
+ "dart-v2-sft": {
37
+ "repo": "p1atdev/dart-v2-sft",
38
+ "type": "sft",
39
+ "class": MistralModel,
40
+ },
41
+ }
42
+
43
+
44
+ def prepare_models(model_config: dict):
45
+ model_name = model_config["repo"]
46
+ tokenizer = DartTokenizer.from_pretrained(model_name)
47
+ model = model_config["class"].from_pretrained(model_name)
48
+
49
+ return {
50
+ "tokenizer": tokenizer,
51
+ "model": model,
52
+ }
53
+
54
+
55
+ def normalize_tags(tokenizer: DartTokenizer, tags: str):
56
+ """Just remove unk tokens."""
57
+ return ", ".join([tag for tag in tokenizer.tokenize(tags) if tag != "<|unk|>"])
58
+
59
+
60
+ @torch.no_grad()
61
+ def generate_tags(
62
+ model: V2Model,
63
+ tokenizer: DartTokenizer,
64
+ prompt: str,
65
+ ban_token_ids: list[int],
66
+ ):
67
+ output = model.generate(
68
+ get_generation_config(
69
+ prompt,
70
+ tokenizer=tokenizer,
71
+ temperature=1,
72
+ top_p=0.9,
73
+ top_k=100,
74
+ max_new_tokens=256,
75
+ ban_token_ids=ban_token_ids,
76
+ ),
77
+ )
78
+
79
+ return output
80
+
81
+
82
+ def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
83
+ return (
84
+ [f"1{noun}"]
85
+ + [f"{num}{noun}s" for num in range(minimum + 1, maximum + 1)]
86
+ + [f"{maximum+1}+{noun}s"]
87
+ )
88
+
89
+
90
+ PEOPLE_TAGS = (
91
+ _people_tag("girl") + _people_tag("boy") + _people_tag("other") + ["no humans"]
92
+ )
93
+
94
+
95
+ def gen_prompt_text(output: UpsamplingOutput):
96
+ # separate people tags (e.g. 1girl)
97
+ people_tags = []
98
+ other_general_tags = []
99
+
100
+ for tag in output.general_tags.split(","):
101
+ tag = tag.strip()
102
+ if tag in PEOPLE_TAGS:
103
+ people_tags.append(tag)
104
+ else:
105
+ other_general_tags.append(tag)
106
+
107
+ return ", ".join(
108
+ [
109
+ part.strip()
110
+ for part in [
111
+ *people_tags,
112
+ output.character_tags,
113
+ output.copyright_tags,
114
+ *other_general_tags,
115
+ output.upsampled_tags,
116
+ output.rating_tag,
117
+ ]
118
+ if part.strip() != ""
119
+ ]
120
+ )
121
+
122
+
123
+ def elapsed_time_format(elapsed_time: float) -> str:
124
+ return f"Elapsed: {elapsed_time:.2f} seconds"
125
+
126
+
127
+ def parse_upsampling_output(
128
+ upsampler: Callable[..., UpsamplingOutput],
129
+ ):
130
+ def _parse_upsampling_output(*args) -> tuple[str, str, dict]:
131
+ output = upsampler(*args)
132
+
133
+ return (
134
+ gen_prompt_text(output),
135
+ elapsed_time_format(output.elapsed_time),
136
+ gr.update(interactive=True),
137
+ gr.update(interactive=True),
138
+ )
139
+
140
+ return _parse_upsampling_output
141
+
142
+
143
+ class V2UI:
144
+ model_name: str | None = None
145
+ model: V2Model
146
+ tokenizer: DartTokenizer
147
+
148
+ input_components: list[Component] = []
149
+ generate_btn: gr.Button
150
+
151
+ def on_generate(
152
+ self,
153
+ model_name: str,
154
+ copyright_tags: str,
155
+ character_tags: str,
156
+ general_tags: str,
157
+ rating_tag: RatingTag,
158
+ aspect_ratio_tag: AspectRatioTag,
159
+ length_tag: LengthTag,
160
+ identity_tag: IdentityTag,
161
+ ban_tags: str,
162
+ *args,
163
+ ) -> UpsamplingOutput:
164
+ if self.model_name is None or self.model_name != model_name:
165
+ models = prepare_models(V2_ALL_MODELS[model_name])
166
+ self.model = models["model"]
167
+ self.tokenizer = models["tokenizer"]
168
+ self.model_name = model_name
169
+
170
+ # normalize tags
171
+ # copyright_tags = normalize_tags(self.tokenizer, copyright_tags)
172
+ # character_tags = normalize_tags(self.tokenizer, character_tags)
173
+ # general_tags = normalize_tags(self.tokenizer, general_tags)
174
+
175
+ ban_token_ids = self.tokenizer.encode(ban_tags.strip())
176
+
177
+ prompt = compose_prompt(
178
+ prompt=general_tags,
179
+ copyright=copyright_tags,
180
+ character=character_tags,
181
+ rating=rating_tag,
182
+ aspect_ratio=aspect_ratio_tag,
183
+ length=length_tag,
184
+ identity=identity_tag,
185
+ )
186
+
187
+ start = time.time()
188
+ upsampled_tags = generate_tags(
189
+ self.model,
190
+ self.tokenizer,
191
+ prompt,
192
+ ban_token_ids,
193
+ )
194
+ elapsed_time = time.time() - start
195
+
196
+ return UpsamplingOutput(
197
+ upsampled_tags=upsampled_tags,
198
+ copyright_tags=copyright_tags,
199
+ character_tags=character_tags,
200
+ general_tags=general_tags,
201
+ rating_tag=rating_tag,
202
+ aspect_ratio_tag=aspect_ratio_tag,
203
+ length_tag=length_tag,
204
+ identity_tag=identity_tag,
205
+ elapsed_time=elapsed_time,
206
+ )
207
+
208
+
209
+ def parse_upsampling_output_simple(upsampler: UpsamplingOutput):
210
+ return gen_prompt_text(upsampler)
211
+
212
+
213
+ v2 = V2UI()
214
+
215
+
216
+ def v2_upsampling_prompt(model: str = "dart-v2-moe-sft", copyright: str = "", character: str = "",
217
+ general_tags: str = "", rating: str = "nsfw", aspect_ratio: str = "square",
218
+ length: str = "very_long", identity: str = "lax", ban_tags: str = "censored"):
219
+ raw_prompt = parse_upsampling_output_simple(v2.on_generate(model, copyright, character, general_tags,
220
+ rating, aspect_ratio, length, identity, ban_tags))
221
+ return raw_prompt
222
+
223
+
224
+ def load_dict_from_csv(filename):
225
+ dict = {}
226
+ if not Path(filename).exists():
227
+ if Path('./tagger/', filename).exists(): filename = str(Path('./tagger/', filename))
228
+ else: return dict
229
+ try:
230
+ with open(filename, 'r', encoding="utf-8") as f:
231
+ lines = f.readlines()
232
+ except Exception:
233
+ print(f"Failed to open dictionary file: {filename}")
234
+ return dict
235
+ for line in lines:
236
+ parts = line.strip().split(',')
237
+ dict[parts[0]] = parts[1]
238
+ return dict
239
+
240
+
241
+ anime_series_dict = load_dict_from_csv('character_series_dict.csv')
242
+
243
+
244
+ def select_random_character(series: str, character: str):
245
+ from random import seed, randrange
246
+ seed()
247
+ character_list = list(anime_series_dict.keys())
248
+ character = character_list[randrange(len(character_list) - 1)]
249
+ series = anime_series_dict.get(character.split(",")[0].strip(), "")
250
+ return series, character
251
+
252
+
253
+ def v2_random_prompt(general_tags: str = "", copyright: str = "", character: str = "", rating: str = "nsfw",
254
+ aspect_ratio: str = "square", length: str = "very_long", identity: str = "lax",
255
+ ban_tags: str = "censored", model: str = "dart-v2-moe-sft"):
256
+ if copyright == "" and character == "":
257
+ copyright, character = select_random_character("", "")
258
+ raw_prompt = v2_upsampling_prompt(model, copyright, character, general_tags, rating,
259
+ aspect_ratio, length, identity, ban_tags)
260
+ return raw_prompt, copyright, character
utils.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import gradio as gr
4
+ from constants import (
5
+ DIFFUSERS_FORMAT_LORAS,
6
+ CIVITAI_API_KEY,
7
+ HF_TOKEN,
8
+ MODEL_TYPE_CLASS,
9
+ DIRECTORY_LORAS,
10
+ )
11
+ from huggingface_hub import HfApi
12
+ from diffusers import DiffusionPipeline
13
+ from huggingface_hub import model_info as model_info_data
14
+ from diffusers.pipelines.pipeline_loading_utils import variant_compatible_siblings
15
+ from pathlib import PosixPath
16
+ from unidecode import unidecode
17
+ import urllib.parse
18
+ import copy
19
+ import requests
20
+ from requests.adapters import HTTPAdapter
21
+ from urllib3.util import Retry
22
+
23
+ USER_AGENT = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:127.0) Gecko/20100101 Firefox/127.0'
24
+
25
+
26
+ def request_json_data(url):
27
+ model_version_id = url.split('/')[-1]
28
+ if "?modelVersionId=" in model_version_id:
29
+ match = re.search(r'modelVersionId=(\d+)', url)
30
+ model_version_id = match.group(1)
31
+
32
+ endpoint_url = f"https://civitai.com/api/v1/model-versions/{model_version_id}"
33
+
34
+ params = {}
35
+ headers = {'User-Agent': USER_AGENT, 'content-type': 'application/json'}
36
+ session = requests.Session()
37
+ retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
38
+ session.mount("https://", HTTPAdapter(max_retries=retries))
39
+
40
+ try:
41
+ result = session.get(endpoint_url, params=params, headers=headers, stream=True, timeout=(3.0, 15))
42
+ result.raise_for_status()
43
+ json_data = result.json()
44
+ return json_data if json_data else None
45
+ except Exception as e:
46
+ print(f"Error: {e}")
47
+ return None
48
+
49
+
50
+ class ModelInformation:
51
+ def __init__(self, json_data):
52
+ self.model_version_id = json_data.get("id", "")
53
+ self.model_id = json_data.get("modelId", "")
54
+ self.download_url = json_data.get("downloadUrl", "")
55
+ self.model_url = f"https://civitai.com/models/{self.model_id}?modelVersionId={self.model_version_id}"
56
+ self.filename_url = next(
57
+ (v.get("name", "") for v in json_data.get("files", []) if str(self.model_version_id) in v.get("downloadUrl", "")), ""
58
+ )
59
+ self.filename_url = self.filename_url if self.filename_url else ""
60
+ self.description = json_data.get("description", "")
61
+ if self.description is None: self.description = ""
62
+ self.model_name = json_data.get("model", {}).get("name", "")
63
+ self.model_type = json_data.get("model", {}).get("type", "")
64
+ self.nsfw = json_data.get("model", {}).get("nsfw", False)
65
+ self.poi = json_data.get("model", {}).get("poi", False)
66
+ self.images = [img.get("url", "") for img in json_data.get("images", [])]
67
+ self.example_prompt = json_data.get("trainedWords", [""])[0] if json_data.get("trainedWords") else ""
68
+ self.original_json = copy.deepcopy(json_data)
69
+
70
+
71
+ def retrieve_model_info(url):
72
+ json_data = request_json_data(url)
73
+ if not json_data:
74
+ return None
75
+ model_descriptor = ModelInformation(json_data)
76
+ return model_descriptor
77
+
78
+
79
+ def download_things(directory, url, hf_token="", civitai_api_key="", romanize=False):
80
+ url = url.strip()
81
+ downloaded_file_path = None
82
+
83
+ if "drive.google.com" in url:
84
+ original_dir = os.getcwd()
85
+ os.chdir(directory)
86
+ os.system(f"gdown --fuzzy {url}")
87
+ os.chdir(original_dir)
88
+ elif "huggingface.co" in url:
89
+ url = url.replace("?download=true", "")
90
+ # url = urllib.parse.quote(url, safe=':/') # fix encoding
91
+ if "/blob/" in url:
92
+ url = url.replace("/blob/", "/resolve/")
93
+ user_header = f'"Authorization: Bearer {hf_token}"'
94
+
95
+ filename = unidecode(url.split('/')[-1]) if romanize else url.split('/')[-1]
96
+
97
+ if hf_token:
98
+ os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory} -o {filename}")
99
+ else:
100
+ os.system(f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {filename}")
101
+
102
+ downloaded_file_path = os.path.join(directory, filename)
103
+
104
+ elif "civitai.com" in url:
105
+
106
+ if not civitai_api_key:
107
+ print("\033[91mYou need an API key to download Civitai models.\033[0m")
108
+
109
+ model_profile = retrieve_model_info(url)
110
+ if model_profile.download_url and model_profile.filename_url:
111
+ url = model_profile.download_url
112
+ filename = unidecode(model_profile.filename_url) if romanize else model_profile.filename_url
113
+ else:
114
+ if "?" in url:
115
+ url = url.split("?")[0]
116
+ filename = ""
117
+
118
+ url_dl = url + f"?token={civitai_api_key}"
119
+ print(f"Filename: {filename}")
120
+
121
+ param_filename = ""
122
+ if filename:
123
+ param_filename = f"-o '{filename}'"
124
+
125
+ aria2_command = (
126
+ f'aria2c --console-log-level=error --summary-interval=10 -c -x 16 '
127
+ f'-k 1M -s 16 -d "{directory}" {param_filename} "{url_dl}"'
128
+ )
129
+ os.system(aria2_command)
130
+
131
+ if param_filename and os.path.exists(os.path.join(directory, filename)):
132
+ downloaded_file_path = os.path.join(directory, filename)
133
+
134
+ # # PLAN B
135
+ # # Follow the redirect to get the actual download URL
136
+ # curl_command = (
137
+ # f'curl -L -sI --connect-timeout 5 --max-time 5 '
138
+ # f'-H "Content-Type: application/json" '
139
+ # f'-H "Authorization: Bearer {civitai_api_key}" "{url}"'
140
+ # )
141
+
142
+ # headers = os.popen(curl_command).read()
143
+
144
+ # # Look for the redirected "Location" URL
145
+ # location_match = re.search(r'location: (.+)', headers, re.IGNORECASE)
146
+
147
+ # if location_match:
148
+ # redirect_url = location_match.group(1).strip()
149
+
150
+ # # Extract the filename from the redirect URL's "Content-Disposition"
151
+ # filename_match = re.search(r'filename%3D%22(.+?)%22', redirect_url)
152
+ # if filename_match:
153
+ # encoded_filename = filename_match.group(1)
154
+ # # Decode the URL-encoded filename
155
+ # decoded_filename = urllib.parse.unquote(encoded_filename)
156
+
157
+ # filename = unidecode(decoded_filename) if romanize else decoded_filename
158
+ # print(f"Filename: {filename}")
159
+
160
+ # aria2_command = (
161
+ # f'aria2c --console-log-level=error --summary-interval=10 -c -x 16 '
162
+ # f'-k 1M -s 16 -d "{directory}" -o "{filename}" "{redirect_url}"'
163
+ # )
164
+ # return_code = os.system(aria2_command)
165
+
166
+ # # if return_code != 0:
167
+ # # raise RuntimeError(f"Failed to download file: {filename}. Error code: {return_code}")
168
+ # downloaded_file_path = os.path.join(directory, filename)
169
+ # if not os.path.exists(downloaded_file_path):
170
+ # downloaded_file_path = None
171
+
172
+ # if not downloaded_file_path:
173
+ # # Old method
174
+ # if "?" in url:
175
+ # url = url.split("?")[0]
176
+ # url = url + f"?token={civitai_api_key}"
177
+ # os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
178
+
179
+ else:
180
+ os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
181
+
182
+ return downloaded_file_path
183
+
184
+
185
+ def get_model_list(directory_path):
186
+ model_list = []
187
+ valid_extensions = {'.ckpt', '.pt', '.pth', '.safetensors', '.bin'}
188
+
189
+ for filename in os.listdir(directory_path):
190
+ if os.path.splitext(filename)[1] in valid_extensions:
191
+ # name_without_extension = os.path.splitext(filename)[0]
192
+ file_path = os.path.join(directory_path, filename)
193
+ # model_list.append((name_without_extension, file_path))
194
+ model_list.append(file_path)
195
+ print('\033[34mFILE: ' + file_path + '\033[0m')
196
+ return model_list
197
+
198
+
199
+ def extract_parameters(input_string):
200
+ parameters = {}
201
+ input_string = input_string.replace("\n", "")
202
+
203
+ if "Negative prompt:" not in input_string:
204
+ if "Steps:" in input_string:
205
+ input_string = input_string.replace("Steps:", "Negative prompt: Steps:")
206
+ else:
207
+ print("Invalid metadata")
208
+ parameters["prompt"] = input_string
209
+ return parameters
210
+
211
+ parm = input_string.split("Negative prompt:")
212
+ parameters["prompt"] = parm[0].strip()
213
+ if "Steps:" not in parm[1]:
214
+ print("Steps not detected")
215
+ parameters["neg_prompt"] = parm[1].strip()
216
+ return parameters
217
+ parm = parm[1].split("Steps:")
218
+ parameters["neg_prompt"] = parm[0].strip()
219
+ input_string = "Steps:" + parm[1]
220
+
221
+ # Extracting Steps
222
+ steps_match = re.search(r'Steps: (\d+)', input_string)
223
+ if steps_match:
224
+ parameters['Steps'] = int(steps_match.group(1))
225
+
226
+ # Extracting Size
227
+ size_match = re.search(r'Size: (\d+x\d+)', input_string)
228
+ if size_match:
229
+ parameters['Size'] = size_match.group(1)
230
+ width, height = map(int, parameters['Size'].split('x'))
231
+ parameters['width'] = width
232
+ parameters['height'] = height
233
+
234
+ # Extracting other parameters
235
+ other_parameters = re.findall(r'(\w+): (.*?)(?=, \w+|$)', input_string)
236
+ for param in other_parameters:
237
+ parameters[param[0]] = param[1].strip('"')
238
+
239
+ return parameters
240
+
241
+
242
+ def get_my_lora(link_url, romanize):
243
+ l_name = ""
244
+ for url in [url.strip() for url in link_url.split(',')]:
245
+ if not os.path.exists(f"./loras/{url.split('/')[-1]}"):
246
+ l_name = download_things(DIRECTORY_LORAS, url, HF_TOKEN, CIVITAI_API_KEY, romanize)
247
+ new_lora_model_list = get_model_list(DIRECTORY_LORAS)
248
+ new_lora_model_list.insert(0, "None")
249
+ new_lora_model_list = new_lora_model_list + DIFFUSERS_FORMAT_LORAS
250
+ msg_lora = "Downloaded"
251
+ if l_name:
252
+ msg_lora += f": <b>{l_name}</b>"
253
+ print(msg_lora)
254
+
255
+ return gr.update(
256
+ choices=new_lora_model_list
257
+ ), gr.update(
258
+ choices=new_lora_model_list
259
+ ), gr.update(
260
+ choices=new_lora_model_list
261
+ ), gr.update(
262
+ choices=new_lora_model_list
263
+ ), gr.update(
264
+ choices=new_lora_model_list
265
+ ), gr.update(
266
+ value=msg_lora
267
+ )
268
+
269
+
270
+ def info_html(json_data, title, subtitle):
271
+ return f"""
272
+ <div style='padding: 0; border-radius: 10px;'>
273
+ <p style='margin: 0; font-weight: bold;'>{title}</p>
274
+ <details>
275
+ <summary>Details</summary>
276
+ <p style='margin: 0; font-weight: bold;'>{subtitle}</p>
277
+ </details>
278
+ </div>
279
+ """
280
+
281
+
282
+ def get_model_type(repo_id: str):
283
+ api = HfApi(token=os.environ.get("HF_TOKEN")) # if use private or gated model
284
+ default = "SD 1.5"
285
+ try:
286
+ model = api.model_info(repo_id=repo_id, timeout=5.0)
287
+ tags = model.tags
288
+ for tag in tags:
289
+ if tag in MODEL_TYPE_CLASS.keys(): return MODEL_TYPE_CLASS.get(tag, default)
290
+ except Exception:
291
+ return default
292
+ return default
293
+
294
+
295
+ def restart_space(repo_id: str, factory_reboot: bool):
296
+ api = HfApi(token=os.environ.get("HF_TOKEN"))
297
+ try:
298
+ runtime = api.get_space_runtime(repo_id=repo_id)
299
+ if runtime.stage == "RUNNING":
300
+ api.restart_space(repo_id=repo_id, factory_reboot=factory_reboot)
301
+ print(f"Restarting space: {repo_id}")
302
+ else:
303
+ print(f"Space {repo_id} is in stage: {runtime.stage}")
304
+ except Exception as e:
305
+ print(e)
306
+
307
+
308
+ def extract_exif_data(image):
309
+ if image is None: return ""
310
+
311
+ try:
312
+ metadata_keys = ['parameters', 'metadata', 'prompt', 'Comment']
313
+
314
+ for key in metadata_keys:
315
+ if key in image.info:
316
+ return image.info[key]
317
+
318
+ return str(image.info)
319
+
320
+ except Exception as e:
321
+ return f"Error extracting metadata: {str(e)}"
322
+
323
+
324
+ def create_mask_now(img, invert):
325
+ import numpy as np
326
+ import time
327
+
328
+ time.sleep(0.5)
329
+
330
+ transparent_image = img["layers"][0]
331
+
332
+ # Extract the alpha channel
333
+ alpha_channel = np.array(transparent_image)[:, :, 3]
334
+
335
+ # Create a binary mask by thresholding the alpha channel
336
+ binary_mask = alpha_channel > 1
337
+
338
+ if invert:
339
+ print("Invert")
340
+ # Invert the binary mask so that the drawn shape is white and the rest is black
341
+ binary_mask = np.invert(binary_mask)
342
+
343
+ # Convert the binary mask to a 3-channel RGB mask
344
+ rgb_mask = np.stack((binary_mask,) * 3, axis=-1)
345
+
346
+ # Convert the mask to uint8
347
+ rgb_mask = rgb_mask.astype(np.uint8) * 255
348
+
349
+ return img["background"], rgb_mask
350
+
351
+
352
+ def download_diffuser_repo(repo_name: str, model_type: str, revision: str = "main", token=True):
353
+
354
+ variant = None
355
+ if token is True and not os.environ.get("HF_TOKEN"):
356
+ token = None
357
+
358
+ if model_type == "SDXL":
359
+ info = model_info_data(
360
+ repo_name,
361
+ token=token,
362
+ revision=revision,
363
+ timeout=5.0,
364
+ )
365
+
366
+ filenames = {sibling.rfilename for sibling in info.siblings}
367
+ model_filenames, variant_filenames = variant_compatible_siblings(
368
+ filenames, variant="fp16"
369
+ )
370
+
371
+ if len(variant_filenames):
372
+ variant = "fp16"
373
+
374
+ cached_folder = DiffusionPipeline.download(
375
+ pretrained_model_name=repo_name,
376
+ force_download=False,
377
+ token=token,
378
+ revision=revision,
379
+ # mirror="https://hf-mirror.com",
380
+ variant=variant,
381
+ use_safetensors=True,
382
+ trust_remote_code=False,
383
+ timeout=5.0,
384
+ )
385
+
386
+ if isinstance(cached_folder, PosixPath):
387
+ cached_folder = cached_folder.as_posix()
388
+
389
+ # Task model
390
+ # from huggingface_hub import hf_hub_download
391
+ # hf_hub_download(
392
+ # task_model,
393
+ # filename="diffusion_pytorch_model.safetensors", # fix fp16 variant
394
+ # )
395
+
396
+ return cached_folder
397
+
398
+
399
+ def progress_step_bar(step, total):
400
+ # Calculate the percentage for the progress bar width
401
+ percentage = min(100, ((step / total) * 100))
402
+
403
+ return f"""
404
+ <div style="position: relative; width: 100%; background-color: gray; border-radius: 5px; overflow: hidden;">
405
+ <div style="width: {percentage}%; height: 17px; background-color: #800080; transition: width 0.5s;"></div>
406
+ <div style="position: absolute; width: 100%; text-align: center; color: white; top: 0; line-height: 19px; font-size: 13px;">
407
+ {int(percentage)}%
408
+ </div>
409
+ </div>
410
+ """
411
+
412
+
413
+ def html_template_message(msg):
414
+ return f"""
415
+ <div style="position: relative; width: 100%; background-color: gray; border-radius: 5px; overflow: hidden;">
416
+ <div style="width: 0%; height: 17px; background-color: #800080; transition: width 0.5s;"></div>
417
+ <div style="position: absolute; width: 100%; text-align: center; color: white; top: 0; line-height: 19px; font-size: 14px; font-weight: bold; text-shadow: 1px 1px 2px black;">
418
+ {msg}
419
+ </div>
420
+ </div>
421
+ """