Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Upload 22 files
Browse files- app.py +220 -452
- constants.py +453 -0
- env.py +15 -18
- modutils.py +263 -50
- requirements.txt +1 -0
- tagger/character_series_dict.csv +0 -0
- tagger/danbooru_e621.csv +0 -0
- tagger/output.py +16 -0
- tagger/tag_group.csv +0 -0
- tagger/tagger.py +556 -0
- tagger/utils.py +50 -0
- tagger/v2.py +260 -0
- utils.py +421 -0
    	
        app.py
    CHANGED
    
    | @@ -1,231 +1,104 @@ | |
| 1 | 
             
            import spaces
         | 
| 2 | 
            -
            import gradio as gr
         | 
| 3 | 
             
            import os
         | 
| 4 | 
             
            from stablepy import Model_Diffusers
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 5 | 
             
            from stablepy.diffusers_vanilla.style_prompt_config import STYLE_NAMES
         | 
| 6 | 
            -
            from stablepy.diffusers_vanilla.constants import FLUX_CN_UNION_MODES
         | 
| 7 | 
             
            import torch
         | 
| 8 | 
             
            import re
         | 
| 9 | 
            -
            from huggingface_hub import HfApi
         | 
| 10 | 
             
            from stablepy import (
         | 
| 11 | 
            -
                CONTROLNET_MODEL_IDS,
         | 
| 12 | 
            -
                VALID_TASKS,
         | 
| 13 | 
            -
                T2I_PREPROCESSOR_NAME,
         | 
| 14 | 
            -
                FLASH_LORA,
         | 
| 15 | 
            -
                SCHEDULER_CONFIG_MAP,
         | 
| 16 | 
             
                scheduler_names,
         | 
| 17 | 
            -
                IP_ADAPTER_MODELS,
         | 
| 18 | 
             
                IP_ADAPTERS_SD,
         | 
| 19 | 
             
                IP_ADAPTERS_SDXL,
         | 
| 20 | 
            -
                REPO_IMAGE_ENCODER,
         | 
| 21 | 
            -
                ALL_PROMPT_WEIGHT_OPTIONS,
         | 
| 22 | 
            -
                SD15_TASKS,
         | 
| 23 | 
            -
                SDXL_TASKS,
         | 
| 24 | 
             
            )
         | 
| 25 | 
             
            import time
         | 
| 26 | 
             
            from PIL import ImageFile
         | 
| 27 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 28 |  | 
| 29 | 
             
            ImageFile.LOAD_TRUNCATED_IMAGES = True
         | 
|  | |
| 30 | 
             
            print(os.getenv("SPACES_ZERO_GPU"))
         | 
| 31 |  | 
| 32 | 
            -
            PREPROCESSOR_CONTROLNET = {
         | 
| 33 | 
            -
              "openpose": [
         | 
| 34 | 
            -
                "Openpose",
         | 
| 35 | 
            -
                "None",
         | 
| 36 | 
            -
              ],
         | 
| 37 | 
            -
              "scribble": [
         | 
| 38 | 
            -
                "HED",
         | 
| 39 | 
            -
                "PidiNet",
         | 
| 40 | 
            -
                "None",
         | 
| 41 | 
            -
              ],
         | 
| 42 | 
            -
              "softedge": [
         | 
| 43 | 
            -
                "PidiNet",
         | 
| 44 | 
            -
                "HED",
         | 
| 45 | 
            -
                "HED safe",
         | 
| 46 | 
            -
                "PidiNet safe",
         | 
| 47 | 
            -
                "None",
         | 
| 48 | 
            -
              ],
         | 
| 49 | 
            -
              "segmentation": [
         | 
| 50 | 
            -
                "UPerNet",
         | 
| 51 | 
            -
                "None",
         | 
| 52 | 
            -
              ],
         | 
| 53 | 
            -
              "depth": [
         | 
| 54 | 
            -
                "DPT",
         | 
| 55 | 
            -
                "Midas",
         | 
| 56 | 
            -
                "None",
         | 
| 57 | 
            -
              ],
         | 
| 58 | 
            -
              "normalbae": [
         | 
| 59 | 
            -
                "NormalBae",
         | 
| 60 | 
            -
                "None",
         | 
| 61 | 
            -
              ],
         | 
| 62 | 
            -
              "lineart": [
         | 
| 63 | 
            -
                "Lineart",
         | 
| 64 | 
            -
                "Lineart coarse",
         | 
| 65 | 
            -
                "Lineart (anime)",
         | 
| 66 | 
            -
                "None",
         | 
| 67 | 
            -
                "None (anime)",
         | 
| 68 | 
            -
              ],
         | 
| 69 | 
            -
              "lineart_anime": [
         | 
| 70 | 
            -
                "Lineart",
         | 
| 71 | 
            -
                "Lineart coarse",
         | 
| 72 | 
            -
                "Lineart (anime)",
         | 
| 73 | 
            -
                "None",
         | 
| 74 | 
            -
                "None (anime)",
         | 
| 75 | 
            -
              ],
         | 
| 76 | 
            -
              "shuffle": [
         | 
| 77 | 
            -
                "ContentShuffle",
         | 
| 78 | 
            -
                "None",
         | 
| 79 | 
            -
              ],
         | 
| 80 | 
            -
              "canny": [
         | 
| 81 | 
            -
                "Canny",
         | 
| 82 | 
            -
                "None",
         | 
| 83 | 
            -
              ],
         | 
| 84 | 
            -
              "mlsd": [
         | 
| 85 | 
            -
                "MLSD",
         | 
| 86 | 
            -
                "None",
         | 
| 87 | 
            -
              ],
         | 
| 88 | 
            -
              "ip2p": [
         | 
| 89 | 
            -
                "ip2p"
         | 
| 90 | 
            -
              ],
         | 
| 91 | 
            -
              "recolor": [
         | 
| 92 | 
            -
                "Recolor luminance",
         | 
| 93 | 
            -
                "Recolor intensity",
         | 
| 94 | 
            -
                "None",
         | 
| 95 | 
            -
              ],
         | 
| 96 | 
            -
              "tile": [
         | 
| 97 | 
            -
                "Mild Blur",
         | 
| 98 | 
            -
                "Moderate Blur",
         | 
| 99 | 
            -
                "Heavy Blur",
         | 
| 100 | 
            -
                "None",
         | 
| 101 | 
            -
              ],
         | 
| 102 | 
            -
            }
         | 
| 103 | 
            -
             | 
| 104 | 
            -
            TASK_STABLEPY = {
         | 
| 105 | 
            -
                'txt2img': 'txt2img',
         | 
| 106 | 
            -
                'img2img': 'img2img',
         | 
| 107 | 
            -
                'inpaint': 'inpaint',
         | 
| 108 | 
            -
                # 'canny T2I Adapter': 'sdxl_canny_t2i',  # NO HAVE STEP CALLBACK PARAMETERS SO NOT WORKS WITH DIFFUSERS 0.29.0
         | 
| 109 | 
            -
                # 'sketch  T2I Adapter': 'sdxl_sketch_t2i',
         | 
| 110 | 
            -
                # 'lineart  T2I Adapter': 'sdxl_lineart_t2i',
         | 
| 111 | 
            -
                # 'depth-midas  T2I Adapter': 'sdxl_depth-midas_t2i',
         | 
| 112 | 
            -
                # 'openpose  T2I Adapter': 'sdxl_openpose_t2i',
         | 
| 113 | 
            -
                'openpose ControlNet': 'openpose',
         | 
| 114 | 
            -
                'canny ControlNet': 'canny',
         | 
| 115 | 
            -
                'mlsd ControlNet': 'mlsd',
         | 
| 116 | 
            -
                'scribble ControlNet': 'scribble',
         | 
| 117 | 
            -
                'softedge ControlNet': 'softedge',
         | 
| 118 | 
            -
                'segmentation ControlNet': 'segmentation',
         | 
| 119 | 
            -
                'depth ControlNet': 'depth',
         | 
| 120 | 
            -
                'normalbae ControlNet': 'normalbae',
         | 
| 121 | 
            -
                'lineart ControlNet': 'lineart',
         | 
| 122 | 
            -
                'lineart_anime ControlNet': 'lineart_anime',
         | 
| 123 | 
            -
                'shuffle ControlNet': 'shuffle',
         | 
| 124 | 
            -
                'ip2p ControlNet': 'ip2p',
         | 
| 125 | 
            -
                'optical pattern ControlNet': 'pattern',
         | 
| 126 | 
            -
                'recolor ControlNet': 'recolor',
         | 
| 127 | 
            -
                'tile ControlNet': 'tile',
         | 
| 128 | 
            -
            }
         | 
| 129 | 
            -
             | 
| 130 | 
            -
            TASK_MODEL_LIST = list(TASK_STABLEPY.keys())
         | 
| 131 | 
            -
             | 
| 132 | 
            -
            UPSCALER_DICT_GUI = {
         | 
| 133 | 
            -
                None: None,
         | 
| 134 | 
            -
                "Lanczos": "Lanczos",
         | 
| 135 | 
            -
                "Nearest": "Nearest",
         | 
| 136 | 
            -
                'Latent': 'Latent',
         | 
| 137 | 
            -
                'Latent (antialiased)': 'Latent (antialiased)',
         | 
| 138 | 
            -
                'Latent (bicubic)': 'Latent (bicubic)',
         | 
| 139 | 
            -
                'Latent (bicubic antialiased)': 'Latent (bicubic antialiased)',
         | 
| 140 | 
            -
                'Latent (nearest)': 'Latent (nearest)',
         | 
| 141 | 
            -
                'Latent (nearest-exact)': 'Latent (nearest-exact)',
         | 
| 142 | 
            -
                "RealESRGAN_x4plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
         | 
| 143 | 
            -
                "RealESRNet_x4plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth",
         | 
| 144 | 
            -
                "RealESRGAN_x4plus_anime_6B": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
         | 
| 145 | 
            -
                "RealESRGAN_x2plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
         | 
| 146 | 
            -
                "realesr-animevideov3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
         | 
| 147 | 
            -
                "realesr-general-x4v3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
         | 
| 148 | 
            -
                "realesr-general-wdn-x4v3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
         | 
| 149 | 
            -
                "4x-UltraSharp": "https://huggingface.co/Shandypur/ESRGAN-4x-UltraSharp/resolve/main/4x-UltraSharp.pth",
         | 
| 150 | 
            -
                "4x_foolhardy_Remacri": "https://huggingface.co/FacehugmanIII/4x_foolhardy_Remacri/resolve/main/4x_foolhardy_Remacri.pth",
         | 
| 151 | 
            -
                "Remacri4xExtraSmoother": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/Remacri%204x%20ExtraSmoother.pth",
         | 
| 152 | 
            -
                "AnimeSharp4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/AnimeSharp%204x.pth",
         | 
| 153 | 
            -
                "lollypop": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/lollypop.pth",
         | 
| 154 | 
            -
                "RealisticRescaler4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/RealisticRescaler%204x.pth",
         | 
| 155 | 
            -
                "NickelbackFS4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/NickelbackFS%204x.pth"
         | 
| 156 | 
            -
            }
         | 
| 157 | 
            -
             | 
| 158 | 
            -
            UPSCALER_KEYS = list(UPSCALER_DICT_GUI.keys())
         | 
| 159 | 
            -
             | 
| 160 | 
            -
            def get_model_list(directory_path):
         | 
| 161 | 
            -
                model_list = []
         | 
| 162 | 
            -
                valid_extensions = {'.ckpt', '.pt', '.pth', '.safetensors', '.bin'}
         | 
| 163 | 
            -
             | 
| 164 | 
            -
                for filename in os.listdir(directory_path):
         | 
| 165 | 
            -
                    if os.path.splitext(filename)[1] in valid_extensions:
         | 
| 166 | 
            -
                        # name_without_extension = os.path.splitext(filename)[0]
         | 
| 167 | 
            -
                        file_path = os.path.join(directory_path, filename)
         | 
| 168 | 
            -
                        # model_list.append((name_without_extension, file_path))
         | 
| 169 | 
            -
                        model_list.append(file_path)
         | 
| 170 | 
            -
                        print('\033[34mFILE: ' + file_path + '\033[0m')
         | 
| 171 | 
            -
                return model_list
         | 
| 172 | 
            -
             | 
| 173 | 
             
            ## BEGIN MOD
         | 
| 174 | 
             
            from modutils import (list_uniq, download_private_repo, get_model_id_list, get_tupled_embed_list,
         | 
| 175 | 
             
                get_lora_model_list, get_all_lora_tupled_list, update_loras, apply_lora_prompt, set_prompt_loras,
         | 
| 176 | 
             
                get_my_lora, upload_file_lora, move_file_lora, search_civitai_lora, select_civitai_lora,
         | 
| 177 | 
             
                update_civitai_selection, get_civitai_tag, CIVITAI_SORT, CIVITAI_PERIOD, CIVITAI_BASEMODEL,
         | 
| 178 | 
             
                set_textual_inversion_prompt, get_model_pipeline, change_interface_mode, get_t2i_model_info,
         | 
| 179 | 
            -
                get_tupled_model_list, save_gallery_images, set_optimization, set_sampler_settings,
         | 
| 180 | 
             
                set_quick_presets, process_style_prompt, optimization_list, save_images, download_things,
         | 
| 181 | 
            -
                preset_styles, preset_quality, preset_sampler_setting, translate_to_en)
         | 
| 182 | 
             
            from env import (HF_TOKEN, CIVITAI_API_KEY, HF_LORA_ESSENTIAL_PRIVATE_REPO, HF_VAE_PRIVATE_REPO,
         | 
| 183 | 
             
                HF_SDXL_EMBEDS_NEGATIVE_PRIVATE_REPO, HF_SDXL_EMBEDS_POSITIVE_PRIVATE_REPO,
         | 
| 184 | 
            -
                 | 
| 185 | 
            -
                 | 
| 186 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
| 187 |  | 
| 188 | 
             
            # - **Download Models**
         | 
| 189 | 
            -
             | 
| 190 | 
             
            # - **Download VAEs**
         | 
| 191 | 
            -
             | 
| 192 | 
             
            # - **Download LoRAs**
         | 
| 193 | 
            -
             | 
| 194 | 
            -
             | 
| 195 | 
            -
            download_private_repo(HF_LORA_ESSENTIAL_PRIVATE_REPO, directory_loras, True)
         | 
| 196 | 
            -
            download_private_repo(HF_VAE_PRIVATE_REPO, directory_vaes, False)
         | 
| 197 | 
            -
             | 
| 198 | 
            -
            load_diffusers_format_model = list_uniq(get_model_id_list() + load_diffusers_format_model)
         | 
| 199 | 
            -
            ## END MOD
         | 
| 200 |  | 
| 201 | 
             
            # Download stuffs
         | 
| 202 | 
            -
            for url in [url.strip() for url in  | 
| 203 | 
             
                if not os.path.exists(f"./models/{url.split('/')[-1]}"):
         | 
| 204 | 
            -
                    download_things( | 
| 205 | 
            -
            for url in [url.strip() for url in  | 
| 206 | 
             
                if not os.path.exists(f"./vaes/{url.split('/')[-1]}"):
         | 
| 207 | 
            -
                    download_things( | 
| 208 | 
            -
            for url in [url.strip() for url in  | 
| 209 | 
             
                if not os.path.exists(f"./loras/{url.split('/')[-1]}"):
         | 
| 210 | 
            -
                    download_things( | 
| 211 |  | 
| 212 | 
             
            # Download Embeddings
         | 
| 213 | 
            -
            for url_embed in  | 
| 214 | 
             
                if not os.path.exists(f"./embedings/{url_embed.split('/')[-1]}"):
         | 
| 215 | 
            -
                    download_things( | 
| 216 |  | 
| 217 | 
             
            # Build list models
         | 
| 218 | 
            -
            embed_list = get_model_list( | 
| 219 | 
            -
            model_list = get_model_list(directory_models)
         | 
| 220 | 
            -
            model_list = load_diffusers_format_model + model_list
         | 
| 221 | 
            -
            ## BEGIN MOD
         | 
| 222 | 
             
            lora_model_list = get_lora_model_list()
         | 
| 223 | 
            -
            vae_model_list = get_model_list( | 
| 224 | 
             
            vae_model_list.insert(0, "None")
         | 
| 225 |  | 
| 226 | 
            -
             | 
| 227 | 
            -
             | 
| 228 | 
            -
             | 
|  | |
|  | |
| 229 |  | 
| 230 | 
             
            def get_embed_list(pipeline_name):
         | 
| 231 | 
             
                return get_tupled_embed_list(embed_sdxl_list if pipeline_name == "StableDiffusionXLPipeline" else embed_list)
         | 
| @@ -248,12 +121,13 @@ warnings.filterwarnings(action="ignore", category=UserWarning, module="diffusers | |
| 248 | 
             
            warnings.filterwarnings(action="ignore", category=FutureWarning, module="transformers")
         | 
| 249 | 
             
            ## BEGIN MOD
         | 
| 250 | 
             
            from stablepy import logger
         | 
| 251 | 
            -
            logger.setLevel(logging.CRITICAL)
         | 
|  | |
| 252 |  | 
| 253 | 
            -
            from v2 import V2_ALL_MODELS, v2_random_prompt, v2_upsampling_prompt
         | 
| 254 | 
            -
            from utils import (gradio_copy_text, COPY_ACTION_JS, gradio_copy_prompt,
         | 
| 255 | 
             
                V2_ASPECT_RATIO_OPTIONS, V2_RATING_OPTIONS, V2_LENGTH_OPTIONS, V2_IDENTITY_OPTIONS)
         | 
| 256 | 
            -
            from tagger import (predict_tags_wd, convert_danbooru_to_e621_prompt,
         | 
| 257 | 
             
                remove_specific_prompt, insert_recom_prompt, insert_model_recom_prompt,
         | 
| 258 | 
             
                compose_prompt_to_copy, translate_prompt, select_random_character)
         | 
| 259 | 
             
            def description_ui():
         | 
| @@ -267,137 +141,91 @@ def description_ui(): | |
| 267 | 
             
                )
         | 
| 268 | 
             
            ## END MOD
         | 
| 269 |  | 
| 270 | 
            -
            msg_inc_vae = (
         | 
| 271 | 
            -
                "Use the right VAE for your model to maintain image quality. The wrong"
         | 
| 272 | 
            -
                " VAE can lead to poor results, like blurriness in the generated images."
         | 
| 273 | 
            -
            )
         | 
| 274 | 
            -
             | 
| 275 | 
            -
            SDXL_TASK = [k for k, v in TASK_STABLEPY.items() if v in SDXL_TASKS]
         | 
| 276 | 
            -
            SD_TASK = [k for k, v in TASK_STABLEPY.items() if v in SD15_TASKS]
         | 
| 277 | 
            -
            FLUX_TASK = list(TASK_STABLEPY.keys())[:3] + [k for k, v in TASK_STABLEPY.items() if v in FLUX_CN_UNION_MODES.keys()]
         | 
| 278 | 
            -
             | 
| 279 | 
            -
            MODEL_TYPE_TASK = {
         | 
| 280 | 
            -
                "SD 1.5": SD_TASK,
         | 
| 281 | 
            -
                "SDXL": SDXL_TASK,
         | 
| 282 | 
            -
                "FLUX": FLUX_TASK,
         | 
| 283 | 
            -
            }
         | 
| 284 | 
            -
             | 
| 285 | 
            -
            MODEL_TYPE_CLASS = {
         | 
| 286 | 
            -
                "diffusers:StableDiffusionPipeline": "SD 1.5",
         | 
| 287 | 
            -
                "diffusers:StableDiffusionXLPipeline": "SDXL",
         | 
| 288 | 
            -
                "diffusers:FluxPipeline": "FLUX",
         | 
| 289 | 
            -
            }
         | 
| 290 | 
            -
             | 
| 291 | 
            -
            POST_PROCESSING_SAMPLER = ["Use same sampler"] + scheduler_names[:-2]
         | 
| 292 | 
            -
             | 
| 293 | 
            -
            SUBTITLE_GUI = (
         | 
| 294 | 
            -
                "### This demo uses [diffusers](https://github.com/huggingface/diffusers)"
         | 
| 295 | 
            -
                " to perform different tasks in image generation."
         | 
| 296 | 
            -
            )
         | 
| 297 | 
            -
             | 
| 298 | 
            -
            def extract_parameters(input_string):
         | 
| 299 | 
            -
                parameters = {}
         | 
| 300 | 
            -
                input_string = input_string.replace("\n", "")
         | 
| 301 | 
            -
             | 
| 302 | 
            -
                if "Negative prompt:" not in input_string:
         | 
| 303 | 
            -
                    if "Steps:" in input_string:
         | 
| 304 | 
            -
                        input_string = input_string.replace("Steps:", "Negative prompt: Steps:")
         | 
| 305 | 
            -
                    else:
         | 
| 306 | 
            -
                        print("Invalid metadata")
         | 
| 307 | 
            -
                        parameters["prompt"] = input_string
         | 
| 308 | 
            -
                        return parameters
         | 
| 309 | 
            -
             | 
| 310 | 
            -
                parm = input_string.split("Negative prompt:")
         | 
| 311 | 
            -
                parameters["prompt"] = parm[0].strip()
         | 
| 312 | 
            -
                if "Steps:" not in parm[1]:
         | 
| 313 | 
            -
                    print("Steps not detected")
         | 
| 314 | 
            -
                    parameters["neg_prompt"] = parm[1].strip()
         | 
| 315 | 
            -
                    return parameters
         | 
| 316 | 
            -
                parm = parm[1].split("Steps:")
         | 
| 317 | 
            -
                parameters["neg_prompt"] = parm[0].strip()
         | 
| 318 | 
            -
                input_string = "Steps:" + parm[1]
         | 
| 319 | 
            -
             | 
| 320 | 
            -
                # Extracting Steps
         | 
| 321 | 
            -
                steps_match = re.search(r'Steps: (\d+)', input_string)
         | 
| 322 | 
            -
                if steps_match:
         | 
| 323 | 
            -
                    parameters['Steps'] = int(steps_match.group(1))
         | 
| 324 | 
            -
             | 
| 325 | 
            -
                # Extracting Size
         | 
| 326 | 
            -
                size_match = re.search(r'Size: (\d+x\d+)', input_string)
         | 
| 327 | 
            -
                if size_match:
         | 
| 328 | 
            -
                    parameters['Size'] = size_match.group(1)
         | 
| 329 | 
            -
                    width, height = map(int, parameters['Size'].split('x'))
         | 
| 330 | 
            -
                    parameters['width'] = width
         | 
| 331 | 
            -
                    parameters['height'] = height
         | 
| 332 | 
            -
             | 
| 333 | 
            -
                # Extracting other parameters
         | 
| 334 | 
            -
                other_parameters = re.findall(r'(\w+): (.*?)(?=, \w+|$)', input_string)
         | 
| 335 | 
            -
                for param in other_parameters:
         | 
| 336 | 
            -
                    parameters[param[0]] = param[1].strip('"')
         | 
| 337 | 
            -
             | 
| 338 | 
            -
                return parameters
         | 
| 339 | 
            -
             | 
| 340 | 
            -
            def info_html(json_data, title, subtitle):
         | 
| 341 | 
            -
                return f"""
         | 
| 342 | 
            -
                    <div style='padding: 0; border-radius: 10px;'>
         | 
| 343 | 
            -
                        <p style='margin: 0; font-weight: bold;'>{title}</p>
         | 
| 344 | 
            -
                        <details>
         | 
| 345 | 
            -
                            <summary>Details</summary>
         | 
| 346 | 
            -
                            <p style='margin: 0; font-weight: bold;'>{subtitle}</p>
         | 
| 347 | 
            -
                        </details>
         | 
| 348 | 
            -
                    </div>
         | 
| 349 | 
            -
                    """
         | 
| 350 | 
            -
             | 
| 351 | 
            -
            def get_model_type(repo_id: str):
         | 
| 352 | 
            -
                api = HfApi(token=os.environ.get("HF_TOKEN"))  # if use private or gated model
         | 
| 353 | 
            -
                default = "SD 1.5"
         | 
| 354 | 
            -
                try:
         | 
| 355 | 
            -
                    model = api.model_info(repo_id=repo_id, timeout=5.0)
         | 
| 356 | 
            -
                    tags = model.tags
         | 
| 357 | 
            -
                    for tag in tags:
         | 
| 358 | 
            -
                        if tag in MODEL_TYPE_CLASS.keys(): return MODEL_TYPE_CLASS.get(tag, default)
         | 
| 359 | 
            -
                except Exception:
         | 
| 360 | 
            -
                    return default
         | 
| 361 | 
            -
                return default
         | 
| 362 | 
            -
             | 
| 363 | 
             
            class GuiSD:
         | 
| 364 | 
            -
                def __init__(self):
         | 
| 365 | 
             
                    self.model = None
         | 
| 366 | 
            -
             | 
| 367 | 
            -
                     | 
| 368 | 
            -
                    self. | 
| 369 | 
            -
                        base_model_id="Lykon/dreamshaper-8",
         | 
| 370 | 
            -
                        task_name="txt2img",
         | 
| 371 | 
            -
                        vae_model=None,
         | 
| 372 | 
            -
                        type_model_precision=torch.float16,
         | 
| 373 | 
            -
                        retain_task_model_in_cache=False,
         | 
| 374 | 
            -
                        device="cpu",
         | 
| 375 | 
            -
                    )
         | 
| 376 | 
            -
                    self.model.load_beta_styles()
         | 
| 377 | 
            -
                    #self.model.device = torch.device("cpu") #
         | 
| 378 |  | 
| 379 | 
             
                def load_new_model(self, model_name, vae_model, task, progress=gr.Progress(track_tqdm=True)):
         | 
| 380 |  | 
| 381 | 
            -
                    yield f"Loading model: {model_name}"
         | 
| 382 | 
            -
                    
         | 
| 383 | 
             
                    vae_model = vae_model if vae_model != "None" else None
         | 
| 384 | 
             
                    model_type = get_model_type(model_name)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 385 |  | 
| 386 | 
             
                    if vae_model:
         | 
| 387 | 
             
                        vae_type = "SDXL" if "sdxl" in vae_model.lower() else "SD 1.5"
         | 
| 388 | 
             
                        if model_type != vae_type:
         | 
| 389 | 
            -
                            gr.Warning( | 
| 390 |  | 
| 391 | 
            -
                     | 
| 392 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 393 |  | 
| 394 | 
            -
                    self.model.load_pipe(
         | 
| 395 | 
            -
                        model_name,
         | 
| 396 | 
            -
                        task_name=TASK_STABLEPY[task],
         | 
| 397 | 
            -
                        vae_model=vae_model if vae_model != "None" else None,
         | 
| 398 | 
            -
                        type_model_precision=dtype_model,
         | 
| 399 | 
            -
                        retain_task_model_in_cache=False,
         | 
| 400 | 
            -
                    )
         | 
| 401 | 
             
                    yield f"Model loaded: {model_name}"
         | 
| 402 |  | 
| 403 | 
             
                #@spaces.GPU
         | 
| @@ -506,18 +334,17 @@ class GuiSD: | |
| 506 | 
             
                    mode_ip2,
         | 
| 507 | 
             
                    scale_ip2,
         | 
| 508 | 
             
                    pag_scale,
         | 
| 509 | 
            -
                    #progress=gr.Progress(track_tqdm=True),
         | 
| 510 | 
             
                ):
         | 
| 511 | 
            -
                     | 
|  | |
| 512 |  | 
| 513 | 
             
                    vae_model = vae_model if vae_model != "None" else None
         | 
| 514 | 
             
                    loras_list = [lora1, lora2, lora3, lora4, lora5]
         | 
| 515 | 
             
                    vae_msg = f"VAE: {vae_model}" if vae_model else ""
         | 
| 516 | 
             
                    msg_lora = ""
         | 
| 517 |  | 
| 518 | 
            -
                    print("Config model:", model_name, vae_model, loras_list)
         | 
| 519 | 
            -
             | 
| 520 | 
             
            ## BEGIN MOD
         | 
|  | |
| 521 | 
             
                    global lora_model_list
         | 
| 522 | 
             
                    lora_model_list = get_lora_model_list()
         | 
| 523 | 
             
                    lora1, lora_scale1, lora2, lora_scale2, lora3, lora_scale3, lora4, lora_scale4, lora5, lora_scale5 = \
         | 
| @@ -526,6 +353,8 @@ class GuiSD: | |
| 526 | 
             
                    prompt, neg_prompt = insert_model_recom_prompt(prompt, neg_prompt, model_name)
         | 
| 527 | 
             
            ## END MOD
         | 
| 528 |  | 
|  | |
|  | |
| 529 | 
             
                    task = TASK_STABLEPY[task]
         | 
| 530 |  | 
| 531 | 
             
                    params_ip_img = []
         | 
| @@ -548,7 +377,8 @@ class GuiSD: | |
| 548 | 
             
                            params_ip_mode.append(modeip)
         | 
| 549 | 
             
                            params_ip_scale.append(scaleip)
         | 
| 550 |  | 
| 551 | 
            -
                     | 
|  | |
| 552 |  | 
| 553 | 
             
                    if task != "txt2img" and not image_control:
         | 
| 554 | 
             
                        raise ValueError("No control image found: To use this function, you have to upload an image in 'Image ControlNet/Inpaint/Img2img'")
         | 
| @@ -621,15 +451,15 @@ class GuiSD: | |
| 621 | 
             
                        "high_threshold": high_threshold,
         | 
| 622 | 
             
                        "value_threshold": value_threshold,
         | 
| 623 | 
             
                        "distance_threshold": distance_threshold,
         | 
| 624 | 
            -
                        "lora_A": lora1 if lora1 != "None"  | 
| 625 | 
             
                        "lora_scale_A": lora_scale1,
         | 
| 626 | 
            -
                        "lora_B": lora2 if lora2 != "None"  | 
| 627 | 
             
                        "lora_scale_B": lora_scale2,
         | 
| 628 | 
            -
                        "lora_C": lora3 if lora3 != "None"  | 
| 629 | 
             
                        "lora_scale_C": lora_scale3,
         | 
| 630 | 
            -
                        "lora_D": lora4 if lora4 != "None"  | 
| 631 | 
             
                        "lora_scale_D": lora_scale4,
         | 
| 632 | 
            -
                        "lora_E": lora5 if lora5 != "None"  | 
| 633 | 
             
                        "lora_scale_E": lora_scale5,
         | 
| 634 | 
             
            ## BEGIN MOD
         | 
| 635 | 
             
                        "textual_inversion": get_embed_list(self.model.class_name) if textual_inversion else [],
         | 
| @@ -679,19 +509,24 @@ class GuiSD: | |
| 679 | 
             
                    }
         | 
| 680 |  | 
| 681 | 
             
                    self.model.device = torch.device("cuda:0")
         | 
| 682 | 
            -
                    if hasattr(self.model.pipe, "transformer") and loras_list != ["None"] * 5 | 
| 683 | 
             
                        self.model.pipe.transformer.to(self.model.device)
         | 
| 684 | 
             
                        print("transformer to cuda")
         | 
| 685 |  | 
| 686 | 
            -
                     | 
| 687 | 
            -
             | 
| 688 | 
            -
                    info_state = "PROCESSING "
         | 
| 689 | 
             
                    for img, seed, image_path, metadata in self.model(**pipe_params):
         | 
| 690 | 
            -
                        info_state  | 
|  | |
| 691 | 
             
                        if image_path:
         | 
| 692 | 
            -
                             | 
| 693 | 
             
                            if vae_msg:
         | 
| 694 | 
            -
                                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 695 |  | 
| 696 | 
             
                            for status, lora in zip(self.model.lora_status, self.model.lora_memory):
         | 
| 697 | 
             
                                if status:
         | 
| @@ -700,9 +535,9 @@ class GuiSD: | |
| 700 | 
             
                                    msg_lora += f"<br>Error with: {lora}"
         | 
| 701 |  | 
| 702 | 
             
                            if msg_lora:
         | 
| 703 | 
            -
                                 | 
| 704 |  | 
| 705 | 
            -
                             | 
| 706 |  | 
| 707 | 
             
                            download_links = "<br>".join(
         | 
| 708 | 
             
                                [
         | 
| @@ -711,19 +546,16 @@ class GuiSD: | |
| 711 | 
             
                                ]
         | 
| 712 | 
             
                            )
         | 
| 713 | 
             
                            if save_generated_images:
         | 
| 714 | 
            -
                                 | 
| 715 |  | 
|  | |
| 716 | 
             
                            img = save_images(img, metadata)
         | 
|  | |
| 717 |  | 
| 718 | 
            -
             | 
| 719 | 
            -
             | 
| 720 | 
            -
            def update_task_options(model_name, task_name):
         | 
| 721 | 
            -
                new_choices = MODEL_TYPE_TASK[get_model_type(model_name)]
         | 
| 722 |  | 
| 723 | 
            -
             | 
| 724 | 
            -
                    task_name = "txt2img"
         | 
| 725 |  | 
| 726 | 
            -
                return gr.update(value=task_name, choices=new_choices)
         | 
| 727 |  | 
| 728 | 
             
            def dynamic_gpu_duration(func, duration, *args):
         | 
| 729 |  | 
| @@ -733,10 +565,12 @@ def dynamic_gpu_duration(func, duration, *args): | |
| 733 |  | 
| 734 | 
             
                return wrapped_func()
         | 
| 735 |  | 
|  | |
| 736 | 
             
            @spaces.GPU
         | 
| 737 | 
             
            def dummy_gpu():
         | 
| 738 | 
             
                return None
         | 
| 739 |  | 
|  | |
| 740 | 
             
            def sd_gen_generate_pipeline(*args):
         | 
| 741 |  | 
| 742 | 
             
                gpu_duration_arg = int(args[-1]) if args[-1] else 59
         | 
| @@ -744,7 +578,7 @@ def sd_gen_generate_pipeline(*args): | |
| 744 | 
             
                load_lora_cpu = args[-3]
         | 
| 745 | 
             
                generation_args = args[:-3]
         | 
| 746 | 
             
                lora_list = [
         | 
| 747 | 
            -
                    None if item == "None" or item == "" else item
         | 
| 748 | 
             
                    for item in [args[7], args[9], args[11], args[13], args[15]]
         | 
| 749 | 
             
                ]
         | 
| 750 | 
             
                lora_status = [None] * 5
         | 
| @@ -754,7 +588,7 @@ def sd_gen_generate_pipeline(*args): | |
| 754 | 
             
                    msg_load_lora = "Updating LoRAs in CPU (Slow but saves GPU usage)..."
         | 
| 755 |  | 
| 756 | 
             
                if lora_list != sd_gen.model.lora_memory and lora_list != [None] * 5:
         | 
| 757 | 
            -
                    yield  | 
| 758 |  | 
| 759 | 
             
                # Load lora in CPU
         | 
| 760 | 
             
                if load_lora_cpu:
         | 
| @@ -780,14 +614,15 @@ def sd_gen_generate_pipeline(*args): | |
| 780 | 
             
                        )
         | 
| 781 | 
             
                        gr.Info(f"LoRAs in cache: {lora_cache_msg}")
         | 
| 782 |  | 
| 783 | 
            -
             | 
|  | |
| 784 | 
             
                    gr.Info(msg_request)
         | 
| 785 | 
             
                    print(msg_request)
         | 
| 786 | 
            -
             | 
| 787 | 
            -
                # yield from sd_gen.generate_pipeline(*generation_args)
         | 
| 788 |  | 
| 789 | 
             
                start_time = time.time()
         | 
| 790 |  | 
|  | |
| 791 | 
             
                yield from dynamic_gpu_duration(
         | 
| 792 | 
             
                    sd_gen.generate_pipeline,
         | 
| 793 | 
             
                    gpu_duration_arg,
         | 
| @@ -795,31 +630,19 @@ def sd_gen_generate_pipeline(*args): | |
| 795 | 
             
                )
         | 
| 796 |  | 
| 797 | 
             
                end_time = time.time()
         | 
|  | |
|  | |
|  | |
|  | |
| 798 |  | 
| 799 | 
             
                if verbose_arg:
         | 
| 800 | 
            -
                    execution_time = end_time - start_time
         | 
| 801 | 
            -
                    msg_task_complete = (
         | 
| 802 | 
            -
                        f"GPU task complete in: {round(execution_time, 0) + 1} seconds"
         | 
| 803 | 
            -
                    )
         | 
| 804 | 
             
                    gr.Info(msg_task_complete)
         | 
| 805 | 
             
                    print(msg_task_complete)
         | 
| 806 |  | 
| 807 | 
            -
             | 
| 808 | 
            -
                if image is None: return ""
         | 
| 809 | 
            -
             | 
| 810 | 
            -
                try:
         | 
| 811 | 
            -
                    metadata_keys = ['parameters', 'metadata', 'prompt', 'Comment']
         | 
| 812 | 
            -
             | 
| 813 | 
            -
                    for key in metadata_keys:
         | 
| 814 | 
            -
                        if key in image.info:
         | 
| 815 | 
            -
                            return image.info[key]
         | 
| 816 | 
            -
             | 
| 817 | 
            -
                    return str(image.info)
         | 
| 818 |  | 
| 819 | 
            -
                except Exception as e:
         | 
| 820 | 
            -
                    return f"Error extracting metadata: {str(e)}"
         | 
| 821 |  | 
| 822 | 
            -
            @spaces.GPU(duration= | 
| 823 | 
             
            def esrgan_upscale(image, upscaler_name, upscaler_size):
         | 
| 824 | 
             
                if image is None: return None
         | 
| 825 |  | 
| @@ -841,17 +664,20 @@ def esrgan_upscale(image, upscaler_name, upscaler_size): | |
| 841 |  | 
| 842 | 
             
                return image_path
         | 
| 843 |  | 
|  | |
| 844 | 
             
            dynamic_gpu_duration.zerogpu = True
         | 
| 845 | 
             
            sd_gen_generate_pipeline.zerogpu = True
         | 
| 846 | 
             
            sd_gen = GuiSD()
         | 
| 847 |  | 
|  | |
| 848 | 
             
            ## BEGIN MOD
         | 
| 849 | 
             
            CSS ="""
         | 
| 850 | 
            -
            .gradio-container, #main { width:100%; height:100%; max-width:100%; padding-left:0; padding-right:0; margin-left:0; margin-right:0;  | 
| 851 | 
            -
            .contain { display:flex; flex-direction:column;  | 
| 852 | 
            -
            #component-0 { width:100%; height:100%;  | 
| 853 | 
            -
            #gallery { flex-grow:1;  | 
| 854 | 
            -
             | 
|  | |
| 855 | 
             
            #model-info { text-align:center; }
         | 
| 856 | 
             
            .title { font-size: 3em; align-items: center; text-align: center; }
         | 
| 857 | 
             
            .info { align-items: center; text-align: center; }
         | 
| @@ -865,6 +691,15 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs | |
| 865 | 
             
                    with gr.Tab("Generation"):
         | 
| 866 | 
             
                        with gr.Row():
         | 
| 867 | 
             
                            with gr.Column(scale=2):
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 868 | 
             
                                interface_mode_gui = gr.Radio(label="Quick settings", choices=["Simple", "Standard", "Fast", "LoRA"], value="Standard")
         | 
| 869 | 
             
                                with gr.Accordion("Model and Task", open=False) as menu_model:
         | 
| 870 | 
             
                                    task_gui = gr.Dropdown(label="Task", choices=SDXL_TASK, value=TASK_MODEL_LIST[0])
         | 
| @@ -928,7 +763,7 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs | |
| 928 | 
             
                                    [task_gui],
         | 
| 929 | 
             
                                )
         | 
| 930 |  | 
| 931 | 
            -
                                load_model_gui = gr.HTML()
         | 
| 932 |  | 
| 933 | 
             
                                result_images = gr.Gallery(
         | 
| 934 | 
             
                                    label="Generated images",
         | 
| @@ -950,6 +785,13 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs | |
| 950 |  | 
| 951 | 
             
                                actual_task_info = gr.HTML()
         | 
| 952 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 953 | 
             
                                with gr.Row(equal_height=False, variant="default"):
         | 
| 954 | 
             
                                    gpu_duration_gui = gr.Number(minimum=5, maximum=240, value=59, show_label=False, container=False, info="GPU time duration (seconds)")
         | 
| 955 | 
             
                                    with gr.Column():
         | 
| @@ -972,15 +814,7 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs | |
| 972 | 
             
                                    with gr.Row():
         | 
| 973 | 
             
                                        sampler_gui = gr.Dropdown(label="Sampler", choices=scheduler_names, value="Euler a")
         | 
| 974 | 
             
                                        vae_model_gui = gr.Dropdown(label="VAE Model", choices=vae_model_list, value=vae_model_list[0])
         | 
| 975 | 
            -
                                         | 
| 976 | 
            -
                                            ("Compel format: (word)weight", "Compel"),
         | 
| 977 | 
            -
                                            ("Classic format: (word:weight)", "Classic"),
         | 
| 978 | 
            -
                                            ("Classic-original format: (word:weight)", "Classic-original"),
         | 
| 979 | 
            -
                                            ("Classic-no_norm format: (word:weight)", "Classic-no_norm"),
         | 
| 980 | 
            -
                                            ("Classic-ignore", "Classic-ignore"),
         | 
| 981 | 
            -
                                            ("None", "None"),
         | 
| 982 | 
            -
                                        ]
         | 
| 983 | 
            -
                                        prompt_syntax_gui = gr.Dropdown(label="Prompt Syntax", choices=prompt_s_options, value=prompt_s_options[1][1])
         | 
| 984 |  | 
| 985 | 
             
                                with gr.Row(equal_height=False):
         | 
| 986 |  | 
| @@ -1130,8 +964,11 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs | |
| 1130 | 
             
                                        with gr.Accordion("Select from Gallery", open=False):
         | 
| 1131 | 
             
                                            search_civitai_gallery_lora = gr.Gallery([], label="Results", allow_preview=False, columns=5, show_share_button=False, interactive=False)
         | 
| 1132 | 
             
                                        search_civitai_result_lora = gr.Dropdown(label="Search Results", choices=[("", "")], value="", allow_custom_value=True, visible=False)
         | 
| 1133 | 
            -
                                         | 
|  | |
|  | |
| 1134 | 
             
                                        button_lora = gr.Button("Get and update lists of LoRAs")
         | 
|  | |
| 1135 | 
             
                                    with gr.Accordion("From Local", open=True, visible=True):
         | 
| 1136 | 
             
                                        file_output_lora = gr.File(label="Uploaded LoRA", file_types=['.ckpt', '.pt', '.pth', '.safetensors', '.bin'], file_count="multiple", interactive=False, visible=False)
         | 
| 1137 | 
             
                                        upload_button_lora = gr.UploadButton(label="Upload LoRA from your disk (very slow)", file_types=['.ckpt', '.pt', '.pth', '.safetensors', '.bin'], file_count="multiple")
         | 
| @@ -1169,7 +1006,7 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs | |
| 1169 | 
             
                                            negative_prompt_ad_a_gui = gr.Textbox(label="Negative prompt", placeholder="Main negative prompt will be use", lines=3)
         | 
| 1170 | 
             
                                            with gr.Row():
         | 
| 1171 | 
             
                                                strength_ad_a_gui = gr.Number(label="Strength:", value=0.35, step=0.01, minimum=0.01, maximum=1.0)
         | 
| 1172 | 
            -
                                                face_detector_ad_a_gui = gr.Checkbox(label="Face detector", value= | 
| 1173 | 
             
                                                person_detector_ad_a_gui = gr.Checkbox(label="Person detector", value=True)
         | 
| 1174 | 
             
                                                hand_detector_ad_a_gui = gr.Checkbox(label="Hand detector", value=False)
         | 
| 1175 | 
             
                                            with gr.Row():
         | 
| @@ -1184,7 +1021,7 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs | |
| 1184 | 
             
                                            negative_prompt_ad_b_gui = gr.Textbox(label="Negative prompt", placeholder="Main negative prompt will be use", lines=3)
         | 
| 1185 | 
             
                                            with gr.Row():
         | 
| 1186 | 
             
                                                strength_ad_b_gui = gr.Number(label="Strength:", value=0.35, step=0.01, minimum=0.01, maximum=1.0)
         | 
| 1187 | 
            -
                                                face_detector_ad_b_gui = gr.Checkbox(label="Face detector", value= | 
| 1188 | 
             
                                                person_detector_ad_b_gui = gr.Checkbox(label="Person detector", value=True)
         | 
| 1189 | 
             
                                                hand_detector_ad_b_gui = gr.Checkbox(label="Hand detector", value=False)
         | 
| 1190 | 
             
                                            with gr.Row():
         | 
| @@ -1314,73 +1151,7 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs | |
| 1314 |  | 
| 1315 | 
             
                        with gr.Accordion("Examples and help", open=True, visible=True) as menu_example:
         | 
| 1316 | 
             
                            gr.Examples(
         | 
| 1317 | 
            -
                                examples= | 
| 1318 | 
            -
                                    [
         | 
| 1319 | 
            -
                                        "1girl, souryuu asuka langley, neon genesis evangelion, plugsuit, pilot suit, red bodysuit, sitting, crossing legs, black eye patch, cat hat, throne, symmetrical, looking down, from bottom, looking at viewer, outdoors, masterpiece, best quality, very aesthetic, absurdres",
         | 
| 1320 | 
            -
                                        "nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
         | 
| 1321 | 
            -
                                        1,
         | 
| 1322 | 
            -
                                        30,
         | 
| 1323 | 
            -
                                        7.5,
         | 
| 1324 | 
            -
                                        True,
         | 
| 1325 | 
            -
                                        -1,
         | 
| 1326 | 
            -
                                        "Euler a",
         | 
| 1327 | 
            -
                                        1152,
         | 
| 1328 | 
            -
                                        896,
         | 
| 1329 | 
            -
                                        "votepurchase/animagine-xl-3.1",
         | 
| 1330 | 
            -
                                    ],
         | 
| 1331 | 
            -
                                    [
         | 
| 1332 | 
            -
                                        "solo, princess Zelda OOT, score_9, score_8_up, score_8, medium breasts, cute, eyelashes, cute small face, long hair, crown braid, hairclip, pointy ears, soft curvy body, looking at viewer, smile, blush, white dress, medium body, (((holding the Master Sword))), standing, deep forest in the background",
         | 
| 1333 | 
            -
                                        "score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white,",
         | 
| 1334 | 
            -
                                        1,
         | 
| 1335 | 
            -
                                        30,
         | 
| 1336 | 
            -
                                        5.,
         | 
| 1337 | 
            -
                                        True,
         | 
| 1338 | 
            -
                                        -1,
         | 
| 1339 | 
            -
                                        "Euler a",
         | 
| 1340 | 
            -
                                        1024,
         | 
| 1341 | 
            -
                                        1024,
         | 
| 1342 | 
            -
                                        "votepurchase/ponyDiffusionV6XL",
         | 
| 1343 | 
            -
                                    ],
         | 
| 1344 | 
            -
                                    [
         | 
| 1345 | 
            -
                                        "1girl, oomuro sakurako, yuru yuri, official art, school uniform, anime artwork, anime style, vibrant, studio anime, highly detailed, masterpiece, best quality, very aesthetic, absurdres",
         | 
| 1346 | 
            -
                                        "photo, deformed, black and white, realism, disfigured, low contrast, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
         | 
| 1347 | 
            -
                                        1,
         | 
| 1348 | 
            -
                                        40,
         | 
| 1349 | 
            -
                                        7.0,
         | 
| 1350 | 
            -
                                        True,
         | 
| 1351 | 
            -
                                        -1,
         | 
| 1352 | 
            -
                                        "Euler a",
         | 
| 1353 | 
            -
                                        1024,
         | 
| 1354 | 
            -
                                        1024,
         | 
| 1355 | 
            -
                                        "Raelina/Rae-Diffusion-XL-V2",
         | 
| 1356 | 
            -
                                    ],
         | 
| 1357 | 
            -
                                    [
         | 
| 1358 | 
            -
                                        "1girl, akaza akari, yuru yuri, official art, anime artwork, anime style, vibrant, studio anime, highly detailed, masterpiece, best quality, very aesthetic, absurdres",
         | 
| 1359 | 
            -
                                        "photo, deformed, black and white, realism, disfigured, low contrast, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
         | 
| 1360 | 
            -
                                        1,
         | 
| 1361 | 
            -
                                        35,
         | 
| 1362 | 
            -
                                        7.0,
         | 
| 1363 | 
            -
                                        True,
         | 
| 1364 | 
            -
                                        -1,
         | 
| 1365 | 
            -
                                        "Euler a",
         | 
| 1366 | 
            -
                                        1024,
         | 
| 1367 | 
            -
                                        1024,
         | 
| 1368 | 
            -
                                        "Raelina/Raemu-XL-V4",
         | 
| 1369 | 
            -
                                    ],
         | 
| 1370 | 
            -
                                    [
         | 
| 1371 | 
            -
                                        "yoshida yuuko, machikado mazoku, 1girl, solo, demon horns,horns, school uniform, long hair, open mouth, skirt, demon girl, ahoge, shiny, shiny hair, anime artwork",
         | 
| 1372 | 
            -
                                        "nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
         | 
| 1373 | 
            -
                                        1,
         | 
| 1374 | 
            -
                                        50,
         | 
| 1375 | 
            -
                                        7.,
         | 
| 1376 | 
            -
                                        True,
         | 
| 1377 | 
            -
                                        -1,
         | 
| 1378 | 
            -
                                        "Euler a",
         | 
| 1379 | 
            -
                                        1024,
         | 
| 1380 | 
            -
                                        1024,
         | 
| 1381 | 
            -
                                        "cagliostrolab/animagine-xl-3.1",
         | 
| 1382 | 
            -
                                    ],
         | 
| 1383 | 
            -
                                ],
         | 
| 1384 | 
             
                                fn=sd_gen.generate_pipeline,
         | 
| 1385 | 
             
                                inputs=[
         | 
| 1386 | 
             
                                    prompt_gui,
         | 
| @@ -1395,16 +1166,12 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs | |
| 1395 | 
             
                                    img_width_gui,
         | 
| 1396 | 
             
                                    model_name_gui,
         | 
| 1397 | 
             
                                ],
         | 
| 1398 | 
            -
                                outputs=[result_images, actual_task_info],
         | 
| 1399 | 
             
                                cache_examples=False,
         | 
| 1400 | 
             
                                #elem_id="examples",
         | 
| 1401 | 
             
                            )
         | 
| 1402 |  | 
| 1403 | 
            -
                        gr.Markdown(
         | 
| 1404 | 
            -
                            """### Resources
         | 
| 1405 | 
            -
                            - You can also try the image generator in Colab’s free tier, which provides free GPU [link](https://github.com/R3gm/SD_diffusers_interactive).
         | 
| 1406 | 
            -
                            """
         | 
| 1407 | 
            -
                        )
         | 
| 1408 | 
             
            ## END MOD
         | 
| 1409 |  | 
| 1410 | 
             
                    with gr.Tab("Inpaint mask maker", render=True):
         | 
| @@ -1563,7 +1330,7 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs | |
| 1563 | 
             
                    )
         | 
| 1564 | 
             
                    search_civitai_result_lora.change(select_civitai_lora, [search_civitai_result_lora], [text_lora, search_civitai_desc_lora], queue=False, scroll_to_output=True)
         | 
| 1565 | 
             
                    search_civitai_gallery_lora.select(update_civitai_selection, None, [search_civitai_result_lora], queue=False, show_api=False)
         | 
| 1566 | 
            -
                    button_lora.click(get_my_lora, [text_lora], [lora1_gui, lora2_gui, lora3_gui, lora4_gui, lora5_gui], scroll_to_output=True)
         | 
| 1567 | 
             
                    upload_button_lora.upload(upload_file_lora, [upload_button_lora], [file_output_lora, upload_button_lora]).success(
         | 
| 1568 | 
             
                        move_file_lora, [file_output_lora], [lora1_gui, lora2_gui, lora3_gui, lora4_gui, lora5_gui], scroll_to_output=True)
         | 
| 1569 |  | 
| @@ -1719,10 +1486,11 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs | |
| 1719 | 
             
                            verbose_info_gui,
         | 
| 1720 | 
             
                            gpu_duration_gui,
         | 
| 1721 | 
             
                        ],
         | 
| 1722 | 
            -
                        outputs=[result_images, actual_task_info], 
         | 
| 1723 | 
             
                        queue=True,
         | 
| 1724 | 
             
                        show_progress="full",
         | 
| 1725 | 
            -
                    ).success(save_gallery_images, [result_images], [result_images, result_images_files], queue=False, show_api=False)
         | 
|  | |
| 1726 |  | 
| 1727 | 
             
                    with gr.Tab("Danbooru Tags Transformer with WD Tagger", render=True):
         | 
| 1728 | 
             
                        with gr.Column(scale=2):
         | 
| @@ -1818,5 +1586,5 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs | |
| 1818 | 
             
                gr.DuplicateButton(value="Duplicate Space for private use (This demo does not work on CPU. Requires GPU Space)")
         | 
| 1819 |  | 
| 1820 | 
             
            app.queue()
         | 
| 1821 | 
            -
            app.launch() # allowed_paths=["./images/"], show_error=True, debug=True
         | 
| 1822 | 
             
            ## END MOD
         | 
|  | |
| 1 | 
             
            import spaces
         | 
|  | |
| 2 | 
             
            import os
         | 
| 3 | 
             
            from stablepy import Model_Diffusers
         | 
| 4 | 
            +
            from constants import (
         | 
| 5 | 
            +
                PREPROCESSOR_CONTROLNET,
         | 
| 6 | 
            +
                TASK_STABLEPY,
         | 
| 7 | 
            +
                TASK_MODEL_LIST,
         | 
| 8 | 
            +
                UPSCALER_DICT_GUI,
         | 
| 9 | 
            +
                UPSCALER_KEYS,
         | 
| 10 | 
            +
                PROMPT_W_OPTIONS,
         | 
| 11 | 
            +
                WARNING_MSG_VAE,
         | 
| 12 | 
            +
                SDXL_TASK,
         | 
| 13 | 
            +
                MODEL_TYPE_TASK,
         | 
| 14 | 
            +
                POST_PROCESSING_SAMPLER,
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            )
         | 
| 17 | 
             
            from stablepy.diffusers_vanilla.style_prompt_config import STYLE_NAMES
         | 
|  | |
| 18 | 
             
            import torch
         | 
| 19 | 
             
            import re
         | 
|  | |
| 20 | 
             
            from stablepy import (
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 21 | 
             
                scheduler_names,
         | 
|  | |
| 22 | 
             
                IP_ADAPTERS_SD,
         | 
| 23 | 
             
                IP_ADAPTERS_SDXL,
         | 
|  | |
|  | |
|  | |
|  | |
| 24 | 
             
            )
         | 
| 25 | 
             
            import time
         | 
| 26 | 
             
            from PIL import ImageFile
         | 
| 27 | 
            +
            from utils import (
         | 
| 28 | 
            +
                get_model_list,
         | 
| 29 | 
            +
                extract_parameters,
         | 
| 30 | 
            +
                get_model_type,
         | 
| 31 | 
            +
                extract_exif_data,
         | 
| 32 | 
            +
                create_mask_now,
         | 
| 33 | 
            +
                download_diffuser_repo,
         | 
| 34 | 
            +
                progress_step_bar,
         | 
| 35 | 
            +
                html_template_message,
         | 
| 36 | 
            +
            )
         | 
| 37 | 
            +
            from datetime import datetime
         | 
| 38 | 
            +
            import gradio as gr
         | 
| 39 | 
            +
            import logging
         | 
| 40 | 
            +
            import diffusers
         | 
| 41 | 
            +
            import warnings
         | 
| 42 | 
            +
            from stablepy import logger
         | 
| 43 | 
            +
            # import urllib.parse
         | 
| 44 |  | 
| 45 | 
             
            ImageFile.LOAD_TRUNCATED_IMAGES = True
         | 
| 46 | 
            +
            # os.environ["PYTORCH_NO_CUDA_MEMORY_CACHING"] = "1"
         | 
| 47 | 
             
            print(os.getenv("SPACES_ZERO_GPU"))
         | 
| 48 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 49 | 
             
            ## BEGIN MOD
         | 
| 50 | 
             
            from modutils import (list_uniq, download_private_repo, get_model_id_list, get_tupled_embed_list,
         | 
| 51 | 
             
                get_lora_model_list, get_all_lora_tupled_list, update_loras, apply_lora_prompt, set_prompt_loras,
         | 
| 52 | 
             
                get_my_lora, upload_file_lora, move_file_lora, search_civitai_lora, select_civitai_lora,
         | 
| 53 | 
             
                update_civitai_selection, get_civitai_tag, CIVITAI_SORT, CIVITAI_PERIOD, CIVITAI_BASEMODEL,
         | 
| 54 | 
             
                set_textual_inversion_prompt, get_model_pipeline, change_interface_mode, get_t2i_model_info,
         | 
| 55 | 
            +
                get_tupled_model_list, save_gallery_images, save_gallery_history, set_optimization, set_sampler_settings,
         | 
| 56 | 
             
                set_quick_presets, process_style_prompt, optimization_list, save_images, download_things,
         | 
| 57 | 
            +
                preset_styles, preset_quality, preset_sampler_setting, translate_to_en, EXAMPLES_GUI, RESOURCES)
         | 
| 58 | 
             
            from env import (HF_TOKEN, CIVITAI_API_KEY, HF_LORA_ESSENTIAL_PRIVATE_REPO, HF_VAE_PRIVATE_REPO,
         | 
| 59 | 
             
                HF_SDXL_EMBEDS_NEGATIVE_PRIVATE_REPO, HF_SDXL_EMBEDS_POSITIVE_PRIVATE_REPO,
         | 
| 60 | 
            +
                DIRECTORY_MODELS, DIRECTORY_LORAS, DIRECTORY_VAES, DIRECTORY_EMBEDS, DIRECTORY_EMBEDS_SDXL,
         | 
| 61 | 
            +
                DIRECTORY_EMBEDS_POSITIVE_SDXL, LOAD_DIFFUSERS_FORMAT_MODEL,
         | 
| 62 | 
            +
                DOWNLOAD_MODEL_LIST, DOWNLOAD_LORA_LIST, DOWNLOAD_VAE_LIST, DOWNLOAD_EMBEDS)
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            download_private_repo(HF_LORA_ESSENTIAL_PRIVATE_REPO, DIRECTORY_LORAS, True)
         | 
| 65 | 
            +
            download_private_repo(HF_VAE_PRIVATE_REPO, DIRECTORY_VAES, False)
         | 
| 66 | 
            +
            ## END MOD
         | 
| 67 |  | 
| 68 | 
             
            # - **Download Models**
         | 
| 69 | 
            +
            DOWNLOAD_MODEL = ", ".join(DOWNLOAD_MODEL_LIST)
         | 
| 70 | 
             
            # - **Download VAEs**
         | 
| 71 | 
            +
            DOWNLOAD_VAE = ", ".join(DOWNLOAD_VAE_LIST)
         | 
| 72 | 
             
            # - **Download LoRAs**
         | 
| 73 | 
            +
            DOWNLOAD_LORA = ", ".join(DOWNLOAD_LORA_LIST)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 74 |  | 
| 75 | 
             
            # Download stuffs
         | 
| 76 | 
            +
            for url in [url.strip() for url in DOWNLOAD_MODEL.split(',')]:
         | 
| 77 | 
             
                if not os.path.exists(f"./models/{url.split('/')[-1]}"):
         | 
| 78 | 
            +
                    download_things(DIRECTORY_MODELS, url, HF_TOKEN, CIVITAI_API_KEY)
         | 
| 79 | 
            +
            for url in [url.strip() for url in DOWNLOAD_VAE.split(',')]:
         | 
| 80 | 
             
                if not os.path.exists(f"./vaes/{url.split('/')[-1]}"):
         | 
| 81 | 
            +
                    download_things(DIRECTORY_VAES, url, HF_TOKEN, CIVITAI_API_KEY)
         | 
| 82 | 
            +
            for url in [url.strip() for url in DOWNLOAD_LORA.split(',')]:
         | 
| 83 | 
             
                if not os.path.exists(f"./loras/{url.split('/')[-1]}"):
         | 
| 84 | 
            +
                    download_things(DIRECTORY_LORAS, url, HF_TOKEN, CIVITAI_API_KEY)
         | 
| 85 |  | 
| 86 | 
             
            # Download Embeddings
         | 
| 87 | 
            +
            for url_embed in DOWNLOAD_EMBEDS:
         | 
| 88 | 
             
                if not os.path.exists(f"./embedings/{url_embed.split('/')[-1]}"):
         | 
| 89 | 
            +
                    download_things(DIRECTORY_EMBEDS, url_embed, HF_TOKEN, CIVITAI_API_KEY)
         | 
| 90 |  | 
| 91 | 
             
            # Build list models
         | 
| 92 | 
            +
            embed_list = get_model_list(DIRECTORY_EMBEDS)
         | 
|  | |
|  | |
|  | |
| 93 | 
             
            lora_model_list = get_lora_model_list()
         | 
| 94 | 
            +
            vae_model_list = get_model_list(DIRECTORY_VAES)
         | 
| 95 | 
             
            vae_model_list.insert(0, "None")
         | 
| 96 |  | 
| 97 | 
            +
            ## BEGIN MOD
         | 
| 98 | 
            +
            model_list = list_uniq(get_model_id_list() + LOAD_DIFFUSERS_FORMAT_MODEL + get_model_list(DIRECTORY_MODELS))
         | 
| 99 | 
            +
            download_private_repo(HF_SDXL_EMBEDS_NEGATIVE_PRIVATE_REPO, DIRECTORY_EMBEDS_SDXL, False)
         | 
| 100 | 
            +
            download_private_repo(HF_SDXL_EMBEDS_POSITIVE_PRIVATE_REPO, DIRECTORY_EMBEDS_POSITIVE_SDXL, False)
         | 
| 101 | 
            +
            embed_sdxl_list = get_model_list(DIRECTORY_EMBEDS_SDXL) + get_model_list(DIRECTORY_EMBEDS_POSITIVE_SDXL)
         | 
| 102 |  | 
| 103 | 
             
            def get_embed_list(pipeline_name):
         | 
| 104 | 
             
                return get_tupled_embed_list(embed_sdxl_list if pipeline_name == "StableDiffusionXLPipeline" else embed_list)
         | 
|  | |
| 121 | 
             
            warnings.filterwarnings(action="ignore", category=FutureWarning, module="transformers")
         | 
| 122 | 
             
            ## BEGIN MOD
         | 
| 123 | 
             
            from stablepy import logger
         | 
| 124 | 
            +
            #logger.setLevel(logging.CRITICAL)
         | 
| 125 | 
            +
            logger.setLevel(logging.DEBUG)
         | 
| 126 |  | 
| 127 | 
            +
            from tagger.v2 import V2_ALL_MODELS, v2_random_prompt, v2_upsampling_prompt
         | 
| 128 | 
            +
            from tagger.utils import (gradio_copy_text, COPY_ACTION_JS, gradio_copy_prompt,
         | 
| 129 | 
             
                V2_ASPECT_RATIO_OPTIONS, V2_RATING_OPTIONS, V2_LENGTH_OPTIONS, V2_IDENTITY_OPTIONS)
         | 
| 130 | 
            +
            from tagger.tagger import (predict_tags_wd, convert_danbooru_to_e621_prompt,
         | 
| 131 | 
             
                remove_specific_prompt, insert_recom_prompt, insert_model_recom_prompt,
         | 
| 132 | 
             
                compose_prompt_to_copy, translate_prompt, select_random_character)
         | 
| 133 | 
             
            def description_ui():
         | 
|  | |
| 141 | 
             
                )
         | 
| 142 | 
             
            ## END MOD
         | 
| 143 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 144 | 
             
            class GuiSD:
         | 
| 145 | 
            +
                def __init__(self, stream=True):
         | 
| 146 | 
             
                    self.model = None
         | 
| 147 | 
            +
                    self.status_loading = False
         | 
| 148 | 
            +
                    self.sleep_loading = 4
         | 
| 149 | 
            +
                    self.last_load = datetime.now()
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 150 |  | 
| 151 | 
             
                def load_new_model(self, model_name, vae_model, task, progress=gr.Progress(track_tqdm=True)):
         | 
| 152 |  | 
|  | |
|  | |
| 153 | 
             
                    vae_model = vae_model if vae_model != "None" else None
         | 
| 154 | 
             
                    model_type = get_model_type(model_name)
         | 
| 155 | 
            +
                    dtype_model = torch.bfloat16 if model_type == "FLUX" else torch.float16
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                    if not os.path.exists(model_name):
         | 
| 158 | 
            +
                        _ = download_diffuser_repo(
         | 
| 159 | 
            +
                            repo_name=model_name,
         | 
| 160 | 
            +
                            model_type=model_type,
         | 
| 161 | 
            +
                            revision="main",
         | 
| 162 | 
            +
                            token=True,
         | 
| 163 | 
            +
                        )
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                    for i in range(68):
         | 
| 166 | 
            +
                        if not self.status_loading:
         | 
| 167 | 
            +
                            self.status_loading = True
         | 
| 168 | 
            +
                            if i > 0:
         | 
| 169 | 
            +
                                time.sleep(self.sleep_loading)
         | 
| 170 | 
            +
                                print("Previous model ops...")
         | 
| 171 | 
            +
                            break
         | 
| 172 | 
            +
                        time.sleep(0.5)
         | 
| 173 | 
            +
                        print(f"Waiting queue {i}")
         | 
| 174 | 
            +
                        yield "Waiting queue"
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                    self.status_loading = True
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                    yield f"Loading model: {model_name}"
         | 
| 179 |  | 
| 180 | 
             
                    if vae_model:
         | 
| 181 | 
             
                        vae_type = "SDXL" if "sdxl" in vae_model.lower() else "SD 1.5"
         | 
| 182 | 
             
                        if model_type != vae_type:
         | 
| 183 | 
            +
                            gr.Warning(WARNING_MSG_VAE)
         | 
| 184 |  | 
| 185 | 
            +
                    print("Loading model...")
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    try:
         | 
| 188 | 
            +
                        start_time = time.time()
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                        if self.model is None:
         | 
| 191 | 
            +
                            self.model = Model_Diffusers(
         | 
| 192 | 
            +
                                base_model_id=model_name,
         | 
| 193 | 
            +
                                task_name=TASK_STABLEPY[task],
         | 
| 194 | 
            +
                                vae_model=vae_model,
         | 
| 195 | 
            +
                                type_model_precision=dtype_model,
         | 
| 196 | 
            +
                                retain_task_model_in_cache=False,
         | 
| 197 | 
            +
                                device="cpu",
         | 
| 198 | 
            +
                            )
         | 
| 199 | 
            +
                        else:
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                            if self.model.base_model_id != model_name:
         | 
| 202 | 
            +
                                load_now_time = datetime.now()
         | 
| 203 | 
            +
                                elapsed_time = max((load_now_time - self.last_load).total_seconds(), 0)
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                                if elapsed_time <= 8:
         | 
| 206 | 
            +
                                    print("Waiting for the previous model's time ops...")
         | 
| 207 | 
            +
                                    time.sleep(8-elapsed_time)
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                            self.model.device = torch.device("cpu")
         | 
| 210 | 
            +
                            self.model.load_pipe(
         | 
| 211 | 
            +
                                model_name,
         | 
| 212 | 
            +
                                task_name=TASK_STABLEPY[task],
         | 
| 213 | 
            +
                                vae_model=vae_model,
         | 
| 214 | 
            +
                                type_model_precision=dtype_model,
         | 
| 215 | 
            +
                                retain_task_model_in_cache=False,
         | 
| 216 | 
            +
                            )
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                        end_time = time.time()
         | 
| 219 | 
            +
                        self.sleep_loading = max(min(int(end_time - start_time), 10), 4)
         | 
| 220 | 
            +
                    except Exception as e:
         | 
| 221 | 
            +
                        self.last_load = datetime.now()
         | 
| 222 | 
            +
                        self.status_loading = False
         | 
| 223 | 
            +
                        self.sleep_loading = 4
         | 
| 224 | 
            +
                        raise e
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                    self.last_load = datetime.now()
         | 
| 227 | 
            +
                    self.status_loading = False
         | 
| 228 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 229 | 
             
                    yield f"Model loaded: {model_name}"
         | 
| 230 |  | 
| 231 | 
             
                #@spaces.GPU
         | 
|  | |
| 334 | 
             
                    mode_ip2,
         | 
| 335 | 
             
                    scale_ip2,
         | 
| 336 | 
             
                    pag_scale,
         | 
|  | |
| 337 | 
             
                ):
         | 
| 338 | 
            +
                    info_state = html_template_message("Navigating latent space...")
         | 
| 339 | 
            +
                    yield info_state, gr.update(), gr.update()
         | 
| 340 |  | 
| 341 | 
             
                    vae_model = vae_model if vae_model != "None" else None
         | 
| 342 | 
             
                    loras_list = [lora1, lora2, lora3, lora4, lora5]
         | 
| 343 | 
             
                    vae_msg = f"VAE: {vae_model}" if vae_model else ""
         | 
| 344 | 
             
                    msg_lora = ""
         | 
| 345 |  | 
|  | |
|  | |
| 346 | 
             
            ## BEGIN MOD
         | 
| 347 | 
            +
                    loras_list = [s if s else "None" for s in loras_list]
         | 
| 348 | 
             
                    global lora_model_list
         | 
| 349 | 
             
                    lora_model_list = get_lora_model_list()
         | 
| 350 | 
             
                    lora1, lora_scale1, lora2, lora_scale2, lora3, lora_scale3, lora4, lora_scale4, lora5, lora_scale5 = \
         | 
|  | |
| 353 | 
             
                    prompt, neg_prompt = insert_model_recom_prompt(prompt, neg_prompt, model_name)
         | 
| 354 | 
             
            ## END MOD
         | 
| 355 |  | 
| 356 | 
            +
                    print("Config model:", model_name, vae_model, loras_list)
         | 
| 357 | 
            +
             | 
| 358 | 
             
                    task = TASK_STABLEPY[task]
         | 
| 359 |  | 
| 360 | 
             
                    params_ip_img = []
         | 
|  | |
| 377 | 
             
                            params_ip_mode.append(modeip)
         | 
| 378 | 
             
                            params_ip_scale.append(scaleip)
         | 
| 379 |  | 
| 380 | 
            +
                    concurrency = 5
         | 
| 381 | 
            +
                    self.model.stream_config(concurrency=concurrency, latent_resize_by=1, vae_decoding=False)
         | 
| 382 |  | 
| 383 | 
             
                    if task != "txt2img" and not image_control:
         | 
| 384 | 
             
                        raise ValueError("No control image found: To use this function, you have to upload an image in 'Image ControlNet/Inpaint/Img2img'")
         | 
|  | |
| 451 | 
             
                        "high_threshold": high_threshold,
         | 
| 452 | 
             
                        "value_threshold": value_threshold,
         | 
| 453 | 
             
                        "distance_threshold": distance_threshold,
         | 
| 454 | 
            +
                        "lora_A": lora1 if lora1 != "None" else None,
         | 
| 455 | 
             
                        "lora_scale_A": lora_scale1,
         | 
| 456 | 
            +
                        "lora_B": lora2 if lora2 != "None" else None,
         | 
| 457 | 
             
                        "lora_scale_B": lora_scale2,
         | 
| 458 | 
            +
                        "lora_C": lora3 if lora3 != "None" else None,
         | 
| 459 | 
             
                        "lora_scale_C": lora_scale3,
         | 
| 460 | 
            +
                        "lora_D": lora4 if lora4 != "None" else None,
         | 
| 461 | 
             
                        "lora_scale_D": lora_scale4,
         | 
| 462 | 
            +
                        "lora_E": lora5 if lora5 != "None" else None,
         | 
| 463 | 
             
                        "lora_scale_E": lora_scale5,
         | 
| 464 | 
             
            ## BEGIN MOD
         | 
| 465 | 
             
                        "textual_inversion": get_embed_list(self.model.class_name) if textual_inversion else [],
         | 
|  | |
| 509 | 
             
                    }
         | 
| 510 |  | 
| 511 | 
             
                    self.model.device = torch.device("cuda:0")
         | 
| 512 | 
            +
                    if hasattr(self.model.pipe, "transformer") and loras_list != ["None"] * 5:
         | 
| 513 | 
             
                        self.model.pipe.transformer.to(self.model.device)
         | 
| 514 | 
             
                        print("transformer to cuda")
         | 
| 515 |  | 
| 516 | 
            +
                    actual_progress = 0
         | 
| 517 | 
            +
                    info_images = gr.update()
         | 
|  | |
| 518 | 
             
                    for img, seed, image_path, metadata in self.model(**pipe_params):
         | 
| 519 | 
            +
                        info_state = progress_step_bar(actual_progress, steps)
         | 
| 520 | 
            +
                        actual_progress += concurrency
         | 
| 521 | 
             
                        if image_path:
         | 
| 522 | 
            +
                            info_images = f"Seeds: {str(seed)}"
         | 
| 523 | 
             
                            if vae_msg:
         | 
| 524 | 
            +
                                info_images = info_images + "<br>" + vae_msg
         | 
| 525 | 
            +
             | 
| 526 | 
            +
                            if "Cannot copy out of meta tensor; no data!" in self.model.last_lora_error:
         | 
| 527 | 
            +
                                msg_ram = "Unable to process the LoRAs due to high RAM usage; please try again later."
         | 
| 528 | 
            +
                                print(msg_ram)
         | 
| 529 | 
            +
                                msg_lora += f"<br>{msg_ram}"
         | 
| 530 |  | 
| 531 | 
             
                            for status, lora in zip(self.model.lora_status, self.model.lora_memory):
         | 
| 532 | 
             
                                if status:
         | 
|  | |
| 535 | 
             
                                    msg_lora += f"<br>Error with: {lora}"
         | 
| 536 |  | 
| 537 | 
             
                            if msg_lora:
         | 
| 538 | 
            +
                                info_images += msg_lora
         | 
| 539 |  | 
| 540 | 
            +
                            info_images = info_images + "<br>" + "GENERATION DATA:<br>" + metadata[0].replace("\n", "<br>") + "<br>-------<br>"
         | 
| 541 |  | 
| 542 | 
             
                            download_links = "<br>".join(
         | 
| 543 | 
             
                                [
         | 
|  | |
| 546 | 
             
                                ]
         | 
| 547 | 
             
                            )
         | 
| 548 | 
             
                            if save_generated_images:
         | 
| 549 | 
            +
                                info_images += f"<br>{download_links}"
         | 
| 550 |  | 
| 551 | 
            +
            ## BEGIN MOD
         | 
| 552 | 
             
                            img = save_images(img, metadata)
         | 
| 553 | 
            +
            ## END MOD
         | 
| 554 |  | 
| 555 | 
            +
                            info_state = "COMPLETE"
         | 
|  | |
|  | |
|  | |
| 556 |  | 
| 557 | 
            +
                        yield info_state, img, info_images
         | 
|  | |
| 558 |  | 
|  | |
| 559 |  | 
| 560 | 
             
            def dynamic_gpu_duration(func, duration, *args):
         | 
| 561 |  | 
|  | |
| 565 |  | 
| 566 | 
             
                return wrapped_func()
         | 
| 567 |  | 
| 568 | 
            +
             | 
| 569 | 
             
            @spaces.GPU
         | 
| 570 | 
             
            def dummy_gpu():
         | 
| 571 | 
             
                return None
         | 
| 572 |  | 
| 573 | 
            +
             | 
| 574 | 
             
            def sd_gen_generate_pipeline(*args):
         | 
| 575 |  | 
| 576 | 
             
                gpu_duration_arg = int(args[-1]) if args[-1] else 59
         | 
|  | |
| 578 | 
             
                load_lora_cpu = args[-3]
         | 
| 579 | 
             
                generation_args = args[:-3]
         | 
| 580 | 
             
                lora_list = [
         | 
| 581 | 
            +
                    None if item == "None" or item == "" else item # MOD
         | 
| 582 | 
             
                    for item in [args[7], args[9], args[11], args[13], args[15]]
         | 
| 583 | 
             
                ]
         | 
| 584 | 
             
                lora_status = [None] * 5
         | 
|  | |
| 588 | 
             
                    msg_load_lora = "Updating LoRAs in CPU (Slow but saves GPU usage)..."
         | 
| 589 |  | 
| 590 | 
             
                if lora_list != sd_gen.model.lora_memory and lora_list != [None] * 5:
         | 
| 591 | 
            +
                    yield msg_load_lora, gr.update(), gr.update()
         | 
| 592 |  | 
| 593 | 
             
                # Load lora in CPU
         | 
| 594 | 
             
                if load_lora_cpu:
         | 
|  | |
| 614 | 
             
                        )
         | 
| 615 | 
             
                        gr.Info(f"LoRAs in cache: {lora_cache_msg}")
         | 
| 616 |  | 
| 617 | 
            +
                msg_request = f"Requesting {gpu_duration_arg}s. of GPU time.\nModel: {sd_gen.model.base_model_id}"
         | 
| 618 | 
            +
                if verbose_arg:
         | 
| 619 | 
             
                    gr.Info(msg_request)
         | 
| 620 | 
             
                    print(msg_request)
         | 
| 621 | 
            +
                yield msg_request.replace("\n", "<br>"), gr.update(), gr.update()
         | 
|  | |
| 622 |  | 
| 623 | 
             
                start_time = time.time()
         | 
| 624 |  | 
| 625 | 
            +
                # yield from sd_gen.generate_pipeline(*generation_args)
         | 
| 626 | 
             
                yield from dynamic_gpu_duration(
         | 
| 627 | 
             
                    sd_gen.generate_pipeline,
         | 
| 628 | 
             
                    gpu_duration_arg,
         | 
|  | |
| 630 | 
             
                )
         | 
| 631 |  | 
| 632 | 
             
                end_time = time.time()
         | 
| 633 | 
            +
                execution_time = end_time - start_time
         | 
| 634 | 
            +
                msg_task_complete = (
         | 
| 635 | 
            +
                    f"GPU task complete in: {int(round(execution_time, 0) + 1)} seconds"
         | 
| 636 | 
            +
                )
         | 
| 637 |  | 
| 638 | 
             
                if verbose_arg:
         | 
|  | |
|  | |
|  | |
|  | |
| 639 | 
             
                    gr.Info(msg_task_complete)
         | 
| 640 | 
             
                    print(msg_task_complete)
         | 
| 641 |  | 
| 642 | 
            +
                yield msg_task_complete, gr.update(), gr.update()
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 643 |  | 
|  | |
|  | |
| 644 |  | 
| 645 | 
            +
            @spaces.GPU(duration=15)
         | 
| 646 | 
             
            def esrgan_upscale(image, upscaler_name, upscaler_size):
         | 
| 647 | 
             
                if image is None: return None
         | 
| 648 |  | 
|  | |
| 664 |  | 
| 665 | 
             
                return image_path
         | 
| 666 |  | 
| 667 | 
            +
             | 
| 668 | 
             
            dynamic_gpu_duration.zerogpu = True
         | 
| 669 | 
             
            sd_gen_generate_pipeline.zerogpu = True
         | 
| 670 | 
             
            sd_gen = GuiSD()
         | 
| 671 |  | 
| 672 | 
            +
             | 
| 673 | 
             
            ## BEGIN MOD
         | 
| 674 | 
             
            CSS ="""
         | 
| 675 | 
            +
            .gradio-container, #main { width:100%; height:100%; max-width:100%; padding-left:0; padding-right:0; margin-left:0; margin-right:0; }
         | 
| 676 | 
            +
            .contain { display:flex; flex-direction:column; }
         | 
| 677 | 
            +
            #component-0 { width:100%; height:100%; }
         | 
| 678 | 
            +
            #gallery { flex-grow:1; }
         | 
| 679 | 
            +
            #load_model { height: 50px; }
         | 
| 680 | 
            +
            .lora { min-width:480px; }
         | 
| 681 | 
             
            #model-info { text-align:center; }
         | 
| 682 | 
             
            .title { font-size: 3em; align-items: center; text-align: center; }
         | 
| 683 | 
             
            .info { align-items: center; text-align: center; }
         | 
|  | |
| 691 | 
             
                    with gr.Tab("Generation"):
         | 
| 692 | 
             
                        with gr.Row():
         | 
| 693 | 
             
                            with gr.Column(scale=2):
         | 
| 694 | 
            +
             | 
| 695 | 
            +
                                def update_task_options(model_name, task_name):
         | 
| 696 | 
            +
                                    new_choices = MODEL_TYPE_TASK[get_model_type(model_name)]
         | 
| 697 | 
            +
             | 
| 698 | 
            +
                                    if task_name not in new_choices:
         | 
| 699 | 
            +
                                        task_name = "txt2img"
         | 
| 700 | 
            +
             | 
| 701 | 
            +
                                    return gr.update(value=task_name, choices=new_choices)
         | 
| 702 | 
            +
                                
         | 
| 703 | 
             
                                interface_mode_gui = gr.Radio(label="Quick settings", choices=["Simple", "Standard", "Fast", "LoRA"], value="Standard")
         | 
| 704 | 
             
                                with gr.Accordion("Model and Task", open=False) as menu_model:
         | 
| 705 | 
             
                                    task_gui = gr.Dropdown(label="Task", choices=SDXL_TASK, value=TASK_MODEL_LIST[0])
         | 
|  | |
| 763 | 
             
                                    [task_gui],
         | 
| 764 | 
             
                                )
         | 
| 765 |  | 
| 766 | 
            +
                                load_model_gui = gr.HTML(elem_id="load_model", elem_classes="contain")
         | 
| 767 |  | 
| 768 | 
             
                                result_images = gr.Gallery(
         | 
| 769 | 
             
                                    label="Generated images",
         | 
|  | |
| 785 |  | 
| 786 | 
             
                                actual_task_info = gr.HTML()
         | 
| 787 |  | 
| 788 | 
            +
                                with gr.Accordion("History", open=False):
         | 
| 789 | 
            +
                                    history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", format="png", interactive=False, show_share_button=False,
         | 
| 790 | 
            +
                                    show_download_button=True)
         | 
| 791 | 
            +
                                    history_files = gr.Files(interactive=False, visible=False)
         | 
| 792 | 
            +
                                    history_clear_button = gr.Button(value="Clear History", variant="secondary")
         | 
| 793 | 
            +
                                    history_clear_button.click(lambda: ([], []), None, [history_gallery, history_files], queue=False, show_api=False)
         | 
| 794 | 
            +
             | 
| 795 | 
             
                                with gr.Row(equal_height=False, variant="default"):
         | 
| 796 | 
             
                                    gpu_duration_gui = gr.Number(minimum=5, maximum=240, value=59, show_label=False, container=False, info="GPU time duration (seconds)")
         | 
| 797 | 
             
                                    with gr.Column():
         | 
|  | |
| 814 | 
             
                                    with gr.Row():
         | 
| 815 | 
             
                                        sampler_gui = gr.Dropdown(label="Sampler", choices=scheduler_names, value="Euler a")
         | 
| 816 | 
             
                                        vae_model_gui = gr.Dropdown(label="VAE Model", choices=vae_model_list, value=vae_model_list[0])
         | 
| 817 | 
            +
                                        prompt_syntax_gui = gr.Dropdown(label="Prompt Syntax", choices=PROMPT_W_OPTIONS, value=PROMPT_W_OPTIONS[1][1])
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 818 |  | 
| 819 | 
             
                                with gr.Row(equal_height=False):
         | 
| 820 |  | 
|  | |
| 964 | 
             
                                        with gr.Accordion("Select from Gallery", open=False):
         | 
| 965 | 
             
                                            search_civitai_gallery_lora = gr.Gallery([], label="Results", allow_preview=False, columns=5, show_share_button=False, interactive=False)
         | 
| 966 | 
             
                                        search_civitai_result_lora = gr.Dropdown(label="Search Results", choices=[("", "")], value="", allow_custom_value=True, visible=False)
         | 
| 967 | 
            +
                                        with gr.Row():
         | 
| 968 | 
            +
                                            text_lora = gr.Textbox(label="LoRA URL", placeholder="https://civitai.com/api/download/models/28907", lines=1, scale=4)
         | 
| 969 | 
            +
                                            romanize_text = gr.Checkbox(value=False, label="Transliterate name", scale=1)
         | 
| 970 | 
             
                                        button_lora = gr.Button("Get and update lists of LoRAs")
         | 
| 971 | 
            +
                                        new_lora_status = gr.HTML()
         | 
| 972 | 
             
                                    with gr.Accordion("From Local", open=True, visible=True):
         | 
| 973 | 
             
                                        file_output_lora = gr.File(label="Uploaded LoRA", file_types=['.ckpt', '.pt', '.pth', '.safetensors', '.bin'], file_count="multiple", interactive=False, visible=False)
         | 
| 974 | 
             
                                        upload_button_lora = gr.UploadButton(label="Upload LoRA from your disk (very slow)", file_types=['.ckpt', '.pt', '.pth', '.safetensors', '.bin'], file_count="multiple")
         | 
|  | |
| 1006 | 
             
                                            negative_prompt_ad_a_gui = gr.Textbox(label="Negative prompt", placeholder="Main negative prompt will be use", lines=3)
         | 
| 1007 | 
             
                                            with gr.Row():
         | 
| 1008 | 
             
                                                strength_ad_a_gui = gr.Number(label="Strength:", value=0.35, step=0.01, minimum=0.01, maximum=1.0)
         | 
| 1009 | 
            +
                                                face_detector_ad_a_gui = gr.Checkbox(label="Face detector", value=False)
         | 
| 1010 | 
             
                                                person_detector_ad_a_gui = gr.Checkbox(label="Person detector", value=True)
         | 
| 1011 | 
             
                                                hand_detector_ad_a_gui = gr.Checkbox(label="Hand detector", value=False)
         | 
| 1012 | 
             
                                            with gr.Row():
         | 
|  | |
| 1021 | 
             
                                            negative_prompt_ad_b_gui = gr.Textbox(label="Negative prompt", placeholder="Main negative prompt will be use", lines=3)
         | 
| 1022 | 
             
                                            with gr.Row():
         | 
| 1023 | 
             
                                                strength_ad_b_gui = gr.Number(label="Strength:", value=0.35, step=0.01, minimum=0.01, maximum=1.0)
         | 
| 1024 | 
            +
                                                face_detector_ad_b_gui = gr.Checkbox(label="Face detector", value=False)
         | 
| 1025 | 
             
                                                person_detector_ad_b_gui = gr.Checkbox(label="Person detector", value=True)
         | 
| 1026 | 
             
                                                hand_detector_ad_b_gui = gr.Checkbox(label="Hand detector", value=False)
         | 
| 1027 | 
             
                                            with gr.Row():
         | 
|  | |
| 1151 |  | 
| 1152 | 
             
                        with gr.Accordion("Examples and help", open=True, visible=True) as menu_example:
         | 
| 1153 | 
             
                            gr.Examples(
         | 
| 1154 | 
            +
                                examples=EXAMPLES_GUI,
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1155 | 
             
                                fn=sd_gen.generate_pipeline,
         | 
| 1156 | 
             
                                inputs=[
         | 
| 1157 | 
             
                                    prompt_gui,
         | 
|  | |
| 1166 | 
             
                                    img_width_gui,
         | 
| 1167 | 
             
                                    model_name_gui,
         | 
| 1168 | 
             
                                ],
         | 
| 1169 | 
            +
                                outputs=[load_model_gui, result_images, actual_task_info],
         | 
| 1170 | 
             
                                cache_examples=False,
         | 
| 1171 | 
             
                                #elem_id="examples",
         | 
| 1172 | 
             
                            )
         | 
| 1173 |  | 
| 1174 | 
            +
                        gr.Markdown(RESOURCES)
         | 
|  | |
|  | |
|  | |
|  | |
| 1175 | 
             
            ## END MOD
         | 
| 1176 |  | 
| 1177 | 
             
                    with gr.Tab("Inpaint mask maker", render=True):
         | 
|  | |
| 1330 | 
             
                    )
         | 
| 1331 | 
             
                    search_civitai_result_lora.change(select_civitai_lora, [search_civitai_result_lora], [text_lora, search_civitai_desc_lora], queue=False, scroll_to_output=True)
         | 
| 1332 | 
             
                    search_civitai_gallery_lora.select(update_civitai_selection, None, [search_civitai_result_lora], queue=False, show_api=False)
         | 
| 1333 | 
            +
                    button_lora.click(get_my_lora, [text_lora, romanize_text], [lora1_gui, lora2_gui, lora3_gui, lora4_gui, lora5_gui, new_lora_status], scroll_to_output=True)
         | 
| 1334 | 
             
                    upload_button_lora.upload(upload_file_lora, [upload_button_lora], [file_output_lora, upload_button_lora]).success(
         | 
| 1335 | 
             
                        move_file_lora, [file_output_lora], [lora1_gui, lora2_gui, lora3_gui, lora4_gui, lora5_gui], scroll_to_output=True)
         | 
| 1336 |  | 
|  | |
| 1486 | 
             
                            verbose_info_gui,
         | 
| 1487 | 
             
                            gpu_duration_gui,
         | 
| 1488 | 
             
                        ],
         | 
| 1489 | 
            +
                        outputs=[load_model_gui, result_images, actual_task_info], 
         | 
| 1490 | 
             
                        queue=True,
         | 
| 1491 | 
             
                        show_progress="full",
         | 
| 1492 | 
            +
                    ).success(save_gallery_images, [result_images, model_name_gui], [result_images, result_images_files], queue=False, show_api=False)\
         | 
| 1493 | 
            +
                    .success(save_gallery_history, [result_images, result_images_files, history_gallery, history_files], [history_gallery, history_files], queue=False, show_api=False)
         | 
| 1494 |  | 
| 1495 | 
             
                    with gr.Tab("Danbooru Tags Transformer with WD Tagger", render=True):
         | 
| 1496 | 
             
                        with gr.Column(scale=2):
         | 
|  | |
| 1586 | 
             
                gr.DuplicateButton(value="Duplicate Space for private use (This demo does not work on CPU. Requires GPU Space)")
         | 
| 1587 |  | 
| 1588 | 
             
            app.queue()
         | 
| 1589 | 
            +
            app.launch(show_error=True, debug=True) # allowed_paths=["./images/"], show_error=True, debug=True
         | 
| 1590 | 
             
            ## END MOD
         | 
    	
        constants.py
    ADDED
    
    | @@ -0,0 +1,453 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            from stablepy.diffusers_vanilla.constants import FLUX_CN_UNION_MODES
         | 
| 3 | 
            +
            from stablepy import (
         | 
| 4 | 
            +
                scheduler_names,
         | 
| 5 | 
            +
                SD15_TASKS,
         | 
| 6 | 
            +
                SDXL_TASKS,
         | 
| 7 | 
            +
            )
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            # - **Download Models**
         | 
| 10 | 
            +
            DOWNLOAD_MODEL = "https://civitai.com/api/download/models/574369, https://huggingface.co/TechnoByte/MilkyWonderland/resolve/main/milkyWonderland_v40.safetensors"
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            # - **Download VAEs**
         | 
| 13 | 
            +
            DOWNLOAD_VAE = "https://huggingface.co/nubby/blessed-sdxl-vae-fp16-fix/resolve/main/sdxl_vae-fp16fix-c-1.1-b-0.5.safetensors?download=true, https://huggingface.co/nubby/blessed-sdxl-vae-fp16-fix/resolve/main/sdxl_vae-fp16fix-blessed.safetensors?download=true, https://huggingface.co/digiplay/VAE/resolve/main/vividReal_v20.safetensors?download=true, https://huggingface.co/fp16-guy/anything_kl-f8-anime2_vae-ft-mse-840000-ema-pruned_blessed_clearvae_fp16_cleaned/resolve/main/vae-ft-mse-840000-ema-pruned_fp16.safetensors?download=true"
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            # - **Download LoRAs**
         | 
| 16 | 
            +
            DOWNLOAD_LORA = "https://huggingface.co/Leopain/color/resolve/main/Coloring_book_-_LineArt.safetensors, https://civitai.com/api/download/models/135867, https://huggingface.co/Linaqruf/anime-detailer-xl-lora/resolve/main/anime-detailer-xl.safetensors?download=true, https://huggingface.co/Linaqruf/style-enhancer-xl-lora/resolve/main/style-enhancer-xl.safetensors?download=true, https://huggingface.co/ByteDance/Hyper-SD/resolve/main/Hyper-SD15-8steps-CFG-lora.safetensors?download=true, https://huggingface.co/ByteDance/Hyper-SD/resolve/main/Hyper-SDXL-8steps-CFG-lora.safetensors?download=true"
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            LOAD_DIFFUSERS_FORMAT_MODEL = [
         | 
| 19 | 
            +
                'stabilityai/stable-diffusion-xl-base-1.0',
         | 
| 20 | 
            +
                'black-forest-labs/FLUX.1-dev',
         | 
| 21 | 
            +
                'John6666/blue-pencil-flux1-v021-fp8-flux',
         | 
| 22 | 
            +
                'John6666/wai-ani-flux-v10forfp8-fp8-flux',
         | 
| 23 | 
            +
                'John6666/xe-anime-flux-v04-fp8-flux',
         | 
| 24 | 
            +
                'John6666/lyh-anime-flux-v2a1-fp8-flux',
         | 
| 25 | 
            +
                'John6666/carnival-unchained-v10-fp8-flux',
         | 
| 26 | 
            +
                'cagliostrolab/animagine-xl-3.1',
         | 
| 27 | 
            +
                'John6666/epicrealism-xl-v8kiss-sdxl',
         | 
| 28 | 
            +
                'misri/epicrealismXL_v7FinalDestination',
         | 
| 29 | 
            +
                'misri/juggernautXL_juggernautX',
         | 
| 30 | 
            +
                'misri/zavychromaxl_v80',
         | 
| 31 | 
            +
                'SG161222/RealVisXL_V4.0',
         | 
| 32 | 
            +
                'SG161222/RealVisXL_V5.0',
         | 
| 33 | 
            +
                'misri/newrealityxlAllInOne_Newreality40',
         | 
| 34 | 
            +
                'eienmojiki/Anything-XL',
         | 
| 35 | 
            +
                'eienmojiki/Starry-XL-v5.2',
         | 
| 36 | 
            +
                'gsdf/CounterfeitXL',
         | 
| 37 | 
            +
                'KBlueLeaf/Kohaku-XL-Zeta',
         | 
| 38 | 
            +
                'John6666/silvermoon-mix-01xl-v11-sdxl',
         | 
| 39 | 
            +
                'WhiteAiZ/autismmixSDXL_autismmixConfetti_diffusers',
         | 
| 40 | 
            +
                'kitty7779/ponyDiffusionV6XL',
         | 
| 41 | 
            +
                'GraydientPlatformAPI/aniverse-pony',
         | 
| 42 | 
            +
                'John6666/ras-real-anime-screencap-v1-sdxl',
         | 
| 43 | 
            +
                'John6666/duchaiten-pony-xl-no-score-v60-sdxl',
         | 
| 44 | 
            +
                'John6666/mistoon-anime-ponyalpha-sdxl',
         | 
| 45 | 
            +
                'John6666/3x3x3mixxl-v2-sdxl',
         | 
| 46 | 
            +
                'John6666/3x3x3mixxl-3dv01-sdxl',
         | 
| 47 | 
            +
                'John6666/ebara-mfcg-pony-mix-v12-sdxl',
         | 
| 48 | 
            +
                'John6666/t-ponynai3-v51-sdxl',
         | 
| 49 | 
            +
                'John6666/t-ponynai3-v65-sdxl',
         | 
| 50 | 
            +
                'John6666/prefect-pony-xl-v3-sdxl',
         | 
| 51 | 
            +
                'John6666/mala-anime-mix-nsfw-pony-xl-v5-sdxl',
         | 
| 52 | 
            +
                'John6666/wai-real-mix-v11-sdxl',
         | 
| 53 | 
            +
                'John6666/wai-c-v6-sdxl',
         | 
| 54 | 
            +
                'John6666/iniverse-mix-xl-sfwnsfw-pony-guofeng-v43-sdxl',
         | 
| 55 | 
            +
                'John6666/photo-realistic-pony-v5-sdxl',
         | 
| 56 | 
            +
                'John6666/pony-realism-v21main-sdxl',
         | 
| 57 | 
            +
                'John6666/pony-realism-v22main-sdxl',
         | 
| 58 | 
            +
                'John6666/cyberrealistic-pony-v63-sdxl',
         | 
| 59 | 
            +
                'John6666/cyberrealistic-pony-v64-sdxl',
         | 
| 60 | 
            +
                'John6666/cyberrealistic-pony-v65-sdxl',
         | 
| 61 | 
            +
                'GraydientPlatformAPI/realcartoon-pony-diffusion',
         | 
| 62 | 
            +
                'John6666/nova-anime-xl-pony-v5-sdxl',
         | 
| 63 | 
            +
                'John6666/autismmix-sdxl-autismmix-pony-sdxl',
         | 
| 64 | 
            +
                'John6666/aimz-dream-real-pony-mix-v3-sdxl',
         | 
| 65 | 
            +
                'John6666/duchaiten-pony-real-v11fix-sdxl',
         | 
| 66 | 
            +
                'John6666/duchaiten-pony-real-v20-sdxl',
         | 
| 67 | 
            +
                'yodayo-ai/kivotos-xl-2.0',
         | 
| 68 | 
            +
                'yodayo-ai/holodayo-xl-2.1',
         | 
| 69 | 
            +
                'yodayo-ai/clandestine-xl-1.0',
         | 
| 70 | 
            +
                'digiplay/majicMIX_sombre_v2',
         | 
| 71 | 
            +
                'digiplay/majicMIX_realistic_v6',
         | 
| 72 | 
            +
                'digiplay/majicMIX_realistic_v7',
         | 
| 73 | 
            +
                'digiplay/DreamShaper_8',
         | 
| 74 | 
            +
                'digiplay/BeautifulArt_v1',
         | 
| 75 | 
            +
                'digiplay/DarkSushi2.5D_v1',
         | 
| 76 | 
            +
                'digiplay/darkphoenix3D_v1.1',
         | 
| 77 | 
            +
                'digiplay/BeenYouLiteL11_diffusers',
         | 
| 78 | 
            +
                'Yntec/RevAnimatedV2Rebirth',
         | 
| 79 | 
            +
                'youknownothing/cyberrealistic_v50',
         | 
| 80 | 
            +
                'youknownothing/deliberate-v6',
         | 
| 81 | 
            +
                'GraydientPlatformAPI/deliberate-cyber3',
         | 
| 82 | 
            +
                'GraydientPlatformAPI/picx-real',
         | 
| 83 | 
            +
                'GraydientPlatformAPI/perfectworld6',
         | 
| 84 | 
            +
                'emilianJR/epiCRealism',
         | 
| 85 | 
            +
                'votepurchase/counterfeitV30_v30',
         | 
| 86 | 
            +
                'votepurchase/ChilloutMix',
         | 
| 87 | 
            +
                'Meina/MeinaMix_V11',
         | 
| 88 | 
            +
                'Meina/MeinaUnreal_V5',
         | 
| 89 | 
            +
                'Meina/MeinaPastel_V7',
         | 
| 90 | 
            +
                'GraydientPlatformAPI/realcartoon3d-17',
         | 
| 91 | 
            +
                'GraydientPlatformAPI/realcartoon-pixar11',
         | 
| 92 | 
            +
                'GraydientPlatformAPI/realcartoon-real17',
         | 
| 93 | 
            +
            ]
         | 
| 94 | 
            +
             | 
| 95 | 
            +
            DIFFUSERS_FORMAT_LORAS = [
         | 
| 96 | 
            +
                "nerijs/animation2k-flux",
         | 
| 97 | 
            +
                "XLabs-AI/flux-RealismLora",
         | 
| 98 | 
            +
            ]
         | 
| 99 | 
            +
             | 
| 100 | 
            +
            DOWNLOAD_EMBEDS = [
         | 
| 101 | 
            +
                'https://huggingface.co/datasets/Nerfgun3/bad_prompt/blob/main/bad_prompt_version2.pt',
         | 
| 102 | 
            +
                'https://huggingface.co/embed/negative/resolve/main/EasyNegativeV2.safetensors',
         | 
| 103 | 
            +
                'https://huggingface.co/embed/negative/resolve/main/bad-hands-5.pt',
         | 
| 104 | 
            +
                ]
         | 
| 105 | 
            +
             | 
| 106 | 
            +
            CIVITAI_API_KEY = os.environ.get("CIVITAI_API_KEY")
         | 
| 107 | 
            +
            HF_TOKEN = os.environ.get("HF_READ_TOKEN")
         | 
| 108 | 
            +
             | 
| 109 | 
            +
            DIRECTORY_MODELS = 'models'
         | 
| 110 | 
            +
            DIRECTORY_LORAS = 'loras'
         | 
| 111 | 
            +
            DIRECTORY_VAES = 'vaes'
         | 
| 112 | 
            +
            DIRECTORY_EMBEDS = 'embedings'
         | 
| 113 | 
            +
             | 
| 114 | 
            +
            PREPROCESSOR_CONTROLNET = {
         | 
| 115 | 
            +
              "openpose": [
         | 
| 116 | 
            +
                "Openpose",
         | 
| 117 | 
            +
                "None",
         | 
| 118 | 
            +
              ],
         | 
| 119 | 
            +
              "scribble": [
         | 
| 120 | 
            +
                "HED",
         | 
| 121 | 
            +
                "PidiNet",
         | 
| 122 | 
            +
                "None",
         | 
| 123 | 
            +
              ],
         | 
| 124 | 
            +
              "softedge": [
         | 
| 125 | 
            +
                "PidiNet",
         | 
| 126 | 
            +
                "HED",
         | 
| 127 | 
            +
                "HED safe",
         | 
| 128 | 
            +
                "PidiNet safe",
         | 
| 129 | 
            +
                "None",
         | 
| 130 | 
            +
              ],
         | 
| 131 | 
            +
              "segmentation": [
         | 
| 132 | 
            +
                "UPerNet",
         | 
| 133 | 
            +
                "None",
         | 
| 134 | 
            +
              ],
         | 
| 135 | 
            +
              "depth": [
         | 
| 136 | 
            +
                "DPT",
         | 
| 137 | 
            +
                "Midas",
         | 
| 138 | 
            +
                "None",
         | 
| 139 | 
            +
              ],
         | 
| 140 | 
            +
              "normalbae": [
         | 
| 141 | 
            +
                "NormalBae",
         | 
| 142 | 
            +
                "None",
         | 
| 143 | 
            +
              ],
         | 
| 144 | 
            +
              "lineart": [
         | 
| 145 | 
            +
                "Lineart",
         | 
| 146 | 
            +
                "Lineart coarse",
         | 
| 147 | 
            +
                "Lineart (anime)",
         | 
| 148 | 
            +
                "None",
         | 
| 149 | 
            +
                "None (anime)",
         | 
| 150 | 
            +
              ],
         | 
| 151 | 
            +
              "lineart_anime": [
         | 
| 152 | 
            +
                "Lineart",
         | 
| 153 | 
            +
                "Lineart coarse",
         | 
| 154 | 
            +
                "Lineart (anime)",
         | 
| 155 | 
            +
                "None",
         | 
| 156 | 
            +
                "None (anime)",
         | 
| 157 | 
            +
              ],
         | 
| 158 | 
            +
              "shuffle": [
         | 
| 159 | 
            +
                "ContentShuffle",
         | 
| 160 | 
            +
                "None",
         | 
| 161 | 
            +
              ],
         | 
| 162 | 
            +
              "canny": [
         | 
| 163 | 
            +
                "Canny",
         | 
| 164 | 
            +
                "None",
         | 
| 165 | 
            +
              ],
         | 
| 166 | 
            +
              "mlsd": [
         | 
| 167 | 
            +
                "MLSD",
         | 
| 168 | 
            +
                "None",
         | 
| 169 | 
            +
              ],
         | 
| 170 | 
            +
              "ip2p": [
         | 
| 171 | 
            +
                "ip2p"
         | 
| 172 | 
            +
              ],
         | 
| 173 | 
            +
              "recolor": [
         | 
| 174 | 
            +
                "Recolor luminance",
         | 
| 175 | 
            +
                "Recolor intensity",
         | 
| 176 | 
            +
                "None",
         | 
| 177 | 
            +
              ],
         | 
| 178 | 
            +
              "tile": [
         | 
| 179 | 
            +
                "Mild Blur",
         | 
| 180 | 
            +
                "Moderate Blur",
         | 
| 181 | 
            +
                "Heavy Blur",
         | 
| 182 | 
            +
                "None",
         | 
| 183 | 
            +
              ],
         | 
| 184 | 
            +
             | 
| 185 | 
            +
            }
         | 
| 186 | 
            +
             | 
| 187 | 
            +
            TASK_STABLEPY = {
         | 
| 188 | 
            +
                'txt2img': 'txt2img',
         | 
| 189 | 
            +
                'img2img': 'img2img',
         | 
| 190 | 
            +
                'inpaint': 'inpaint',
         | 
| 191 | 
            +
                # 'canny T2I Adapter': 'sdxl_canny_t2i',  # NO HAVE STEP CALLBACK PARAMETERS SO NOT WORKS WITH DIFFUSERS 0.29.0
         | 
| 192 | 
            +
                # 'sketch  T2I Adapter': 'sdxl_sketch_t2i',
         | 
| 193 | 
            +
                # 'lineart  T2I Adapter': 'sdxl_lineart_t2i',
         | 
| 194 | 
            +
                # 'depth-midas  T2I Adapter': 'sdxl_depth-midas_t2i',
         | 
| 195 | 
            +
                # 'openpose  T2I Adapter': 'sdxl_openpose_t2i',
         | 
| 196 | 
            +
                'openpose ControlNet': 'openpose',
         | 
| 197 | 
            +
                'canny ControlNet': 'canny',
         | 
| 198 | 
            +
                'mlsd ControlNet': 'mlsd',
         | 
| 199 | 
            +
                'scribble ControlNet': 'scribble',
         | 
| 200 | 
            +
                'softedge ControlNet': 'softedge',
         | 
| 201 | 
            +
                'segmentation ControlNet': 'segmentation',
         | 
| 202 | 
            +
                'depth ControlNet': 'depth',
         | 
| 203 | 
            +
                'normalbae ControlNet': 'normalbae',
         | 
| 204 | 
            +
                'lineart ControlNet': 'lineart',
         | 
| 205 | 
            +
                'lineart_anime ControlNet': 'lineart_anime',
         | 
| 206 | 
            +
                'shuffle ControlNet': 'shuffle',
         | 
| 207 | 
            +
                'ip2p ControlNet': 'ip2p',
         | 
| 208 | 
            +
                'optical pattern ControlNet': 'pattern',
         | 
| 209 | 
            +
                'recolor ControlNet': 'recolor',
         | 
| 210 | 
            +
                'tile ControlNet': 'tile',
         | 
| 211 | 
            +
            }
         | 
| 212 | 
            +
             | 
| 213 | 
            +
            TASK_MODEL_LIST = list(TASK_STABLEPY.keys())
         | 
| 214 | 
            +
             | 
| 215 | 
            +
            UPSCALER_DICT_GUI = {
         | 
| 216 | 
            +
                None: None,
         | 
| 217 | 
            +
                "Lanczos": "Lanczos",
         | 
| 218 | 
            +
                "Nearest": "Nearest",
         | 
| 219 | 
            +
                'Latent': 'Latent',
         | 
| 220 | 
            +
                'Latent (antialiased)': 'Latent (antialiased)',
         | 
| 221 | 
            +
                'Latent (bicubic)': 'Latent (bicubic)',
         | 
| 222 | 
            +
                'Latent (bicubic antialiased)': 'Latent (bicubic antialiased)',
         | 
| 223 | 
            +
                'Latent (nearest)': 'Latent (nearest)',
         | 
| 224 | 
            +
                'Latent (nearest-exact)': 'Latent (nearest-exact)',
         | 
| 225 | 
            +
                "RealESRGAN_x4plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
         | 
| 226 | 
            +
                "RealESRNet_x4plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth",
         | 
| 227 | 
            +
                "RealESRGAN_x4plus_anime_6B": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
         | 
| 228 | 
            +
                "RealESRGAN_x2plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
         | 
| 229 | 
            +
                "realesr-animevideov3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
         | 
| 230 | 
            +
                "realesr-general-x4v3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
         | 
| 231 | 
            +
                "realesr-general-wdn-x4v3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
         | 
| 232 | 
            +
                "4x-UltraSharp": "https://huggingface.co/Shandypur/ESRGAN-4x-UltraSharp/resolve/main/4x-UltraSharp.pth",
         | 
| 233 | 
            +
                "4x_foolhardy_Remacri": "https://huggingface.co/FacehugmanIII/4x_foolhardy_Remacri/resolve/main/4x_foolhardy_Remacri.pth",
         | 
| 234 | 
            +
                "Remacri4xExtraSmoother": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/Remacri%204x%20ExtraSmoother.pth",
         | 
| 235 | 
            +
                "AnimeSharp4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/AnimeSharp%204x.pth",
         | 
| 236 | 
            +
                "lollypop": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/lollypop.pth",
         | 
| 237 | 
            +
                "RealisticRescaler4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/RealisticRescaler%204x.pth",
         | 
| 238 | 
            +
                "NickelbackFS4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/NickelbackFS%204x.pth"
         | 
| 239 | 
            +
            }
         | 
| 240 | 
            +
             | 
| 241 | 
            +
            UPSCALER_KEYS = list(UPSCALER_DICT_GUI.keys())
         | 
| 242 | 
            +
             | 
| 243 | 
            +
            PROMPT_W_OPTIONS = [
         | 
| 244 | 
            +
                ("Compel format: (word)weight", "Compel"),
         | 
| 245 | 
            +
                ("Classic format: (word:weight)", "Classic"),
         | 
| 246 | 
            +
                ("Classic-original format: (word:weight)", "Classic-original"),
         | 
| 247 | 
            +
                ("Classic-no_norm format: (word:weight)", "Classic-no_norm"),
         | 
| 248 | 
            +
                ("Classic-ignore", "Classic-ignore"),
         | 
| 249 | 
            +
                ("None", "None"),
         | 
| 250 | 
            +
            ]
         | 
| 251 | 
            +
             | 
| 252 | 
            +
            WARNING_MSG_VAE = (
         | 
| 253 | 
            +
                "Use the right VAE for your model to maintain image quality. The wrong"
         | 
| 254 | 
            +
                " VAE can lead to poor results, like blurriness in the generated images."
         | 
| 255 | 
            +
            )
         | 
| 256 | 
            +
             | 
| 257 | 
            +
            SDXL_TASK = [k for k, v in TASK_STABLEPY.items() if v in SDXL_TASKS]
         | 
| 258 | 
            +
            SD_TASK = [k for k, v in TASK_STABLEPY.items() if v in SD15_TASKS]
         | 
| 259 | 
            +
            FLUX_TASK = list(TASK_STABLEPY.keys())[:3] + [k for k, v in TASK_STABLEPY.items() if v in FLUX_CN_UNION_MODES.keys()]
         | 
| 260 | 
            +
             | 
| 261 | 
            +
            MODEL_TYPE_TASK = {
         | 
| 262 | 
            +
                "SD 1.5": SD_TASK,
         | 
| 263 | 
            +
                "SDXL": SDXL_TASK,
         | 
| 264 | 
            +
                "FLUX": FLUX_TASK,
         | 
| 265 | 
            +
            }
         | 
| 266 | 
            +
             | 
| 267 | 
            +
            MODEL_TYPE_CLASS = {
         | 
| 268 | 
            +
                "diffusers:StableDiffusionPipeline": "SD 1.5",
         | 
| 269 | 
            +
                "diffusers:StableDiffusionXLPipeline": "SDXL",
         | 
| 270 | 
            +
                "diffusers:FluxPipeline": "FLUX",
         | 
| 271 | 
            +
            }
         | 
| 272 | 
            +
             | 
| 273 | 
            +
            POST_PROCESSING_SAMPLER = ["Use same sampler"] + scheduler_names[:-2]
         | 
| 274 | 
            +
             | 
| 275 | 
            +
            SUBTITLE_GUI = (
         | 
| 276 | 
            +
                "### This demo uses [diffusers](https://github.com/huggingface/diffusers)"
         | 
| 277 | 
            +
                " to perform different tasks in image generation."
         | 
| 278 | 
            +
            )
         | 
| 279 | 
            +
             | 
| 280 | 
            +
            HELP_GUI = (
         | 
| 281 | 
            +
                """### Help:
         | 
| 282 | 
            +
                - The current space runs on a ZERO GPU which is assigned for approximately 60 seconds; Therefore, if you submit expensive tasks, the operation may be canceled upon reaching the maximum allowed time with 'GPU TASK ABORTED'.
         | 
| 283 | 
            +
                - Distorted or strange images often result from high prompt weights, so it's best to use low weights and scales, and consider using Classic variants like 'Classic-original'.
         | 
| 284 | 
            +
                - For better results with Pony Diffusion, try using sampler DPM++ 1s or DPM2 with Compel or Classic prompt weights.
         | 
| 285 | 
            +
                """
         | 
| 286 | 
            +
            )
         | 
| 287 | 
            +
             | 
| 288 | 
            +
            EXAMPLES_GUI_HELP = (
         | 
| 289 | 
            +
                """### The following examples perform specific tasks:
         | 
| 290 | 
            +
                1. Generation with SDXL and upscale
         | 
| 291 | 
            +
                2. Generation with FLUX dev
         | 
| 292 | 
            +
                3. ControlNet Canny SDXL
         | 
| 293 | 
            +
                4. Optical pattern (Optical illusion) SDXL
         | 
| 294 | 
            +
                5. Convert an image to a coloring drawing
         | 
| 295 | 
            +
                6. ControlNet OpenPose SD 1.5 and Latent upscale
         | 
| 296 | 
            +
             | 
| 297 | 
            +
                - Different tasks can be performed, such as img2img or using the IP adapter, to preserve a person's appearance or a specific style based on an image.
         | 
| 298 | 
            +
                """
         | 
| 299 | 
            +
            )
         | 
| 300 | 
            +
             | 
| 301 | 
            +
            EXAMPLES_GUI = [
         | 
| 302 | 
            +
                [
         | 
| 303 | 
            +
                    "1girl, souryuu asuka langley, neon genesis evangelion, rebuild of evangelion, lance of longinus, cat hat, plugsuit, pilot suit, red bodysuit, sitting, crossed legs, black eye patch, throne, looking down, from bottom, looking at viewer, outdoors, (masterpiece), (best quality), (ultra-detailed), very aesthetic, illustration, disheveled hair, perfect composition, moist skin, intricate details",
         | 
| 304 | 
            +
                    "nfsw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, unfinished, very displeasing, oldest, early, chromatic aberration, artistic error, scan, abstract",
         | 
| 305 | 
            +
                    28,
         | 
| 306 | 
            +
                    7.0,
         | 
| 307 | 
            +
                    -1,
         | 
| 308 | 
            +
                    "None",
         | 
| 309 | 
            +
                    0.33,
         | 
| 310 | 
            +
                    "Euler a",
         | 
| 311 | 
            +
                    1152,
         | 
| 312 | 
            +
                    896,
         | 
| 313 | 
            +
                    "cagliostrolab/animagine-xl-3.1",
         | 
| 314 | 
            +
                    "txt2img",
         | 
| 315 | 
            +
                    "image.webp",  # img conttol
         | 
| 316 | 
            +
                    1024,  # img resolution
         | 
| 317 | 
            +
                    0.35,  # strength
         | 
| 318 | 
            +
                    1.0,  # cn scale
         | 
| 319 | 
            +
                    0.0,  # cn start
         | 
| 320 | 
            +
                    1.0,  # cn end
         | 
| 321 | 
            +
                    "Classic",
         | 
| 322 | 
            +
                    "Nearest",
         | 
| 323 | 
            +
                    45,
         | 
| 324 | 
            +
                    False,
         | 
| 325 | 
            +
                ],
         | 
| 326 | 
            +
                [
         | 
| 327 | 
            +
                    "a digital illustration of a movie poster titled 'Finding Emo', finding nemo parody poster, featuring a depressed cartoon clownfish with black emo hair, eyeliner, and piercings, bored expression, swimming in a dark underwater scene, in the background, movie title in a dripping, grungy font, moody blue and purple color palette",
         | 
| 328 | 
            +
                    "",
         | 
| 329 | 
            +
                    24,
         | 
| 330 | 
            +
                    3.5,
         | 
| 331 | 
            +
                    -1,
         | 
| 332 | 
            +
                    "None",
         | 
| 333 | 
            +
                    0.33,
         | 
| 334 | 
            +
                    "Euler a",
         | 
| 335 | 
            +
                    1152,
         | 
| 336 | 
            +
                    896,
         | 
| 337 | 
            +
                    "black-forest-labs/FLUX.1-dev",
         | 
| 338 | 
            +
                    "txt2img",
         | 
| 339 | 
            +
                    None,  # img conttol
         | 
| 340 | 
            +
                    1024,  # img resolution
         | 
| 341 | 
            +
                    0.35,  # strength
         | 
| 342 | 
            +
                    1.0,  # cn scale
         | 
| 343 | 
            +
                    0.0,  # cn start
         | 
| 344 | 
            +
                    1.0,  # cn end
         | 
| 345 | 
            +
                    "Classic",
         | 
| 346 | 
            +
                    None,
         | 
| 347 | 
            +
                    70,
         | 
| 348 | 
            +
                    True,
         | 
| 349 | 
            +
                ],
         | 
| 350 | 
            +
                [
         | 
| 351 | 
            +
                    "((masterpiece)), best quality, blonde disco girl, detailed face, realistic face, realistic hair, dynamic pose, pink pvc, intergalactic disco background, pastel lights, dynamic contrast, airbrush, fine detail, 70s vibe, midriff",
         | 
| 352 | 
            +
                    "(worst quality:1.2), (bad quality:1.2), (poor quality:1.2), (missing fingers:1.2), bad-artist-anime, bad-artist, bad-picture-chill-75v",
         | 
| 353 | 
            +
                    48,
         | 
| 354 | 
            +
                    3.5,
         | 
| 355 | 
            +
                    -1,
         | 
| 356 | 
            +
                    "None",
         | 
| 357 | 
            +
                    0.33,
         | 
| 358 | 
            +
                    "DPM++ 2M SDE Lu",
         | 
| 359 | 
            +
                    1024,
         | 
| 360 | 
            +
                    1024,
         | 
| 361 | 
            +
                    "misri/epicrealismXL_v7FinalDestination",
         | 
| 362 | 
            +
                    "canny ControlNet",
         | 
| 363 | 
            +
                    "image.webp",  # img conttol
         | 
| 364 | 
            +
                    1024,  # img resolution
         | 
| 365 | 
            +
                    0.35,  # strength
         | 
| 366 | 
            +
                    1.0,  # cn scale
         | 
| 367 | 
            +
                    0.0,  # cn start
         | 
| 368 | 
            +
                    1.0,  # cn end
         | 
| 369 | 
            +
                    "Classic",
         | 
| 370 | 
            +
                    None,
         | 
| 371 | 
            +
                    44,
         | 
| 372 | 
            +
                    False,
         | 
| 373 | 
            +
                ],
         | 
| 374 | 
            +
                [
         | 
| 375 | 
            +
                    "cinematic scenery old city ruins",
         | 
| 376 | 
            +
                    "(worst quality, low quality, illustration, 3d, 2d, painting, cartoons, sketch), (illustration, 3d, 2d, painting, cartoons, sketch, blurry, film grain, noise), (low quality, worst quality:1.2)",
         | 
| 377 | 
            +
                    50,
         | 
| 378 | 
            +
                    4.0,
         | 
| 379 | 
            +
                    -1,
         | 
| 380 | 
            +
                    "None",
         | 
| 381 | 
            +
                    0.33,
         | 
| 382 | 
            +
                    "Euler a",
         | 
| 383 | 
            +
                    1024,
         | 
| 384 | 
            +
                    1024,
         | 
| 385 | 
            +
                    "misri/juggernautXL_juggernautX",
         | 
| 386 | 
            +
                    "optical pattern ControlNet",
         | 
| 387 | 
            +
                    "spiral_no_transparent.png",  # img conttol
         | 
| 388 | 
            +
                    1024,  # img resolution
         | 
| 389 | 
            +
                    0.35,  # strength
         | 
| 390 | 
            +
                    1.0,  # cn scale
         | 
| 391 | 
            +
                    0.05,  # cn start
         | 
| 392 | 
            +
                    0.75,  # cn end
         | 
| 393 | 
            +
                    "Classic",
         | 
| 394 | 
            +
                    None,
         | 
| 395 | 
            +
                    35,
         | 
| 396 | 
            +
                    False,
         | 
| 397 | 
            +
                ],
         | 
| 398 | 
            +
                [
         | 
| 399 | 
            +
                    "black and white, line art, coloring drawing, clean line art, black strokes, no background, white, black, free lines, black scribbles, on paper, A blend of comic book art and lineart full of black and white color, masterpiece, high-resolution, trending on Pixiv fan box, palette knife, brush strokes, two-dimensional, planar vector, T-shirt design, stickers, and T-shirt design, vector art, fantasy art, Adobe Illustrator, hand-painted, digital painting, low polygon, soft lighting, aerial view, isometric style, retro aesthetics, 8K resolution, black sketch lines, monochrome, invert color",
         | 
| 400 | 
            +
                    "color, red, green, yellow, colored, duplicate, blurry, abstract, disfigured, deformed, animated, toy, figure, framed, 3d, bad art, poorly drawn, extra limbs, close up, b&w, weird colors, blurry, watermark, blur haze, 2 heads, long neck, watermark, elongated body, cropped image, out of frame, draft, deformed hands, twisted fingers, double image, malformed hands, multiple heads, extra limb, ugly, poorly drawn hands, missing limb, cut-off, over satured, grain, lowères, bad anatomy, poorly drawn face, mutation, mutated, floating limbs, disconnected limbs, out of focus, long body, disgusting, extra fingers, groos proportions, missing arms, mutated hands, cloned face, missing legs, ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, signature, cut off, draft, deformed, blurry, bad anatomy, disfigured, poorly drawn face, mutation, bluelish, blue",
         | 
| 401 | 
            +
                    20,
         | 
| 402 | 
            +
                    4.0,
         | 
| 403 | 
            +
                    -1,
         | 
| 404 | 
            +
                    "loras/Coloring_book_-_LineArt.safetensors",
         | 
| 405 | 
            +
                    1.0,
         | 
| 406 | 
            +
                    "DPM++ 2M SDE Karras",
         | 
| 407 | 
            +
                    1024,
         | 
| 408 | 
            +
                    1024,
         | 
| 409 | 
            +
                    "cagliostrolab/animagine-xl-3.1",
         | 
| 410 | 
            +
                    "lineart ControlNet",
         | 
| 411 | 
            +
                    "color_image.png",  # img conttol
         | 
| 412 | 
            +
                    896,  # img resolution
         | 
| 413 | 
            +
                    0.35,  # strength
         | 
| 414 | 
            +
                    1.0,  # cn scale
         | 
| 415 | 
            +
                    0.0,  # cn start
         | 
| 416 | 
            +
                    1.0,  # cn end
         | 
| 417 | 
            +
                    "Compel",
         | 
| 418 | 
            +
                    None,
         | 
| 419 | 
            +
                    35,
         | 
| 420 | 
            +
                    False,
         | 
| 421 | 
            +
                ],
         | 
| 422 | 
            +
                [
         | 
| 423 | 
            +
                    "1girl,face,curly hair,red hair,white background,",
         | 
| 424 | 
            +
                    "(worst quality:2),(low quality:2),(normal quality:2),lowres,watermark,",
         | 
| 425 | 
            +
                    38,
         | 
| 426 | 
            +
                    5.0,
         | 
| 427 | 
            +
                    -1,
         | 
| 428 | 
            +
                    "None",
         | 
| 429 | 
            +
                    0.33,
         | 
| 430 | 
            +
                    "DPM++ 2M SDE Karras",
         | 
| 431 | 
            +
                    512,
         | 
| 432 | 
            +
                    512,
         | 
| 433 | 
            +
                    "digiplay/majicMIX_realistic_v7",
         | 
| 434 | 
            +
                    "openpose ControlNet",
         | 
| 435 | 
            +
                    "image.webp",  # img conttol
         | 
| 436 | 
            +
                    1024,  # img resolution
         | 
| 437 | 
            +
                    0.35,  # strength
         | 
| 438 | 
            +
                    1.0,  # cn scale
         | 
| 439 | 
            +
                    0.0,  # cn start
         | 
| 440 | 
            +
                    0.9,  # cn end
         | 
| 441 | 
            +
                    "Compel",
         | 
| 442 | 
            +
                    "Latent (antialiased)",
         | 
| 443 | 
            +
                    46,
         | 
| 444 | 
            +
                    False,
         | 
| 445 | 
            +
                ],
         | 
| 446 | 
            +
            ]
         | 
| 447 | 
            +
             | 
| 448 | 
            +
            RESOURCES = (
         | 
| 449 | 
            +
                """### Resources
         | 
| 450 | 
            +
                - John6666's space has some great features you might find helpful [link](https://huggingface.co/spaces/John6666/DiffuseCraftMod).
         | 
| 451 | 
            +
                - You can also try the image generator in Colab’s free tier, which provides free GPU [link](https://github.com/R3gm/SD_diffusers_interactive).
         | 
| 452 | 
            +
                """
         | 
| 453 | 
            +
            )
         | 
    	
        env.py
    CHANGED
    
    | @@ -2,10 +2,10 @@ import os | |
| 2 |  | 
| 3 | 
             
            CIVITAI_API_KEY = os.environ.get("CIVITAI_API_KEY")
         | 
| 4 | 
             
            HF_TOKEN = os.environ.get("HF_TOKEN")
         | 
| 5 | 
            -
             | 
| 6 |  | 
| 7 | 
             
            # - **List Models**
         | 
| 8 | 
            -
             | 
| 9 | 
             
                'stabilityai/stable-diffusion-xl-base-1.0',
         | 
| 10 | 
             
                'John6666/blue-pencil-flux1-v021-fp8-flux',
         | 
| 11 | 
             
                'John6666/wai-ani-flux-v10forfp8-fp8-flux',
         | 
| @@ -106,11 +106,11 @@ HF_MODEL_USER_EX = ["John6666"] # sorted by a special rule | |
| 106 |  | 
| 107 |  | 
| 108 | 
             
            # - **Download Models**
         | 
| 109 | 
            -
             | 
| 110 | 
             
            ]
         | 
| 111 |  | 
| 112 | 
             
            # - **Download VAEs**
         | 
| 113 | 
            -
             | 
| 114 | 
             
                'https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/sdxl.vae.safetensors?download=true',
         | 
| 115 | 
             
                'https://huggingface.co/nubby/blessed-sdxl-vae-fp16-fix/resolve/main/sdxl_vae-fp16fix-c-1.1-b-0.5.safetensors?download=true',
         | 
| 116 | 
             
                'https://huggingface.co/nubby/blessed-sdxl-vae-fp16-fix/resolve/main/sdxl_vae-fp16fix-blessed.safetensors?download=true',
         | 
| @@ -119,29 +119,26 @@ download_vae_list = [ | |
| 119 | 
             
            ]
         | 
| 120 |  | 
| 121 | 
             
            # - **Download LoRAs**
         | 
| 122 | 
            -
             | 
| 123 | 
             
            ]
         | 
| 124 |  | 
| 125 | 
             
            # Download Embeddings
         | 
| 126 | 
            -
             | 
| 127 | 
             
                'https://huggingface.co/datasets/Nerfgun3/bad_prompt/blob/main/bad_prompt_version2.pt',
         | 
| 128 | 
             
                'https://huggingface.co/embed/negative/resolve/main/EasyNegativeV2.safetensors',
         | 
| 129 | 
             
                'https://huggingface.co/embed/negative/resolve/main/bad-hands-5.pt',
         | 
| 130 | 
             
            ]
         | 
| 131 |  | 
| 132 | 
            -
             | 
| 133 | 
            -
             | 
| 134 | 
            -
             | 
| 135 | 
            -
             | 
| 136 | 
            -
             | 
| 137 | 
            -
             | 
| 138 | 
            -
            directory_embeds = 'embedings'
         | 
| 139 | 
            -
            os.makedirs(directory_embeds, exist_ok=True)
         | 
| 140 |  | 
| 141 | 
            -
             | 
| 142 | 
            -
             | 
| 143 | 
            -
             | 
| 144 | 
            -
            os.makedirs(directory_embeds_positive_sdxl, exist_ok=True)
         | 
| 145 |  | 
| 146 | 
             
            HF_LORA_PRIVATE_REPOS1 = ['John6666/loratest1', 'John6666/loratest3', 'John6666/loratest4', 'John6666/loratest6']
         | 
| 147 | 
             
            HF_LORA_PRIVATE_REPOS2 = ['John6666/loratest10', 'John6666/loratest11','John6666/loratest'] # to be sorted as 1 repo
         | 
|  | |
| 2 |  | 
| 3 | 
             
            CIVITAI_API_KEY = os.environ.get("CIVITAI_API_KEY")
         | 
| 4 | 
             
            HF_TOKEN = os.environ.get("HF_TOKEN")
         | 
| 5 | 
            +
            HF_READ_TOKEN = os.environ.get('HF_READ_TOKEN') # only use for private repo
         | 
| 6 |  | 
| 7 | 
             
            # - **List Models**
         | 
| 8 | 
            +
            LOAD_DIFFUSERS_FORMAT_MODEL = [
         | 
| 9 | 
             
                'stabilityai/stable-diffusion-xl-base-1.0',
         | 
| 10 | 
             
                'John6666/blue-pencil-flux1-v021-fp8-flux',
         | 
| 11 | 
             
                'John6666/wai-ani-flux-v10forfp8-fp8-flux',
         | 
|  | |
| 106 |  | 
| 107 |  | 
| 108 | 
             
            # - **Download Models**
         | 
| 109 | 
            +
            DOWNLOAD_MODEL_LIST = [
         | 
| 110 | 
             
            ]
         | 
| 111 |  | 
| 112 | 
             
            # - **Download VAEs**
         | 
| 113 | 
            +
            DOWNLOAD_VAE_LIST = [
         | 
| 114 | 
             
                'https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/sdxl.vae.safetensors?download=true',
         | 
| 115 | 
             
                'https://huggingface.co/nubby/blessed-sdxl-vae-fp16-fix/resolve/main/sdxl_vae-fp16fix-c-1.1-b-0.5.safetensors?download=true',
         | 
| 116 | 
             
                'https://huggingface.co/nubby/blessed-sdxl-vae-fp16-fix/resolve/main/sdxl_vae-fp16fix-blessed.safetensors?download=true',
         | 
|  | |
| 119 | 
             
            ]
         | 
| 120 |  | 
| 121 | 
             
            # - **Download LoRAs**
         | 
| 122 | 
            +
            DOWNLOAD_LORA_LIST = [
         | 
| 123 | 
             
            ]
         | 
| 124 |  | 
| 125 | 
             
            # Download Embeddings
         | 
| 126 | 
            +
            DOWNLOAD_EMBEDS = [
         | 
| 127 | 
             
                'https://huggingface.co/datasets/Nerfgun3/bad_prompt/blob/main/bad_prompt_version2.pt',
         | 
| 128 | 
             
                'https://huggingface.co/embed/negative/resolve/main/EasyNegativeV2.safetensors',
         | 
| 129 | 
             
                'https://huggingface.co/embed/negative/resolve/main/bad-hands-5.pt',
         | 
| 130 | 
             
            ]
         | 
| 131 |  | 
| 132 | 
            +
            DIRECTORY_MODELS = 'models'
         | 
| 133 | 
            +
            DIRECTORY_LORAS = 'loras'
         | 
| 134 | 
            +
            DIRECTORY_VAES = 'vaes'
         | 
| 135 | 
            +
            DIRECTORY_EMBEDS = 'embedings'
         | 
| 136 | 
            +
            DIRECTORY_EMBEDS_SDXL = 'embedings_xl'
         | 
| 137 | 
            +
            DIRECTORY_EMBEDS_POSITIVE_SDXL = 'embedings_xl/positive'
         | 
|  | |
|  | |
| 138 |  | 
| 139 | 
            +
            directories = [DIRECTORY_MODELS, DIRECTORY_LORAS, DIRECTORY_VAES, DIRECTORY_EMBEDS, DIRECTORY_EMBEDS_SDXL, DIRECTORY_EMBEDS_POSITIVE_SDXL]
         | 
| 140 | 
            +
            for directory in directories:
         | 
| 141 | 
            +
                os.makedirs(directory, exist_ok=True)
         | 
|  | |
| 142 |  | 
| 143 | 
             
            HF_LORA_PRIVATE_REPOS1 = ['John6666/loratest1', 'John6666/loratest3', 'John6666/loratest4', 'John6666/loratest6']
         | 
| 144 | 
             
            HF_LORA_PRIVATE_REPOS2 = ['John6666/loratest10', 'John6666/loratest11','John6666/loratest'] # to be sorted as 1 repo
         | 
    	
        modutils.py
    CHANGED
    
    | @@ -5,6 +5,7 @@ import os | |
| 5 | 
             
            import re
         | 
| 6 | 
             
            from pathlib import Path
         | 
| 7 | 
             
            from PIL import Image
         | 
|  | |
| 8 | 
             
            import shutil
         | 
| 9 | 
             
            import requests
         | 
| 10 | 
             
            from requests.adapters import HTTPAdapter
         | 
| @@ -12,11 +13,16 @@ from urllib3.util import Retry | |
| 12 | 
             
            import urllib.parse
         | 
| 13 | 
             
            import pandas as pd
         | 
| 14 | 
             
            from huggingface_hub import HfApi, HfFolder, hf_hub_download, snapshot_download
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 15 |  | 
| 16 |  | 
| 17 | 
             
            from env import (HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2,
         | 
| 18 | 
             
                HF_MODEL_USER_EX, HF_MODEL_USER_LIKES, DIFFUSERS_FORMAT_LORAS,
         | 
| 19 | 
            -
                 | 
| 20 |  | 
| 21 |  | 
| 22 | 
             
            MODEL_TYPE_DICT = {
         | 
| @@ -46,7 +52,6 @@ def is_repo_name(s): | |
| 46 | 
             
                return re.fullmatch(r'^[^/]+?/[^/]+?$', s)
         | 
| 47 |  | 
| 48 |  | 
| 49 | 
            -
            from translatepy import Translator
         | 
| 50 | 
             
            translator = Translator()
         | 
| 51 | 
             
            def translate_to_en(input: str):
         | 
| 52 | 
             
                try:
         | 
| @@ -64,6 +69,7 @@ def get_local_model_list(dir_path): | |
| 64 | 
             
                    if file.suffix in valid_extensions:
         | 
| 65 | 
             
                        file_path = str(Path(f"{dir_path}/{file.name}"))
         | 
| 66 | 
             
                        model_list.append(file_path)
         | 
|  | |
| 67 | 
             
                return model_list
         | 
| 68 |  | 
| 69 |  | 
| @@ -98,21 +104,81 @@ def split_hf_url(url: str): | |
| 98 | 
             
                    print(e)
         | 
| 99 |  | 
| 100 |  | 
| 101 | 
            -
            def download_hf_file(directory, url, progress=gr.Progress(track_tqdm=True)):
         | 
| 102 | 
            -
                hf_token = get_token()
         | 
| 103 | 
             
                repo_id, filename, subfolder, repo_type = split_hf_url(url)
         | 
|  | |
|  | |
|  | |
| 104 | 
             
                try:
         | 
| 105 | 
            -
                    print(f" | 
| 106 | 
            -
                     | 
| 107 | 
            -
                    else: path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type=repo_type, local_dir=directory, token=hf_token)
         | 
| 108 | 
             
                    return path
         | 
| 109 | 
             
                except Exception as e:
         | 
| 110 | 
            -
                    print(f" | 
| 111 | 
             
                    return None
         | 
| 112 |  | 
| 113 |  | 
| 114 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 115 | 
             
                url = url.strip()
         | 
|  | |
|  | |
| 116 | 
             
                if "drive.google.com" in url:
         | 
| 117 | 
             
                    original_dir = os.getcwd()
         | 
| 118 | 
             
                    os.chdir(directory)
         | 
| @@ -123,18 +189,48 @@ def download_things(directory, url, hf_token="", civitai_api_key=""): | |
| 123 | 
             
                    # url = urllib.parse.quote(url, safe=':/')  # fix encoding
         | 
| 124 | 
             
                    if "/blob/" in url:
         | 
| 125 | 
             
                        url = url.replace("/blob/", "/resolve/")
         | 
| 126 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 127 | 
             
                elif "civitai.com" in url:
         | 
| 128 | 
            -
             | 
| 129 | 
            -
             | 
| 130 | 
            -
                    if civitai_api_key:
         | 
| 131 | 
            -
                        url = url + f"?token={civitai_api_key}"
         | 
| 132 | 
            -
                        os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
         | 
| 133 | 
            -
                    else:
         | 
| 134 | 
             
                        print("\033[91mYou need an API key to download Civitai models.\033[0m")
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 135 | 
             
                else:
         | 
| 136 | 
             
                    os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
         | 
| 137 |  | 
|  | |
|  | |
| 138 |  | 
| 139 | 
             
            def get_download_file(temp_dir, url, civitai_key="", progress=gr.Progress(track_tqdm=True)):
         | 
| 140 | 
             
                if not "http" in url and is_repo_name(url) and not Path(url).exists():
         | 
| @@ -173,7 +269,7 @@ def to_lora_key(path: str): | |
| 173 |  | 
| 174 | 
             
            def to_lora_path(key: str):
         | 
| 175 | 
             
                if Path(key).is_file(): return key
         | 
| 176 | 
            -
                path = Path(f"{ | 
| 177 | 
             
                return str(path)
         | 
| 178 |  | 
| 179 |  | 
| @@ -203,25 +299,21 @@ def save_images(images: list[Image.Image], metadatas: list[str]): | |
| 203 | 
             
                    raise Exception(f"Failed to save image file:") from e
         | 
| 204 |  | 
| 205 |  | 
| 206 | 
            -
            def save_gallery_images(images, progress=gr.Progress(track_tqdm=True)):
         | 
| 207 | 
            -
                from datetime import datetime, timezone, timedelta
         | 
| 208 | 
             
                progress(0, desc="Updating gallery...")
         | 
| 209 | 
            -
                 | 
| 210 | 
            -
                 | 
| 211 | 
            -
                i = 1
         | 
| 212 | 
            -
                if not images: return images, gr.update(visible=False)
         | 
| 213 | 
             
                output_images = []
         | 
| 214 | 
             
                output_paths = []
         | 
| 215 | 
            -
                for image in images:
         | 
| 216 | 
            -
                    filename = basename | 
| 217 | 
            -
                    i += 1
         | 
| 218 | 
             
                    oldpath = Path(image[0])
         | 
| 219 | 
             
                    newpath = oldpath
         | 
| 220 | 
             
                    try:
         | 
| 221 | 
             
                        if oldpath.exists():
         | 
| 222 | 
             
                            newpath = oldpath.resolve().rename(Path(filename).resolve())
         | 
| 223 | 
             
                    except Exception as e:
         | 
| 224 | 
            -
             | 
| 225 | 
             
                    finally: 
         | 
| 226 | 
             
                        output_paths.append(str(newpath))
         | 
| 227 | 
             
                        output_images.append((str(newpath), str(filename)))
         | 
| @@ -229,10 +321,47 @@ def save_gallery_images(images, progress=gr.Progress(track_tqdm=True)): | |
| 229 | 
             
                return gr.update(value=output_images), gr.update(value=output_paths, visible=True)
         | 
| 230 |  | 
| 231 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 232 | 
             
            def download_private_repo(repo_id, dir_path, is_replace):
         | 
| 233 | 
            -
                if not  | 
| 234 | 
             
                try:
         | 
| 235 | 
            -
                    snapshot_download(repo_id=repo_id, local_dir=dir_path, allow_patterns=['*.ckpt', '*.pt', '*.pth', '*.safetensors', '*.bin'],  | 
| 236 | 
             
                except Exception as e:
         | 
| 237 | 
             
                    print(f"Error: Failed to download {repo_id}.")
         | 
| 238 | 
             
                    print(e)
         | 
| @@ -250,9 +379,9 @@ private_model_path_repo_dict = {} # {"local filepath": "huggingface repo_id", .. | |
| 250 | 
             
            def get_private_model_list(repo_id, dir_path):
         | 
| 251 | 
             
                global private_model_path_repo_dict
         | 
| 252 | 
             
                api = HfApi()
         | 
| 253 | 
            -
                if not  | 
| 254 | 
             
                try:
         | 
| 255 | 
            -
                    files = api.list_repo_files(repo_id, token= | 
| 256 | 
             
                except Exception as e:
         | 
| 257 | 
             
                    print(f"Error: Failed to list {repo_id}.")
         | 
| 258 | 
             
                    print(e)
         | 
| @@ -270,11 +399,11 @@ def get_private_model_list(repo_id, dir_path): | |
| 270 | 
             
            def download_private_file(repo_id, path, is_replace):
         | 
| 271 | 
             
                file = Path(path)
         | 
| 272 | 
             
                newpath = Path(f'{file.parent.name}/{escape_lora_basename(file.stem)}{file.suffix}') if is_replace else file
         | 
| 273 | 
            -
                if not  | 
| 274 | 
             
                filename = file.name
         | 
| 275 | 
             
                dirname = file.parent.name
         | 
| 276 | 
             
                try:
         | 
| 277 | 
            -
                    hf_hub_download(repo_id=repo_id, filename=filename, local_dir=dirname,  | 
| 278 | 
             
                except Exception as e:
         | 
| 279 | 
             
                    print(f"Error: Failed to download {filename}.")
         | 
| 280 | 
             
                    print(e)
         | 
| @@ -404,9 +533,9 @@ def get_private_lora_model_lists(): | |
| 404 | 
             
                models1 = []
         | 
| 405 | 
             
                models2 = []
         | 
| 406 | 
             
                for repo in HF_LORA_PRIVATE_REPOS1:
         | 
| 407 | 
            -
                    models1.extend(get_private_model_list(repo,  | 
| 408 | 
             
                for repo in HF_LORA_PRIVATE_REPOS2:
         | 
| 409 | 
            -
                    models2.extend(get_private_model_list(repo,  | 
| 410 | 
             
                models = list_uniq(models1 + sorted(models2))
         | 
| 411 | 
             
                private_lora_model_list = models.copy()
         | 
| 412 | 
             
                return models
         | 
| @@ -451,7 +580,7 @@ def get_civitai_info(path): | |
| 451 |  | 
| 452 |  | 
| 453 | 
             
            def get_lora_model_list():
         | 
| 454 | 
            -
                loras = list_uniq(get_private_lora_model_lists() + DIFFUSERS_FORMAT_LORAS + get_local_model_list( | 
| 455 | 
             
                loras.insert(0, "None")
         | 
| 456 | 
             
                loras.insert(0, "")
         | 
| 457 | 
             
                return loras
         | 
| @@ -503,14 +632,14 @@ def update_lora_dict(path): | |
| 503 | 
             
            def download_lora(dl_urls: str):
         | 
| 504 | 
             
                global loras_url_to_path_dict
         | 
| 505 | 
             
                dl_path = ""
         | 
| 506 | 
            -
                before = get_local_model_list( | 
| 507 | 
             
                urls = []
         | 
| 508 | 
             
                for url in [url.strip() for url in dl_urls.split(',')]:
         | 
| 509 | 
            -
                    local_path = f"{ | 
| 510 | 
             
                    if not Path(local_path).exists():
         | 
| 511 | 
            -
                        download_things( | 
| 512 | 
             
                        urls.append(url)
         | 
| 513 | 
            -
                after = get_local_model_list( | 
| 514 | 
             
                new_files = list_sub(after, before)
         | 
| 515 | 
             
                i = 0
         | 
| 516 | 
             
                for file in new_files:
         | 
| @@ -761,12 +890,14 @@ def update_loras(prompt, prompt_syntax, lora1, lora1_wt, lora2, lora2_wt, lora3, | |
| 761 | 
             
                 gr.update(value=tag5, label=label5, visible=on5), gr.update(visible=on5), gr.update(value=md5, visible=on5)
         | 
| 762 |  | 
| 763 |  | 
| 764 | 
            -
            def get_my_lora(link_url):
         | 
| 765 | 
            -
                 | 
|  | |
|  | |
| 766 | 
             
                for url in [url.strip() for url in link_url.split(',')]:
         | 
| 767 | 
            -
                    if not Path(f"{ | 
| 768 | 
            -
                        download_things( | 
| 769 | 
            -
                after = get_local_model_list( | 
| 770 | 
             
                new_files = list_sub(after, before)
         | 
| 771 | 
             
                for file in new_files:
         | 
| 772 | 
             
                    path = Path(file)
         | 
| @@ -774,11 +905,16 @@ def get_my_lora(link_url): | |
| 774 | 
             
                        new_path = Path(f'{path.parent.name}/{escape_lora_basename(path.stem)}{path.suffix}')
         | 
| 775 | 
             
                        path.resolve().rename(new_path.resolve())
         | 
| 776 | 
             
                        update_lora_dict(str(new_path))
         | 
|  | |
| 777 | 
             
                new_lora_model_list = get_lora_model_list()
         | 
| 778 | 
             
                new_lora_tupled_list = get_all_lora_tupled_list()
         | 
| 779 | 
            -
                
         | 
|  | |
|  | |
|  | |
|  | |
| 780 | 
             
                return gr.update(
         | 
| 781 | 
            -
                    choices=new_lora_tupled_list, value= | 
| 782 | 
             
                ), gr.update(
         | 
| 783 | 
             
                    choices=new_lora_tupled_list
         | 
| 784 | 
             
                ), gr.update(
         | 
| @@ -787,6 +923,8 @@ def get_my_lora(link_url): | |
| 787 | 
             
                    choices=new_lora_tupled_list
         | 
| 788 | 
             
                ), gr.update(
         | 
| 789 | 
             
                    choices=new_lora_tupled_list
         | 
|  | |
|  | |
| 790 | 
             
                )
         | 
| 791 |  | 
| 792 |  | 
| @@ -794,12 +932,12 @@ def upload_file_lora(files, progress=gr.Progress(track_tqdm=True)): | |
| 794 | 
             
                progress(0, desc="Uploading...")
         | 
| 795 | 
             
                file_paths = [file.name for file in files]
         | 
| 796 | 
             
                progress(1, desc="Uploaded.")
         | 
| 797 | 
            -
                return gr.update(value=file_paths, visible=True), gr.update( | 
| 798 |  | 
| 799 |  | 
| 800 | 
             
            def move_file_lora(filepaths):
         | 
| 801 | 
             
                for file in filepaths:
         | 
| 802 | 
            -
                    path = Path(shutil.move(Path(file).resolve(), Path(f"./{ | 
| 803 | 
             
                    newpath = Path(f'{path.parent.name}/{escape_lora_basename(path.stem)}{path.suffix}')
         | 
| 804 | 
             
                    path.resolve().rename(newpath.resolve())
         | 
| 805 | 
             
                    update_lora_dict(str(newpath))
         | 
| @@ -941,7 +1079,7 @@ def update_civitai_selection(evt: gr.SelectData): | |
| 941 | 
             
                    selected = civitai_last_choices[selected_index][1]
         | 
| 942 | 
             
                    return gr.update(value=selected)
         | 
| 943 | 
             
                except Exception:
         | 
| 944 | 
            -
                    return gr.update( | 
| 945 |  | 
| 946 |  | 
| 947 | 
             
            def select_civitai_lora(search_result):
         | 
| @@ -1425,3 +1563,78 @@ def get_model_pipeline(repo_id: str): | |
| 1425 | 
             
                else:
         | 
| 1426 | 
             
                    return default
         | 
| 1427 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 5 | 
             
            import re
         | 
| 6 | 
             
            from pathlib import Path
         | 
| 7 | 
             
            from PIL import Image
         | 
| 8 | 
            +
            import numpy as np
         | 
| 9 | 
             
            import shutil
         | 
| 10 | 
             
            import requests
         | 
| 11 | 
             
            from requests.adapters import HTTPAdapter
         | 
|  | |
| 13 | 
             
            import urllib.parse
         | 
| 14 | 
             
            import pandas as pd
         | 
| 15 | 
             
            from huggingface_hub import HfApi, HfFolder, hf_hub_download, snapshot_download
         | 
| 16 | 
            +
            from translatepy import Translator
         | 
| 17 | 
            +
            from unidecode import unidecode
         | 
| 18 | 
            +
            import copy
         | 
| 19 | 
            +
            from datetime import datetime, timezone, timedelta
         | 
| 20 | 
            +
            FILENAME_TIMEZONE = timezone(timedelta(hours=9)) # JST
         | 
| 21 |  | 
| 22 |  | 
| 23 | 
             
            from env import (HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2,
         | 
| 24 | 
             
                HF_MODEL_USER_EX, HF_MODEL_USER_LIKES, DIFFUSERS_FORMAT_LORAS,
         | 
| 25 | 
            +
                DIRECTORY_LORAS, HF_READ_TOKEN, HF_TOKEN, CIVITAI_API_KEY)
         | 
| 26 |  | 
| 27 |  | 
| 28 | 
             
            MODEL_TYPE_DICT = {
         | 
|  | |
| 52 | 
             
                return re.fullmatch(r'^[^/]+?/[^/]+?$', s)
         | 
| 53 |  | 
| 54 |  | 
|  | |
| 55 | 
             
            translator = Translator()
         | 
| 56 | 
             
            def translate_to_en(input: str):
         | 
| 57 | 
             
                try:
         | 
|  | |
| 69 | 
             
                    if file.suffix in valid_extensions:
         | 
| 70 | 
             
                        file_path = str(Path(f"{dir_path}/{file.name}"))
         | 
| 71 | 
             
                        model_list.append(file_path)
         | 
| 72 | 
            +
                        #print('\033[34mFILE: ' + file_path + '\033[0m')
         | 
| 73 | 
             
                return model_list
         | 
| 74 |  | 
| 75 |  | 
|  | |
| 104 | 
             
                    print(e)
         | 
| 105 |  | 
| 106 |  | 
| 107 | 
            +
            def download_hf_file(directory, url, force_filename="", hf_token="", progress=gr.Progress(track_tqdm=True)):
         | 
|  | |
| 108 | 
             
                repo_id, filename, subfolder, repo_type = split_hf_url(url)
         | 
| 109 | 
            +
                kwargs = {}
         | 
| 110 | 
            +
                if subfolder is not None: kwargs["subfolder"] = subfolder
         | 
| 111 | 
            +
                if force_filename: kwargs["force_filename"] = force_filename
         | 
| 112 | 
             
                try:
         | 
| 113 | 
            +
                    print(f"Start downloading: {url} to {directory}")
         | 
| 114 | 
            +
                    path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type=repo_type, local_dir=directory, token=hf_token, **kwargs)
         | 
|  | |
| 115 | 
             
                    return path
         | 
| 116 | 
             
                except Exception as e:
         | 
| 117 | 
            +
                    print(f"Download failed: {url} {e}")
         | 
| 118 | 
             
                    return None
         | 
| 119 |  | 
| 120 |  | 
| 121 | 
            +
            USER_AGENT = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:127.0) Gecko/20100101 Firefox/127.0'
         | 
| 122 | 
            +
             | 
| 123 | 
            +
             | 
| 124 | 
            +
            def request_json_data(url):
         | 
| 125 | 
            +
                model_version_id = url.split('/')[-1]
         | 
| 126 | 
            +
                if "?modelVersionId=" in model_version_id:
         | 
| 127 | 
            +
                    match = re.search(r'modelVersionId=(\d+)', url)
         | 
| 128 | 
            +
                    model_version_id = match.group(1)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                endpoint_url = f"https://civitai.com/api/v1/model-versions/{model_version_id}"
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                params = {}
         | 
| 133 | 
            +
                headers = {'User-Agent': USER_AGENT, 'content-type': 'application/json'}
         | 
| 134 | 
            +
                session = requests.Session()
         | 
| 135 | 
            +
                retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
         | 
| 136 | 
            +
                session.mount("https://", HTTPAdapter(max_retries=retries))
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                try:
         | 
| 139 | 
            +
                    result = session.get(endpoint_url, params=params, headers=headers, stream=True, timeout=(3.0, 15))
         | 
| 140 | 
            +
                    result.raise_for_status()
         | 
| 141 | 
            +
                    json_data = result.json()
         | 
| 142 | 
            +
                    return json_data if json_data else None
         | 
| 143 | 
            +
                except Exception as e:
         | 
| 144 | 
            +
                    print(f"Error: {e}")
         | 
| 145 | 
            +
                    return None
         | 
| 146 | 
            +
             | 
| 147 | 
            +
             | 
| 148 | 
            +
            class ModelInformation:
         | 
| 149 | 
            +
                def __init__(self, json_data):
         | 
| 150 | 
            +
                    self.model_version_id = json_data.get("id", "")
         | 
| 151 | 
            +
                    self.model_id = json_data.get("modelId", "")
         | 
| 152 | 
            +
                    self.download_url = json_data.get("downloadUrl", "")
         | 
| 153 | 
            +
                    self.model_url = f"https://civitai.com/models/{self.model_id}?modelVersionId={self.model_version_id}"
         | 
| 154 | 
            +
                    self.filename_url = next(
         | 
| 155 | 
            +
                        (v.get("name", "") for v in json_data.get("files", []) if str(self.model_version_id) in v.get("downloadUrl", "")), ""
         | 
| 156 | 
            +
                    )
         | 
| 157 | 
            +
                    self.filename_url = self.filename_url if self.filename_url else ""
         | 
| 158 | 
            +
                    self.description = json_data.get("description", "")
         | 
| 159 | 
            +
                    if self.description is None: self.description = ""
         | 
| 160 | 
            +
                    self.model_name = json_data.get("model", {}).get("name", "")
         | 
| 161 | 
            +
                    self.model_type = json_data.get("model", {}).get("type", "")
         | 
| 162 | 
            +
                    self.nsfw = json_data.get("model", {}).get("nsfw", False)
         | 
| 163 | 
            +
                    self.poi = json_data.get("model", {}).get("poi", False)
         | 
| 164 | 
            +
                    self.images = [img.get("url", "") for img in json_data.get("images", [])]
         | 
| 165 | 
            +
                    self.example_prompt = json_data.get("trainedWords", [""])[0] if json_data.get("trainedWords") else ""
         | 
| 166 | 
            +
                    self.original_json = copy.deepcopy(json_data)
         | 
| 167 | 
            +
             | 
| 168 | 
            +
             | 
| 169 | 
            +
            def retrieve_model_info(url):
         | 
| 170 | 
            +
                json_data = request_json_data(url)
         | 
| 171 | 
            +
                if not json_data:
         | 
| 172 | 
            +
                    return None
         | 
| 173 | 
            +
                model_descriptor = ModelInformation(json_data)
         | 
| 174 | 
            +
                return model_descriptor
         | 
| 175 | 
            +
             | 
| 176 | 
            +
             | 
| 177 | 
            +
            def download_things(directory, url, hf_token="", civitai_api_key="", romanize=False):
         | 
| 178 | 
            +
                hf_token = get_token()
         | 
| 179 | 
             
                url = url.strip()
         | 
| 180 | 
            +
                downloaded_file_path = None
         | 
| 181 | 
            +
             | 
| 182 | 
             
                if "drive.google.com" in url:
         | 
| 183 | 
             
                    original_dir = os.getcwd()
         | 
| 184 | 
             
                    os.chdir(directory)
         | 
|  | |
| 189 | 
             
                    # url = urllib.parse.quote(url, safe=':/')  # fix encoding
         | 
| 190 | 
             
                    if "/blob/" in url:
         | 
| 191 | 
             
                        url = url.replace("/blob/", "/resolve/")
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                    filename = unidecode(url.split('/')[-1]) if romanize else url.split('/')[-1]
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                    download_hf_file(directory, url, filename, hf_token)
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                    downloaded_file_path = os.path.join(directory, filename)
         | 
| 198 | 
            +
             | 
| 199 | 
             
                elif "civitai.com" in url:
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                    if not civitai_api_key:
         | 
|  | |
|  | |
|  | |
|  | |
| 202 | 
             
                        print("\033[91mYou need an API key to download Civitai models.\033[0m")
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                    model_profile = retrieve_model_info(url)
         | 
| 205 | 
            +
                    if model_profile.download_url and model_profile.filename_url:
         | 
| 206 | 
            +
                        url = model_profile.download_url
         | 
| 207 | 
            +
                        filename = unidecode(model_profile.filename_url) if romanize else model_profile.filename_url
         | 
| 208 | 
            +
                    else:
         | 
| 209 | 
            +
                        if "?" in url:
         | 
| 210 | 
            +
                            url = url.split("?")[0]
         | 
| 211 | 
            +
                        filename = ""
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                    url_dl = url + f"?token={civitai_api_key}"
         | 
| 214 | 
            +
                    print(f"Filename: {filename}")
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                    param_filename = ""
         | 
| 217 | 
            +
                    if filename:
         | 
| 218 | 
            +
                        param_filename = f"-o '{filename}'"
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                    aria2_command = (
         | 
| 221 | 
            +
                        f'aria2c --console-log-level=error --summary-interval=10 -c -x 16 '
         | 
| 222 | 
            +
                        f'-k 1M -s 16 -d "{directory}" {param_filename} "{url_dl}"'
         | 
| 223 | 
            +
                    )
         | 
| 224 | 
            +
                    os.system(aria2_command)
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                    if param_filename and os.path.exists(os.path.join(directory, filename)):
         | 
| 227 | 
            +
                        downloaded_file_path = os.path.join(directory, filename)
         | 
| 228 | 
            +
             | 
| 229 | 
             
                else:
         | 
| 230 | 
             
                    os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
         | 
| 231 |  | 
| 232 | 
            +
                return downloaded_file_path
         | 
| 233 | 
            +
             | 
| 234 |  | 
| 235 | 
             
            def get_download_file(temp_dir, url, civitai_key="", progress=gr.Progress(track_tqdm=True)):
         | 
| 236 | 
             
                if not "http" in url and is_repo_name(url) and not Path(url).exists():
         | 
|  | |
| 269 |  | 
| 270 | 
             
            def to_lora_path(key: str):
         | 
| 271 | 
             
                if Path(key).is_file(): return key
         | 
| 272 | 
            +
                path = Path(f"{DIRECTORY_LORAS}/{escape_lora_basename(key)}.safetensors")
         | 
| 273 | 
             
                return str(path)
         | 
| 274 |  | 
| 275 |  | 
|  | |
| 299 | 
             
                    raise Exception(f"Failed to save image file:") from e
         | 
| 300 |  | 
| 301 |  | 
| 302 | 
            +
            def save_gallery_images(images, model_name="", progress=gr.Progress(track_tqdm=True)):
         | 
|  | |
| 303 | 
             
                progress(0, desc="Updating gallery...")
         | 
| 304 | 
            +
                basename = f"{model_name.split('/')[-1]}_{datetime.now(FILENAME_TIMEZONE).strftime('%Y%m%d_%H%M%S')}_"
         | 
| 305 | 
            +
                if not images: return images, gr.update()
         | 
|  | |
|  | |
| 306 | 
             
                output_images = []
         | 
| 307 | 
             
                output_paths = []
         | 
| 308 | 
            +
                for i, image in enumerate(images):
         | 
| 309 | 
            +
                    filename = f"{basename}{str(i + 1)}.png"
         | 
|  | |
| 310 | 
             
                    oldpath = Path(image[0])
         | 
| 311 | 
             
                    newpath = oldpath
         | 
| 312 | 
             
                    try:
         | 
| 313 | 
             
                        if oldpath.exists():
         | 
| 314 | 
             
                            newpath = oldpath.resolve().rename(Path(filename).resolve())
         | 
| 315 | 
             
                    except Exception as e:
         | 
| 316 | 
            +
                        print(e)
         | 
| 317 | 
             
                    finally: 
         | 
| 318 | 
             
                        output_paths.append(str(newpath))
         | 
| 319 | 
             
                        output_images.append((str(newpath), str(filename)))
         | 
|  | |
| 321 | 
             
                return gr.update(value=output_images), gr.update(value=output_paths, visible=True)
         | 
| 322 |  | 
| 323 |  | 
| 324 | 
            +
            def save_gallery_history(images, files, history_gallery, history_files, progress=gr.Progress(track_tqdm=True)):
         | 
| 325 | 
            +
                if not images or not files: return gr.update(), gr.update()
         | 
| 326 | 
            +
                if not history_gallery: history_gallery = []
         | 
| 327 | 
            +
                if not history_files: history_files = []
         | 
| 328 | 
            +
                output_gallery = images + history_gallery
         | 
| 329 | 
            +
                output_files = files + history_files
         | 
| 330 | 
            +
                return gr.update(value=output_gallery), gr.update(value=output_files, visible=True)
         | 
| 331 | 
            +
             | 
| 332 | 
            +
             | 
| 333 | 
            +
            def save_image_history(image, gallery, files, model_name: str, progress=gr.Progress(track_tqdm=True)):
         | 
| 334 | 
            +
                if not gallery: gallery = []
         | 
| 335 | 
            +
                if not files: files = []
         | 
| 336 | 
            +
                try:
         | 
| 337 | 
            +
                    basename = f"{model_name.split('/')[-1]}_{datetime.now(FILENAME_TIMEZONE).strftime('%Y%m%d_%H%M%S')}"
         | 
| 338 | 
            +
                    if image is None or not isinstance(image, (str, Image.Image, np.ndarray, tuple)): return gr.update(), gr.update()
         | 
| 339 | 
            +
                    filename = f"{basename}.png"
         | 
| 340 | 
            +
                    if isinstance(image, tuple): image = image[0]
         | 
| 341 | 
            +
                    if isinstance(image, str): oldpath = image
         | 
| 342 | 
            +
                    elif isinstance(image, Image.Image):
         | 
| 343 | 
            +
                        oldpath = "temp.png"
         | 
| 344 | 
            +
                        image.save(oldpath)
         | 
| 345 | 
            +
                    elif isinstance(image, np.ndarray):
         | 
| 346 | 
            +
                        oldpath = "temp.png"
         | 
| 347 | 
            +
                        Image.fromarray(image).convert('RGBA').save(oldpath)
         | 
| 348 | 
            +
                    oldpath = Path(oldpath)
         | 
| 349 | 
            +
                    newpath = oldpath
         | 
| 350 | 
            +
                    if oldpath.exists():
         | 
| 351 | 
            +
                        shutil.copy(oldpath.resolve(), Path(filename).resolve())
         | 
| 352 | 
            +
                        newpath = Path(filename).resolve()
         | 
| 353 | 
            +
                    files.insert(0, str(newpath))
         | 
| 354 | 
            +
                    gallery.insert(0, (str(newpath), str(filename)))
         | 
| 355 | 
            +
                except Exception as e:
         | 
| 356 | 
            +
                    print(e)
         | 
| 357 | 
            +
                finally: 
         | 
| 358 | 
            +
                    return gr.update(value=gallery), gr.update(value=files, visible=True)
         | 
| 359 | 
            +
             | 
| 360 | 
            +
             | 
| 361 | 
             
            def download_private_repo(repo_id, dir_path, is_replace):
         | 
| 362 | 
            +
                if not HF_READ_TOKEN: return
         | 
| 363 | 
             
                try:
         | 
| 364 | 
            +
                    snapshot_download(repo_id=repo_id, local_dir=dir_path, allow_patterns=['*.ckpt', '*.pt', '*.pth', '*.safetensors', '*.bin'], token=HF_READ_TOKEN)
         | 
| 365 | 
             
                except Exception as e:
         | 
| 366 | 
             
                    print(f"Error: Failed to download {repo_id}.")
         | 
| 367 | 
             
                    print(e)
         | 
|  | |
| 379 | 
             
            def get_private_model_list(repo_id, dir_path):
         | 
| 380 | 
             
                global private_model_path_repo_dict
         | 
| 381 | 
             
                api = HfApi()
         | 
| 382 | 
            +
                if not HF_READ_TOKEN: return []
         | 
| 383 | 
             
                try:
         | 
| 384 | 
            +
                    files = api.list_repo_files(repo_id, token=HF_READ_TOKEN)
         | 
| 385 | 
             
                except Exception as e:
         | 
| 386 | 
             
                    print(f"Error: Failed to list {repo_id}.")
         | 
| 387 | 
             
                    print(e)
         | 
|  | |
| 399 | 
             
            def download_private_file(repo_id, path, is_replace):
         | 
| 400 | 
             
                file = Path(path)
         | 
| 401 | 
             
                newpath = Path(f'{file.parent.name}/{escape_lora_basename(file.stem)}{file.suffix}') if is_replace else file
         | 
| 402 | 
            +
                if not HF_READ_TOKEN or newpath.exists(): return
         | 
| 403 | 
             
                filename = file.name
         | 
| 404 | 
             
                dirname = file.parent.name
         | 
| 405 | 
             
                try:
         | 
| 406 | 
            +
                    hf_hub_download(repo_id=repo_id, filename=filename, local_dir=dirname, token=HF_READ_TOKEN)
         | 
| 407 | 
             
                except Exception as e:
         | 
| 408 | 
             
                    print(f"Error: Failed to download {filename}.")
         | 
| 409 | 
             
                    print(e)
         | 
|  | |
| 533 | 
             
                models1 = []
         | 
| 534 | 
             
                models2 = []
         | 
| 535 | 
             
                for repo in HF_LORA_PRIVATE_REPOS1:
         | 
| 536 | 
            +
                    models1.extend(get_private_model_list(repo, DIRECTORY_LORAS))
         | 
| 537 | 
             
                for repo in HF_LORA_PRIVATE_REPOS2:
         | 
| 538 | 
            +
                    models2.extend(get_private_model_list(repo, DIRECTORY_LORAS))
         | 
| 539 | 
             
                models = list_uniq(models1 + sorted(models2))
         | 
| 540 | 
             
                private_lora_model_list = models.copy()
         | 
| 541 | 
             
                return models
         | 
|  | |
| 580 |  | 
| 581 |  | 
| 582 | 
             
            def get_lora_model_list():
         | 
| 583 | 
            +
                loras = list_uniq(get_private_lora_model_lists() + DIFFUSERS_FORMAT_LORAS + get_local_model_list(DIRECTORY_LORAS))
         | 
| 584 | 
             
                loras.insert(0, "None")
         | 
| 585 | 
             
                loras.insert(0, "")
         | 
| 586 | 
             
                return loras
         | 
|  | |
| 632 | 
             
            def download_lora(dl_urls: str):
         | 
| 633 | 
             
                global loras_url_to_path_dict
         | 
| 634 | 
             
                dl_path = ""
         | 
| 635 | 
            +
                before = get_local_model_list(DIRECTORY_LORAS)
         | 
| 636 | 
             
                urls = []
         | 
| 637 | 
             
                for url in [url.strip() for url in dl_urls.split(',')]:
         | 
| 638 | 
            +
                    local_path = f"{DIRECTORY_LORAS}/{url.split('/')[-1]}"
         | 
| 639 | 
             
                    if not Path(local_path).exists():
         | 
| 640 | 
            +
                        download_things(DIRECTORY_LORAS, url, HF_TOKEN, CIVITAI_API_KEY)
         | 
| 641 | 
             
                        urls.append(url)
         | 
| 642 | 
            +
                after = get_local_model_list(DIRECTORY_LORAS)
         | 
| 643 | 
             
                new_files = list_sub(after, before)
         | 
| 644 | 
             
                i = 0
         | 
| 645 | 
             
                for file in new_files:
         | 
|  | |
| 890 | 
             
                 gr.update(value=tag5, label=label5, visible=on5), gr.update(visible=on5), gr.update(value=md5, visible=on5)
         | 
| 891 |  | 
| 892 |  | 
| 893 | 
            +
            def get_my_lora(link_url, romanize):
         | 
| 894 | 
            +
                l_name = ""
         | 
| 895 | 
            +
                l_path = ""
         | 
| 896 | 
            +
                before = get_local_model_list(DIRECTORY_LORAS)
         | 
| 897 | 
             
                for url in [url.strip() for url in link_url.split(',')]:
         | 
| 898 | 
            +
                    if not Path(f"{DIRECTORY_LORAS}/{url.split('/')[-1]}").exists():
         | 
| 899 | 
            +
                        l_name = download_things(DIRECTORY_LORAS, url, HF_TOKEN, CIVITAI_API_KEY, romanize)
         | 
| 900 | 
            +
                after = get_local_model_list(DIRECTORY_LORAS)
         | 
| 901 | 
             
                new_files = list_sub(after, before)
         | 
| 902 | 
             
                for file in new_files:
         | 
| 903 | 
             
                    path = Path(file)
         | 
|  | |
| 905 | 
             
                        new_path = Path(f'{path.parent.name}/{escape_lora_basename(path.stem)}{path.suffix}')
         | 
| 906 | 
             
                        path.resolve().rename(new_path.resolve())
         | 
| 907 | 
             
                        update_lora_dict(str(new_path))
         | 
| 908 | 
            +
                        l_path = str(new_path)
         | 
| 909 | 
             
                new_lora_model_list = get_lora_model_list()
         | 
| 910 | 
             
                new_lora_tupled_list = get_all_lora_tupled_list()
         | 
| 911 | 
            +
                msg_lora = "Downloaded"
         | 
| 912 | 
            +
                if l_name:
         | 
| 913 | 
            +
                    msg_lora += f": <b>{l_name}</b>"
         | 
| 914 | 
            +
                    print(msg_lora)
         | 
| 915 | 
            +
             | 
| 916 | 
             
                return gr.update(
         | 
| 917 | 
            +
                    choices=new_lora_tupled_list, value=l_path
         | 
| 918 | 
             
                ), gr.update(
         | 
| 919 | 
             
                    choices=new_lora_tupled_list
         | 
| 920 | 
             
                ), gr.update(
         | 
|  | |
| 923 | 
             
                    choices=new_lora_tupled_list
         | 
| 924 | 
             
                ), gr.update(
         | 
| 925 | 
             
                    choices=new_lora_tupled_list
         | 
| 926 | 
            +
                ), gr.update(
         | 
| 927 | 
            +
                    value=msg_lora
         | 
| 928 | 
             
                )
         | 
| 929 |  | 
| 930 |  | 
|  | |
| 932 | 
             
                progress(0, desc="Uploading...")
         | 
| 933 | 
             
                file_paths = [file.name for file in files]
         | 
| 934 | 
             
                progress(1, desc="Uploaded.")
         | 
| 935 | 
            +
                return gr.update(value=file_paths, visible=True), gr.update()
         | 
| 936 |  | 
| 937 |  | 
| 938 | 
             
            def move_file_lora(filepaths):
         | 
| 939 | 
             
                for file in filepaths:
         | 
| 940 | 
            +
                    path = Path(shutil.move(Path(file).resolve(), Path(f"./{DIRECTORY_LORAS}").resolve()))
         | 
| 941 | 
             
                    newpath = Path(f'{path.parent.name}/{escape_lora_basename(path.stem)}{path.suffix}')
         | 
| 942 | 
             
                    path.resolve().rename(newpath.resolve())
         | 
| 943 | 
             
                    update_lora_dict(str(newpath))
         | 
|  | |
| 1079 | 
             
                    selected = civitai_last_choices[selected_index][1]
         | 
| 1080 | 
             
                    return gr.update(value=selected)
         | 
| 1081 | 
             
                except Exception:
         | 
| 1082 | 
            +
                    return gr.update()
         | 
| 1083 |  | 
| 1084 |  | 
| 1085 | 
             
            def select_civitai_lora(search_result):
         | 
|  | |
| 1563 | 
             
                else:
         | 
| 1564 | 
             
                    return default
         | 
| 1565 |  | 
| 1566 | 
            +
             | 
| 1567 | 
            +
            EXAMPLES_GUI = [
         | 
| 1568 | 
            +
                [
         | 
| 1569 | 
            +
                    "1girl, souryuu asuka langley, neon genesis evangelion, plugsuit, pilot suit, red bodysuit, sitting, crossing legs, black eye patch, cat hat, throne, symmetrical, looking down, from bottom, looking at viewer, outdoors, masterpiece, best quality, very aesthetic, absurdres",
         | 
| 1570 | 
            +
                    "nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
         | 
| 1571 | 
            +
                    1,
         | 
| 1572 | 
            +
                    30,
         | 
| 1573 | 
            +
                    7.5,
         | 
| 1574 | 
            +
                    True,
         | 
| 1575 | 
            +
                    -1,
         | 
| 1576 | 
            +
                    "Euler a",
         | 
| 1577 | 
            +
                    1152,
         | 
| 1578 | 
            +
                    896,
         | 
| 1579 | 
            +
                    "votepurchase/animagine-xl-3.1",
         | 
| 1580 | 
            +
                ],
         | 
| 1581 | 
            +
                [
         | 
| 1582 | 
            +
                    "solo, princess Zelda OOT, score_9, score_8_up, score_8, medium breasts, cute, eyelashes, cute small face, long hair, crown braid, hairclip, pointy ears, soft curvy body, looking at viewer, smile, blush, white dress, medium body, (((holding the Master Sword))), standing, deep forest in the background",
         | 
| 1583 | 
            +
                    "score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white,",
         | 
| 1584 | 
            +
                    1,
         | 
| 1585 | 
            +
                    30,
         | 
| 1586 | 
            +
                    5.,
         | 
| 1587 | 
            +
                    True,
         | 
| 1588 | 
            +
                    -1,
         | 
| 1589 | 
            +
                    "Euler a",
         | 
| 1590 | 
            +
                    1024,
         | 
| 1591 | 
            +
                    1024,
         | 
| 1592 | 
            +
                    "votepurchase/ponyDiffusionV6XL",
         | 
| 1593 | 
            +
                ],
         | 
| 1594 | 
            +
                [
         | 
| 1595 | 
            +
                    "1girl, oomuro sakurako, yuru yuri, official art, school uniform, anime artwork, anime style, vibrant, studio anime, highly detailed, masterpiece, best quality, very aesthetic, absurdres",
         | 
| 1596 | 
            +
                    "photo, deformed, black and white, realism, disfigured, low contrast, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
         | 
| 1597 | 
            +
                    1,
         | 
| 1598 | 
            +
                    40,
         | 
| 1599 | 
            +
                    7.0,
         | 
| 1600 | 
            +
                    True,
         | 
| 1601 | 
            +
                    -1,
         | 
| 1602 | 
            +
                    "Euler a",
         | 
| 1603 | 
            +
                    1024,
         | 
| 1604 | 
            +
                    1024,
         | 
| 1605 | 
            +
                    "Raelina/Rae-Diffusion-XL-V2",
         | 
| 1606 | 
            +
                ],
         | 
| 1607 | 
            +
                [
         | 
| 1608 | 
            +
                    "1girl, akaza akari, yuru yuri, official art, anime artwork, anime style, vibrant, studio anime, highly detailed, masterpiece, best quality, very aesthetic, absurdres",
         | 
| 1609 | 
            +
                    "photo, deformed, black and white, realism, disfigured, low contrast, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
         | 
| 1610 | 
            +
                    1,
         | 
| 1611 | 
            +
                    35,
         | 
| 1612 | 
            +
                    7.0,
         | 
| 1613 | 
            +
                    True,
         | 
| 1614 | 
            +
                    -1,
         | 
| 1615 | 
            +
                    "Euler a",
         | 
| 1616 | 
            +
                    1024,
         | 
| 1617 | 
            +
                    1024,
         | 
| 1618 | 
            +
                    "Raelina/Raemu-XL-V4",
         | 
| 1619 | 
            +
                ],
         | 
| 1620 | 
            +
                [
         | 
| 1621 | 
            +
                    "yoshida yuuko, machikado mazoku, 1girl, solo, demon horns,horns, school uniform, long hair, open mouth, skirt, demon girl, ahoge, shiny, shiny hair, anime artwork",
         | 
| 1622 | 
            +
                    "nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
         | 
| 1623 | 
            +
                    1,
         | 
| 1624 | 
            +
                    50,
         | 
| 1625 | 
            +
                    7.,
         | 
| 1626 | 
            +
                    True,
         | 
| 1627 | 
            +
                    -1,
         | 
| 1628 | 
            +
                    "Euler a",
         | 
| 1629 | 
            +
                    1024,
         | 
| 1630 | 
            +
                    1024,
         | 
| 1631 | 
            +
                    "cagliostrolab/animagine-xl-3.1",
         | 
| 1632 | 
            +
                ],
         | 
| 1633 | 
            +
            ]
         | 
| 1634 | 
            +
             | 
| 1635 | 
            +
             | 
| 1636 | 
            +
            RESOURCES = (
         | 
| 1637 | 
            +
                """### Resources
         | 
| 1638 | 
            +
                - You can also try the image generator in Colab’s free tier, which provides free GPU [link](https://github.com/R3gm/SD_diffusers_interactive).
         | 
| 1639 | 
            +
                """
         | 
| 1640 | 
            +
            )
         | 
    	
        requirements.txt
    CHANGED
    
    | @@ -12,3 +12,4 @@ translatepy | |
| 12 | 
             
            timm
         | 
| 13 | 
             
            rapidfuzz
         | 
| 14 | 
             
            sentencepiece
         | 
|  | 
|  | |
| 12 | 
             
            timm
         | 
| 13 | 
             
            rapidfuzz
         | 
| 14 | 
             
            sentencepiece
         | 
| 15 | 
            +
            unidecode
         | 
    	
        tagger/character_series_dict.csv
    ADDED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  | 
    	
        tagger/danbooru_e621.csv
    ADDED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  | 
    	
        tagger/output.py
    ADDED
    
    | @@ -0,0 +1,16 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from dataclasses import dataclass
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            @dataclass
         | 
| 5 | 
            +
            class UpsamplingOutput:
         | 
| 6 | 
            +
                upsampled_tags: str
         | 
| 7 | 
            +
             | 
| 8 | 
            +
                copyright_tags: str
         | 
| 9 | 
            +
                character_tags: str
         | 
| 10 | 
            +
                general_tags: str
         | 
| 11 | 
            +
                rating_tag: str
         | 
| 12 | 
            +
                aspect_ratio_tag: str
         | 
| 13 | 
            +
                length_tag: str
         | 
| 14 | 
            +
                identity_tag: str
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                elapsed_time: float = 0.0
         | 
    	
        tagger/tag_group.csv
    ADDED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  | 
    	
        tagger/tagger.py
    ADDED
    
    | @@ -0,0 +1,556 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import spaces
         | 
| 2 | 
            +
            from PIL import Image
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import gradio as gr
         | 
| 5 | 
            +
            from transformers import AutoImageProcessor, AutoModelForImageClassification
         | 
| 6 | 
            +
            from pathlib import Path
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            WD_MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
         | 
| 10 | 
            +
            WD_MODEL_NAME = WD_MODEL_NAMES[0]
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 13 | 
            +
            default_device = device
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            try:
         | 
| 16 | 
            +
                wd_model = AutoModelForImageClassification.from_pretrained(WD_MODEL_NAME, trust_remote_code=True).to(default_device).eval()
         | 
| 17 | 
            +
                wd_processor = AutoImageProcessor.from_pretrained(WD_MODEL_NAME, trust_remote_code=True)
         | 
| 18 | 
            +
            except Exception as e:
         | 
| 19 | 
            +
                print(e)
         | 
| 20 | 
            +
                wd_model = wd_processor = None
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
         | 
| 23 | 
            +
                return (
         | 
| 24 | 
            +
                    [f"1{noun}"]
         | 
| 25 | 
            +
                    + [f"{num}{noun}s" for num in range(minimum + 1, maximum + 1)]
         | 
| 26 | 
            +
                    + [f"{maximum+1}+{noun}s"]
         | 
| 27 | 
            +
                )
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            PEOPLE_TAGS = (
         | 
| 31 | 
            +
                _people_tag("girl") + _people_tag("boy") + _people_tag("other") + ["no humans"]
         | 
| 32 | 
            +
            )
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
            RATING_MAP = {
         | 
| 36 | 
            +
                "sfw": "safe",
         | 
| 37 | 
            +
                "general": "safe",
         | 
| 38 | 
            +
                "sensitive": "sensitive",
         | 
| 39 | 
            +
                "questionable": "nsfw",
         | 
| 40 | 
            +
                "explicit": "explicit, nsfw",
         | 
| 41 | 
            +
            }
         | 
| 42 | 
            +
            DANBOORU_TO_E621_RATING_MAP = {
         | 
| 43 | 
            +
                "sfw": "rating_safe",
         | 
| 44 | 
            +
                "general": "rating_safe",
         | 
| 45 | 
            +
                "safe": "rating_safe",
         | 
| 46 | 
            +
                "sensitive": "rating_safe",
         | 
| 47 | 
            +
                "nsfw": "rating_explicit",
         | 
| 48 | 
            +
                "explicit, nsfw": "rating_explicit",
         | 
| 49 | 
            +
                "explicit": "rating_explicit",
         | 
| 50 | 
            +
                "rating:safe": "rating_safe",
         | 
| 51 | 
            +
                "rating:general": "rating_safe",
         | 
| 52 | 
            +
                "rating:sensitive": "rating_safe",
         | 
| 53 | 
            +
                "rating:questionable, nsfw": "rating_explicit",
         | 
| 54 | 
            +
                "rating:explicit, nsfw": "rating_explicit",
         | 
| 55 | 
            +
            }
         | 
| 56 | 
            +
             | 
| 57 | 
            +
             | 
| 58 | 
            +
            # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
         | 
| 59 | 
            +
            kaomojis = [
         | 
| 60 | 
            +
                "0_0",
         | 
| 61 | 
            +
                "(o)_(o)",
         | 
| 62 | 
            +
                "+_+",
         | 
| 63 | 
            +
                "+_-",
         | 
| 64 | 
            +
                "._.",
         | 
| 65 | 
            +
                "<o>_<o>",
         | 
| 66 | 
            +
                "<|>_<|>",
         | 
| 67 | 
            +
                "=_=",
         | 
| 68 | 
            +
                ">_<",
         | 
| 69 | 
            +
                "3_3",
         | 
| 70 | 
            +
                "6_9",
         | 
| 71 | 
            +
                ">_o",
         | 
| 72 | 
            +
                "@_@",
         | 
| 73 | 
            +
                "^_^",
         | 
| 74 | 
            +
                "o_o",
         | 
| 75 | 
            +
                "u_u",
         | 
| 76 | 
            +
                "x_x",
         | 
| 77 | 
            +
                "|_|",
         | 
| 78 | 
            +
                "||_||",
         | 
| 79 | 
            +
            ]
         | 
| 80 | 
            +
             | 
| 81 | 
            +
             | 
| 82 | 
            +
            def replace_underline(x: str):
         | 
| 83 | 
            +
                return x.strip().replace("_", " ") if x not in kaomojis else x.strip()
         | 
| 84 | 
            +
             | 
| 85 | 
            +
             | 
| 86 | 
            +
            def to_list(s):
         | 
| 87 | 
            +
                return [x.strip() for x in s.split(",") if not s == ""]
         | 
| 88 | 
            +
             | 
| 89 | 
            +
             | 
| 90 | 
            +
            def list_sub(a, b):
         | 
| 91 | 
            +
                return [e for e in a if e not in b]
         | 
| 92 | 
            +
             | 
| 93 | 
            +
             | 
| 94 | 
            +
            def list_uniq(l):
         | 
| 95 | 
            +
                return sorted(set(l), key=l.index)
         | 
| 96 | 
            +
             | 
| 97 | 
            +
             | 
| 98 | 
            +
            def load_dict_from_csv(filename):
         | 
| 99 | 
            +
                dict = {}
         | 
| 100 | 
            +
                if not Path(filename).exists():
         | 
| 101 | 
            +
                    if Path('./tagger/', filename).exists(): filename = str(Path('./tagger/', filename))
         | 
| 102 | 
            +
                    else: return dict
         | 
| 103 | 
            +
                try:
         | 
| 104 | 
            +
                    with open(filename, 'r', encoding="utf-8") as f:
         | 
| 105 | 
            +
                        lines = f.readlines()
         | 
| 106 | 
            +
                except Exception:
         | 
| 107 | 
            +
                    print(f"Failed to open dictionary file: {filename}")
         | 
| 108 | 
            +
                    return dict
         | 
| 109 | 
            +
                for line in lines:
         | 
| 110 | 
            +
                    parts = line.strip().split(',')
         | 
| 111 | 
            +
                    dict[parts[0]] = parts[1]
         | 
| 112 | 
            +
                return dict
         | 
| 113 | 
            +
             | 
| 114 | 
            +
             | 
| 115 | 
            +
            anime_series_dict = load_dict_from_csv('character_series_dict.csv')
         | 
| 116 | 
            +
             | 
| 117 | 
            +
             | 
| 118 | 
            +
            def character_list_to_series_list(character_list):
         | 
| 119 | 
            +
                output_series_tag = []
         | 
| 120 | 
            +
                series_tag = ""
         | 
| 121 | 
            +
                series_dict = anime_series_dict
         | 
| 122 | 
            +
                for tag in character_list:
         | 
| 123 | 
            +
                    series_tag = series_dict.get(tag, "")
         | 
| 124 | 
            +
                    if tag.endswith(")"):
         | 
| 125 | 
            +
                        tags = tag.split("(")
         | 
| 126 | 
            +
                        character_tag = "(".join(tags[:-1])
         | 
| 127 | 
            +
                        if character_tag.endswith(" "):
         | 
| 128 | 
            +
                            character_tag = character_tag[:-1]
         | 
| 129 | 
            +
                        series_tag = tags[-1].replace(")", "")
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                if series_tag:
         | 
| 132 | 
            +
                    output_series_tag.append(series_tag)
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                return output_series_tag
         | 
| 135 | 
            +
             | 
| 136 | 
            +
             | 
| 137 | 
            +
            def select_random_character(series: str, character: str):
         | 
| 138 | 
            +
                from random import seed, randrange
         | 
| 139 | 
            +
                seed()
         | 
| 140 | 
            +
                character_list = list(anime_series_dict.keys())
         | 
| 141 | 
            +
                character = character_list[randrange(len(character_list) - 1)]
         | 
| 142 | 
            +
                series = anime_series_dict.get(character.split(",")[0].strip(), "")
         | 
| 143 | 
            +
                return series, character
         | 
| 144 | 
            +
             | 
| 145 | 
            +
             | 
| 146 | 
            +
            def danbooru_to_e621(dtag, e621_dict):
         | 
| 147 | 
            +
                def d_to_e(match, e621_dict):
         | 
| 148 | 
            +
                    dtag = match.group(0)
         | 
| 149 | 
            +
                    etag = e621_dict.get(replace_underline(dtag), "")
         | 
| 150 | 
            +
                    if etag:
         | 
| 151 | 
            +
                        return etag
         | 
| 152 | 
            +
                    else:
         | 
| 153 | 
            +
                        return dtag
         | 
| 154 | 
            +
                
         | 
| 155 | 
            +
                import re
         | 
| 156 | 
            +
                tag = re.sub(r'[\w ]+', lambda wrapper: d_to_e(wrapper, e621_dict), dtag, 2)
         | 
| 157 | 
            +
                return tag
         | 
| 158 | 
            +
             | 
| 159 | 
            +
             | 
| 160 | 
            +
            danbooru_to_e621_dict = load_dict_from_csv('danbooru_e621.csv')
         | 
| 161 | 
            +
             | 
| 162 | 
            +
             | 
| 163 | 
            +
            def convert_danbooru_to_e621_prompt(input_prompt: str = "", prompt_type: str = "danbooru"):
         | 
| 164 | 
            +
                if prompt_type == "danbooru": return input_prompt
         | 
| 165 | 
            +
                tags = input_prompt.split(",") if input_prompt else []
         | 
| 166 | 
            +
                people_tags: list[str] = []
         | 
| 167 | 
            +
                other_tags: list[str] = []
         | 
| 168 | 
            +
                rating_tags: list[str] = []
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                e621_dict = danbooru_to_e621_dict
         | 
| 171 | 
            +
                for tag in tags:
         | 
| 172 | 
            +
                    tag = replace_underline(tag)
         | 
| 173 | 
            +
                    tag = danbooru_to_e621(tag, e621_dict)
         | 
| 174 | 
            +
                    if tag in PEOPLE_TAGS:        
         | 
| 175 | 
            +
                        people_tags.append(tag)
         | 
| 176 | 
            +
                    elif tag in DANBOORU_TO_E621_RATING_MAP.keys():
         | 
| 177 | 
            +
                        rating_tags.append(DANBOORU_TO_E621_RATING_MAP.get(tag.replace(" ",""), ""))            
         | 
| 178 | 
            +
                    else:
         | 
| 179 | 
            +
                        other_tags.append(tag)
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                rating_tags = sorted(set(rating_tags), key=rating_tags.index)
         | 
| 182 | 
            +
                rating_tags = [rating_tags[0]] if rating_tags else []
         | 
| 183 | 
            +
                rating_tags = ["explicit, nsfw"] if rating_tags and rating_tags[0] == "explicit" else rating_tags
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                output_prompt = ", ".join(people_tags + other_tags + rating_tags)
         | 
| 186 | 
            +
                
         | 
| 187 | 
            +
                return output_prompt
         | 
| 188 | 
            +
             | 
| 189 | 
            +
             | 
| 190 | 
            +
            from translatepy import Translator
         | 
| 191 | 
            +
            translator = Translator()
         | 
| 192 | 
            +
            def translate_prompt_old(prompt: str = ""):
         | 
| 193 | 
            +
                def translate_to_english(input: str):
         | 
| 194 | 
            +
                    try:
         | 
| 195 | 
            +
                        output = str(translator.translate(input, 'English'))
         | 
| 196 | 
            +
                    except Exception as e:
         | 
| 197 | 
            +
                        output = input
         | 
| 198 | 
            +
                        print(e)
         | 
| 199 | 
            +
                    return output
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                def is_japanese(s):
         | 
| 202 | 
            +
                    import unicodedata
         | 
| 203 | 
            +
                    for ch in s:
         | 
| 204 | 
            +
                        name = unicodedata.name(ch, "") 
         | 
| 205 | 
            +
                        if "CJK UNIFIED" in name or "HIRAGANA" in name or "KATAKANA" in name:
         | 
| 206 | 
            +
                            return True
         | 
| 207 | 
            +
                    return False
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                def to_list(s):
         | 
| 210 | 
            +
                    return [x.strip() for x in s.split(",")]
         | 
| 211 | 
            +
                
         | 
| 212 | 
            +
                prompts = to_list(prompt)
         | 
| 213 | 
            +
                outputs = []
         | 
| 214 | 
            +
                for p in prompts:
         | 
| 215 | 
            +
                    p = translate_to_english(p) if is_japanese(p) else p
         | 
| 216 | 
            +
                    outputs.append(p)
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                return ", ".join(outputs)
         | 
| 219 | 
            +
             | 
| 220 | 
            +
             | 
| 221 | 
            +
            def translate_prompt(input: str):
         | 
| 222 | 
            +
                try:
         | 
| 223 | 
            +
                    output = str(translator.translate(input, 'English'))
         | 
| 224 | 
            +
                except Exception as e:
         | 
| 225 | 
            +
                    output = input
         | 
| 226 | 
            +
                    print(e)
         | 
| 227 | 
            +
                return output
         | 
| 228 | 
            +
             | 
| 229 | 
            +
             | 
| 230 | 
            +
            def translate_prompt_to_ja(prompt: str = ""):
         | 
| 231 | 
            +
                def translate_to_japanese(input: str):
         | 
| 232 | 
            +
                    try:
         | 
| 233 | 
            +
                        output = str(translator.translate(input, 'Japanese'))
         | 
| 234 | 
            +
                    except Exception as e:
         | 
| 235 | 
            +
                        output = input
         | 
| 236 | 
            +
                        print(e)
         | 
| 237 | 
            +
                    return output
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                def is_japanese(s):
         | 
| 240 | 
            +
                    import unicodedata
         | 
| 241 | 
            +
                    for ch in s:
         | 
| 242 | 
            +
                        name = unicodedata.name(ch, "") 
         | 
| 243 | 
            +
                        if "CJK UNIFIED" in name or "HIRAGANA" in name or "KATAKANA" in name:
         | 
| 244 | 
            +
                            return True
         | 
| 245 | 
            +
                    return False
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                def to_list(s):
         | 
| 248 | 
            +
                    return [x.strip() for x in s.split(",")]
         | 
| 249 | 
            +
                
         | 
| 250 | 
            +
                prompts = to_list(prompt)
         | 
| 251 | 
            +
                outputs = []
         | 
| 252 | 
            +
                for p in prompts:
         | 
| 253 | 
            +
                    p = translate_to_japanese(p) if not is_japanese(p) else p
         | 
| 254 | 
            +
                    outputs.append(p)
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                return ", ".join(outputs)
         | 
| 257 | 
            +
             | 
| 258 | 
            +
             | 
| 259 | 
            +
            def tags_to_ja(itag, dict):
         | 
| 260 | 
            +
                def t_to_j(match, dict):
         | 
| 261 | 
            +
                    tag = match.group(0)
         | 
| 262 | 
            +
                    ja = dict.get(replace_underline(tag), "")
         | 
| 263 | 
            +
                    if ja:
         | 
| 264 | 
            +
                        return ja
         | 
| 265 | 
            +
                    else:
         | 
| 266 | 
            +
                        return tag
         | 
| 267 | 
            +
                
         | 
| 268 | 
            +
                import re
         | 
| 269 | 
            +
                tag = re.sub(r'[\w ]+', lambda wrapper: t_to_j(wrapper, dict), itag, 2)
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                return tag
         | 
| 272 | 
            +
             | 
| 273 | 
            +
             | 
| 274 | 
            +
            def convert_tags_to_ja(input_prompt: str = ""):
         | 
| 275 | 
            +
                tags = input_prompt.split(",") if input_prompt else []
         | 
| 276 | 
            +
                out_tags = []
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                tags_to_ja_dict = load_dict_from_csv('all_tags_ja_ext.csv')
         | 
| 279 | 
            +
                dict = tags_to_ja_dict
         | 
| 280 | 
            +
                for tag in tags:
         | 
| 281 | 
            +
                    tag = replace_underline(tag)
         | 
| 282 | 
            +
                    tag = tags_to_ja(tag, dict)
         | 
| 283 | 
            +
                    out_tags.append(tag)
         | 
| 284 | 
            +
                
         | 
| 285 | 
            +
                return ", ".join(out_tags)
         | 
| 286 | 
            +
             | 
| 287 | 
            +
             | 
| 288 | 
            +
            enable_auto_recom_prompt = True
         | 
| 289 | 
            +
             | 
| 290 | 
            +
             | 
| 291 | 
            +
            animagine_ps = to_list("masterpiece, best quality, very aesthetic, absurdres")
         | 
| 292 | 
            +
            animagine_nps = to_list("lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]")
         | 
| 293 | 
            +
            pony_ps = to_list("score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres")
         | 
| 294 | 
            +
            pony_nps = to_list("source_pony, score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends")
         | 
| 295 | 
            +
            other_ps = to_list("anime artwork, anime style, studio anime, highly detailed, cinematic photo, 35mm photograph, film, bokeh, professional, 4k, highly detailed")
         | 
| 296 | 
            +
            other_nps = to_list("photo, deformed, black and white, realism, disfigured, low contrast, drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly")
         | 
| 297 | 
            +
            default_ps = to_list("highly detailed, masterpiece, best quality, very aesthetic, absurdres")
         | 
| 298 | 
            +
            default_nps = to_list("score_6, score_5, score_4, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]")
         | 
| 299 | 
            +
            def insert_recom_prompt(prompt: str = "", neg_prompt: str = "", type: str = "None"):
         | 
| 300 | 
            +
                global enable_auto_recom_prompt
         | 
| 301 | 
            +
                prompts = to_list(prompt)
         | 
| 302 | 
            +
                neg_prompts = to_list(neg_prompt)
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                prompts = list_sub(prompts, animagine_ps + pony_ps)
         | 
| 305 | 
            +
                neg_prompts = list_sub(neg_prompts, animagine_nps + pony_nps)
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                last_empty_p = [""] if not prompts and type != "None" else []
         | 
| 308 | 
            +
                last_empty_np = [""] if not neg_prompts and type != "None" else []
         | 
| 309 | 
            +
             | 
| 310 | 
            +
                if type == "Auto":
         | 
| 311 | 
            +
                    enable_auto_recom_prompt = True
         | 
| 312 | 
            +
                else:
         | 
| 313 | 
            +
                    enable_auto_recom_prompt = False
         | 
| 314 | 
            +
                    if type == "Animagine":
         | 
| 315 | 
            +
                        prompts = prompts + animagine_ps
         | 
| 316 | 
            +
                        neg_prompts = neg_prompts + animagine_nps
         | 
| 317 | 
            +
                    elif type == "Pony":
         | 
| 318 | 
            +
                        prompts = prompts + pony_ps
         | 
| 319 | 
            +
                        neg_prompts = neg_prompts + pony_nps
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                prompt = ", ".join(list_uniq(prompts) + last_empty_p)
         | 
| 322 | 
            +
                neg_prompt = ", ".join(list_uniq(neg_prompts) + last_empty_np)
         | 
| 323 | 
            +
             | 
| 324 | 
            +
                return prompt, neg_prompt
         | 
| 325 | 
            +
             | 
| 326 | 
            +
             | 
| 327 | 
            +
            def load_model_prompt_dict():
         | 
| 328 | 
            +
                import json
         | 
| 329 | 
            +
                dict = {}
         | 
| 330 | 
            +
                path = 'model_dict.json' if Path('model_dict.json').exists() else './tagger/model_dict.json'
         | 
| 331 | 
            +
                try:
         | 
| 332 | 
            +
                    with open('model_dict.json', encoding='utf-8') as f:
         | 
| 333 | 
            +
                        dict = json.load(f)
         | 
| 334 | 
            +
                except Exception:
         | 
| 335 | 
            +
                    pass
         | 
| 336 | 
            +
                return dict
         | 
| 337 | 
            +
             | 
| 338 | 
            +
             | 
| 339 | 
            +
            model_prompt_dict = load_model_prompt_dict()
         | 
| 340 | 
            +
             | 
| 341 | 
            +
             | 
| 342 | 
            +
            def insert_model_recom_prompt(prompt: str = "", neg_prompt: str = "", model_name: str = "None"):
         | 
| 343 | 
            +
                if not model_name or not enable_auto_recom_prompt: return prompt, neg_prompt
         | 
| 344 | 
            +
                prompts = to_list(prompt)
         | 
| 345 | 
            +
                neg_prompts = to_list(neg_prompt)
         | 
| 346 | 
            +
                prompts = list_sub(prompts, animagine_ps + pony_ps + other_ps)
         | 
| 347 | 
            +
                neg_prompts = list_sub(neg_prompts, animagine_nps + pony_nps + other_nps)
         | 
| 348 | 
            +
                last_empty_p = [""] if not prompts and type != "None" else []
         | 
| 349 | 
            +
                last_empty_np = [""] if not neg_prompts and type != "None" else []
         | 
| 350 | 
            +
                ps = []
         | 
| 351 | 
            +
                nps = []
         | 
| 352 | 
            +
                if model_name in model_prompt_dict.keys(): 
         | 
| 353 | 
            +
                    ps = to_list(model_prompt_dict[model_name]["prompt"])
         | 
| 354 | 
            +
                    nps = to_list(model_prompt_dict[model_name]["negative_prompt"])
         | 
| 355 | 
            +
                else:
         | 
| 356 | 
            +
                    ps = default_ps
         | 
| 357 | 
            +
                    nps = default_nps
         | 
| 358 | 
            +
                prompts = prompts + ps
         | 
| 359 | 
            +
                neg_prompts = neg_prompts + nps
         | 
| 360 | 
            +
                prompt = ", ".join(list_uniq(prompts) + last_empty_p)
         | 
| 361 | 
            +
                neg_prompt = ", ".join(list_uniq(neg_prompts) + last_empty_np)
         | 
| 362 | 
            +
                return prompt, neg_prompt
         | 
| 363 | 
            +
             | 
| 364 | 
            +
             | 
| 365 | 
            +
            tag_group_dict = load_dict_from_csv('tag_group.csv')
         | 
| 366 | 
            +
             | 
| 367 | 
            +
             | 
| 368 | 
            +
            def remove_specific_prompt(input_prompt: str = "", keep_tags: str = "all"):
         | 
| 369 | 
            +
                def is_dressed(tag):
         | 
| 370 | 
            +
                    import re
         | 
| 371 | 
            +
                    p = re.compile(r'dress|cloth|uniform|costume|vest|sweater|coat|shirt|jacket|blazer|apron|leotard|hood|sleeve|skirt|shorts|pant|loafer|ribbon|necktie|bow|collar|glove|sock|shoe|boots|wear|emblem')
         | 
| 372 | 
            +
                    return p.search(tag)
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                def is_background(tag):
         | 
| 375 | 
            +
                    import re
         | 
| 376 | 
            +
                    p = re.compile(r'background|outline|light|sky|build|day|screen|tree|city')
         | 
| 377 | 
            +
                    return p.search(tag)
         | 
| 378 | 
            +
             | 
| 379 | 
            +
                un_tags = ['solo']
         | 
| 380 | 
            +
                group_list = ['groups', 'body_parts', 'attire', 'posture', 'objects', 'creatures', 'locations', 'disambiguation_pages', 'commonly_misused_tags', 'phrases', 'verbs_and_gerunds', 'subjective', 'nudity', 'sex_objects', 'sex', 'sex_acts', 'image_composition', 'artistic_license', 'text', 'year_tags', 'metatags']
         | 
| 381 | 
            +
                keep_group_dict = {
         | 
| 382 | 
            +
                    "body": ['groups', 'body_parts'],
         | 
| 383 | 
            +
                    "dress": ['groups', 'body_parts', 'attire'],
         | 
| 384 | 
            +
                    "all": group_list,
         | 
| 385 | 
            +
                }
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                def is_necessary(tag, keep_tags, group_dict):
         | 
| 388 | 
            +
                    if keep_tags == "all":
         | 
| 389 | 
            +
                        return True
         | 
| 390 | 
            +
                    elif tag in un_tags or group_dict.get(tag, "") in explicit_group:
         | 
| 391 | 
            +
                        return False
         | 
| 392 | 
            +
                    elif keep_tags == "body" and is_dressed(tag):
         | 
| 393 | 
            +
                        return False
         | 
| 394 | 
            +
                    elif is_background(tag):
         | 
| 395 | 
            +
                        return False
         | 
| 396 | 
            +
                    else:
         | 
| 397 | 
            +
                        return True
         | 
| 398 | 
            +
                
         | 
| 399 | 
            +
                if keep_tags == "all": return input_prompt
         | 
| 400 | 
            +
                keep_group = keep_group_dict.get(keep_tags, keep_group_dict["body"])
         | 
| 401 | 
            +
                explicit_group = list(set(group_list) ^ set(keep_group))
         | 
| 402 | 
            +
             | 
| 403 | 
            +
                tags = input_prompt.split(",") if input_prompt else []
         | 
| 404 | 
            +
                people_tags: list[str] = []
         | 
| 405 | 
            +
                other_tags: list[str] = []
         | 
| 406 | 
            +
             | 
| 407 | 
            +
                group_dict = tag_group_dict
         | 
| 408 | 
            +
                for tag in tags:
         | 
| 409 | 
            +
                    tag = replace_underline(tag)
         | 
| 410 | 
            +
                    if tag in PEOPLE_TAGS:
         | 
| 411 | 
            +
                        people_tags.append(tag)
         | 
| 412 | 
            +
                    elif is_necessary(tag, keep_tags, group_dict):
         | 
| 413 | 
            +
                        other_tags.append(tag)
         | 
| 414 | 
            +
             | 
| 415 | 
            +
                output_prompt = ", ".join(people_tags + other_tags)
         | 
| 416 | 
            +
                
         | 
| 417 | 
            +
                return output_prompt
         | 
| 418 | 
            +
             | 
| 419 | 
            +
             | 
| 420 | 
            +
            def sort_taglist(tags: list[str]):
         | 
| 421 | 
            +
                if not tags: return []
         | 
| 422 | 
            +
                character_tags: list[str] = []
         | 
| 423 | 
            +
                series_tags: list[str] = []
         | 
| 424 | 
            +
                people_tags: list[str] = []
         | 
| 425 | 
            +
                group_list = ['groups', 'body_parts', 'attire', 'posture', 'objects', 'creatures', 'locations', 'disambiguation_pages', 'commonly_misused_tags', 'phrases', 'verbs_and_gerunds', 'subjective', 'nudity', 'sex_objects', 'sex', 'sex_acts', 'image_composition', 'artistic_license', 'text', 'year_tags', 'metatags']
         | 
| 426 | 
            +
                group_tags = {}
         | 
| 427 | 
            +
                other_tags: list[str] = []
         | 
| 428 | 
            +
                rating_tags: list[str] = []
         | 
| 429 | 
            +
             | 
| 430 | 
            +
                group_dict = tag_group_dict
         | 
| 431 | 
            +
                group_set = set(group_dict.keys())
         | 
| 432 | 
            +
                character_set = set(anime_series_dict.keys())
         | 
| 433 | 
            +
                series_set = set(anime_series_dict.values())
         | 
| 434 | 
            +
                rating_set = set(DANBOORU_TO_E621_RATING_MAP.keys()) | set(DANBOORU_TO_E621_RATING_MAP.values())
         | 
| 435 | 
            +
             | 
| 436 | 
            +
                for tag in tags:
         | 
| 437 | 
            +
                    tag = replace_underline(tag)
         | 
| 438 | 
            +
                    if tag in PEOPLE_TAGS:
         | 
| 439 | 
            +
                        people_tags.append(tag)
         | 
| 440 | 
            +
                    elif tag in rating_set:
         | 
| 441 | 
            +
                        rating_tags.append(tag)
         | 
| 442 | 
            +
                    elif tag in group_set:
         | 
| 443 | 
            +
                        elem = group_dict[tag]
         | 
| 444 | 
            +
                        group_tags[elem] = group_tags[elem] + [tag] if elem in group_tags else [tag]
         | 
| 445 | 
            +
                    elif tag in character_set:
         | 
| 446 | 
            +
                        character_tags.append(tag)
         | 
| 447 | 
            +
                    elif tag in series_set:
         | 
| 448 | 
            +
                        series_tags.append(tag)
         | 
| 449 | 
            +
                    else:
         | 
| 450 | 
            +
                        other_tags.append(tag)
         | 
| 451 | 
            +
             | 
| 452 | 
            +
                output_group_tags: list[str] = []
         | 
| 453 | 
            +
                for k in group_list:
         | 
| 454 | 
            +
                    output_group_tags.extend(group_tags.get(k, []))
         | 
| 455 | 
            +
             | 
| 456 | 
            +
                rating_tags = [rating_tags[0]] if rating_tags else []
         | 
| 457 | 
            +
                rating_tags = ["explicit, nsfw"] if rating_tags and rating_tags[0] == "explicit" else rating_tags
         | 
| 458 | 
            +
             | 
| 459 | 
            +
                output_tags = character_tags + series_tags + people_tags + output_group_tags + other_tags + rating_tags
         | 
| 460 | 
            +
                
         | 
| 461 | 
            +
                return output_tags
         | 
| 462 | 
            +
             | 
| 463 | 
            +
             | 
| 464 | 
            +
            def sort_tags(tags: str):
         | 
| 465 | 
            +
                if not tags: return ""
         | 
| 466 | 
            +
                taglist: list[str] = []
         | 
| 467 | 
            +
                for tag in tags.split(","):
         | 
| 468 | 
            +
                    taglist.append(tag.strip())
         | 
| 469 | 
            +
                taglist = list(filter(lambda x: x != "", taglist))
         | 
| 470 | 
            +
                return ", ".join(sort_taglist(taglist))
         | 
| 471 | 
            +
             | 
| 472 | 
            +
             | 
| 473 | 
            +
            def postprocess_results(results: dict[str, float], general_threshold: float, character_threshold: float):
         | 
| 474 | 
            +
                results = {
         | 
| 475 | 
            +
                    k: v for k, v in sorted(results.items(), key=lambda item: item[1], reverse=True)
         | 
| 476 | 
            +
                }
         | 
| 477 | 
            +
             | 
| 478 | 
            +
                rating = {}
         | 
| 479 | 
            +
                character = {}
         | 
| 480 | 
            +
                general = {}
         | 
| 481 | 
            +
             | 
| 482 | 
            +
                for k, v in results.items():
         | 
| 483 | 
            +
                    if k.startswith("rating:"):
         | 
| 484 | 
            +
                        rating[k.replace("rating:", "")] = v
         | 
| 485 | 
            +
                        continue
         | 
| 486 | 
            +
                    elif k.startswith("character:"):
         | 
| 487 | 
            +
                        character[k.replace("character:", "")] = v
         | 
| 488 | 
            +
                        continue
         | 
| 489 | 
            +
             | 
| 490 | 
            +
                    general[k] = v
         | 
| 491 | 
            +
             | 
| 492 | 
            +
                character = {k: v for k, v in character.items() if v >= character_threshold}
         | 
| 493 | 
            +
                general = {k: v for k, v in general.items() if v >= general_threshold}
         | 
| 494 | 
            +
             | 
| 495 | 
            +
                return rating, character, general
         | 
| 496 | 
            +
             | 
| 497 | 
            +
             | 
| 498 | 
            +
            def gen_prompt(rating: list[str], character: list[str], general: list[str]):
         | 
| 499 | 
            +
                people_tags: list[str] = []
         | 
| 500 | 
            +
                other_tags: list[str] = []
         | 
| 501 | 
            +
                rating_tag = RATING_MAP[rating[0]]
         | 
| 502 | 
            +
             | 
| 503 | 
            +
                for tag in general:
         | 
| 504 | 
            +
                    if tag in PEOPLE_TAGS:
         | 
| 505 | 
            +
                        people_tags.append(tag)
         | 
| 506 | 
            +
                    else:
         | 
| 507 | 
            +
                        other_tags.append(tag)
         | 
| 508 | 
            +
             | 
| 509 | 
            +
                all_tags = people_tags + other_tags
         | 
| 510 | 
            +
             | 
| 511 | 
            +
                return ", ".join(all_tags)
         | 
| 512 | 
            +
             | 
| 513 | 
            +
             | 
| 514 | 
            +
            @spaces.GPU(duration=30)
         | 
| 515 | 
            +
            def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_threshold: float = 0.8):
         | 
| 516 | 
            +
                inputs = wd_processor.preprocess(image, return_tensors="pt")
         | 
| 517 | 
            +
             | 
| 518 | 
            +
                outputs = wd_model(**inputs.to(wd_model.device, wd_model.dtype))
         | 
| 519 | 
            +
                logits = torch.sigmoid(outputs.logits[0])  # take the first logits
         | 
| 520 | 
            +
             | 
| 521 | 
            +
                # get probabilities
         | 
| 522 | 
            +
                if device != default_device: wd_model.to(device=device)
         | 
| 523 | 
            +
                results = {
         | 
| 524 | 
            +
                    wd_model.config.id2label[i]: float(logit.float()) for i, logit in enumerate(logits)
         | 
| 525 | 
            +
                }
         | 
| 526 | 
            +
                if device != default_device: wd_model.to(device=default_device)
         | 
| 527 | 
            +
                # rating, character, general
         | 
| 528 | 
            +
                rating, character, general = postprocess_results(
         | 
| 529 | 
            +
                    results, general_threshold, character_threshold
         | 
| 530 | 
            +
                )
         | 
| 531 | 
            +
                prompt = gen_prompt(
         | 
| 532 | 
            +
                    list(rating.keys()), list(character.keys()), list(general.keys())
         | 
| 533 | 
            +
                )
         | 
| 534 | 
            +
                output_series_tag = ""
         | 
| 535 | 
            +
                output_series_list = character_list_to_series_list(character.keys())
         | 
| 536 | 
            +
                if output_series_list:
         | 
| 537 | 
            +
                    output_series_tag = output_series_list[0]
         | 
| 538 | 
            +
                else:
         | 
| 539 | 
            +
                    output_series_tag = ""
         | 
| 540 | 
            +
                return output_series_tag, ", ".join(character.keys()), prompt, gr.update(interactive=True)
         | 
| 541 | 
            +
             | 
| 542 | 
            +
             | 
| 543 | 
            +
            def predict_tags_wd(image: Image.Image, input_tags: str, algo: list[str], general_threshold: float = 0.3,
         | 
| 544 | 
            +
                                 character_threshold: float = 0.8, input_series: str = "", input_character: str = ""):
         | 
| 545 | 
            +
                if not "Use WD Tagger" in algo and len(algo) != 0:
         | 
| 546 | 
            +
                    return input_series, input_character, input_tags, gr.update(interactive=True)
         | 
| 547 | 
            +
                return predict_tags(image, general_threshold, character_threshold)
         | 
| 548 | 
            +
             | 
| 549 | 
            +
             | 
| 550 | 
            +
            def compose_prompt_to_copy(character: str, series: str, general: str):
         | 
| 551 | 
            +
                characters = character.split(",") if character else []
         | 
| 552 | 
            +
                serieses = series.split(",") if series else []
         | 
| 553 | 
            +
                generals = general.split(",") if general else []
         | 
| 554 | 
            +
                tags = characters + serieses + generals
         | 
| 555 | 
            +
                cprompt = ",".join(tags) if tags else ""
         | 
| 556 | 
            +
                return cprompt
         | 
    	
        tagger/utils.py
    ADDED
    
    | @@ -0,0 +1,50 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
            from dartrs.v2 import AspectRatioTag, LengthTag, RatingTag, IdentityTag
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            V2_ASPECT_RATIO_OPTIONS: list[AspectRatioTag] = [
         | 
| 6 | 
            +
                "ultra_wide",
         | 
| 7 | 
            +
                "wide",
         | 
| 8 | 
            +
                "square",
         | 
| 9 | 
            +
                "tall",
         | 
| 10 | 
            +
                "ultra_tall",
         | 
| 11 | 
            +
            ]
         | 
| 12 | 
            +
            V2_RATING_OPTIONS: list[RatingTag] = [
         | 
| 13 | 
            +
                "sfw",
         | 
| 14 | 
            +
                "general",
         | 
| 15 | 
            +
                "sensitive",
         | 
| 16 | 
            +
                "nsfw",
         | 
| 17 | 
            +
                "questionable",
         | 
| 18 | 
            +
                "explicit",
         | 
| 19 | 
            +
            ]
         | 
| 20 | 
            +
            V2_LENGTH_OPTIONS: list[LengthTag] = [
         | 
| 21 | 
            +
                "very_short",
         | 
| 22 | 
            +
                "short",
         | 
| 23 | 
            +
                "medium",
         | 
| 24 | 
            +
                "long",
         | 
| 25 | 
            +
                "very_long",
         | 
| 26 | 
            +
            ]
         | 
| 27 | 
            +
            V2_IDENTITY_OPTIONS: list[IdentityTag] = [
         | 
| 28 | 
            +
                "none",
         | 
| 29 | 
            +
                "lax",
         | 
| 30 | 
            +
                "strict",
         | 
| 31 | 
            +
            ]
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            # ref: https://qiita.com/tregu148/items/fccccbbc47d966dd2fc2
         | 
| 35 | 
            +
            def gradio_copy_text(_text: None):
         | 
| 36 | 
            +
                gr.Info("Copied!")
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
            COPY_ACTION_JS = """\
         | 
| 40 | 
            +
            (inputs, _outputs) => {
         | 
| 41 | 
            +
              // inputs is the string value of the input_text
         | 
| 42 | 
            +
              if (inputs.trim() !== "") {
         | 
| 43 | 
            +
                navigator.clipboard.writeText(inputs);
         | 
| 44 | 
            +
              }
         | 
| 45 | 
            +
            }"""
         | 
| 46 | 
            +
             | 
| 47 | 
            +
             | 
| 48 | 
            +
            def gradio_copy_prompt(prompt: str):
         | 
| 49 | 
            +
                gr.Info("Copied!")
         | 
| 50 | 
            +
                return prompt
         | 
    	
        tagger/v2.py
    ADDED
    
    | @@ -0,0 +1,260 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import time
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            from typing import Callable
         | 
| 4 | 
            +
            from pathlib import Path
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from dartrs.v2 import (
         | 
| 7 | 
            +
                V2Model,
         | 
| 8 | 
            +
                MixtralModel,
         | 
| 9 | 
            +
                MistralModel,
         | 
| 10 | 
            +
                compose_prompt,
         | 
| 11 | 
            +
                LengthTag,
         | 
| 12 | 
            +
                AspectRatioTag,
         | 
| 13 | 
            +
                RatingTag,
         | 
| 14 | 
            +
                IdentityTag,
         | 
| 15 | 
            +
            )
         | 
| 16 | 
            +
            from dartrs.dartrs import DartTokenizer
         | 
| 17 | 
            +
            from dartrs.utils import get_generation_config
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            import gradio as gr
         | 
| 21 | 
            +
            from gradio.components import Component
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            try:
         | 
| 25 | 
            +
                from output import UpsamplingOutput
         | 
| 26 | 
            +
            except:
         | 
| 27 | 
            +
                from .output import UpsamplingOutput
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            V2_ALL_MODELS = {
         | 
| 31 | 
            +
                "dart-v2-moe-sft": {
         | 
| 32 | 
            +
                    "repo": "p1atdev/dart-v2-moe-sft",
         | 
| 33 | 
            +
                    "type": "sft",
         | 
| 34 | 
            +
                    "class": MixtralModel,
         | 
| 35 | 
            +
                },
         | 
| 36 | 
            +
                "dart-v2-sft": {
         | 
| 37 | 
            +
                    "repo": "p1atdev/dart-v2-sft",
         | 
| 38 | 
            +
                    "type": "sft",
         | 
| 39 | 
            +
                    "class": MistralModel,
         | 
| 40 | 
            +
                },
         | 
| 41 | 
            +
            }
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            def prepare_models(model_config: dict):
         | 
| 45 | 
            +
                model_name = model_config["repo"]
         | 
| 46 | 
            +
                tokenizer = DartTokenizer.from_pretrained(model_name)
         | 
| 47 | 
            +
                model = model_config["class"].from_pretrained(model_name)
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                return {
         | 
| 50 | 
            +
                    "tokenizer": tokenizer,
         | 
| 51 | 
            +
                    "model": model,
         | 
| 52 | 
            +
                }
         | 
| 53 | 
            +
             | 
| 54 | 
            +
             | 
| 55 | 
            +
            def normalize_tags(tokenizer: DartTokenizer, tags: str):
         | 
| 56 | 
            +
                """Just remove unk tokens."""
         | 
| 57 | 
            +
                return ", ".join([tag for tag in tokenizer.tokenize(tags) if tag != "<|unk|>"])
         | 
| 58 | 
            +
             | 
| 59 | 
            +
             | 
| 60 | 
            +
            @torch.no_grad()
         | 
| 61 | 
            +
            def generate_tags(
         | 
| 62 | 
            +
                model: V2Model,
         | 
| 63 | 
            +
                tokenizer: DartTokenizer,
         | 
| 64 | 
            +
                prompt: str,
         | 
| 65 | 
            +
                ban_token_ids: list[int],
         | 
| 66 | 
            +
            ):
         | 
| 67 | 
            +
                output = model.generate(
         | 
| 68 | 
            +
                    get_generation_config(
         | 
| 69 | 
            +
                        prompt,
         | 
| 70 | 
            +
                        tokenizer=tokenizer,
         | 
| 71 | 
            +
                        temperature=1,
         | 
| 72 | 
            +
                        top_p=0.9,
         | 
| 73 | 
            +
                        top_k=100,
         | 
| 74 | 
            +
                        max_new_tokens=256,
         | 
| 75 | 
            +
                        ban_token_ids=ban_token_ids,
         | 
| 76 | 
            +
                    ),
         | 
| 77 | 
            +
                )
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                return output
         | 
| 80 | 
            +
             | 
| 81 | 
            +
             | 
| 82 | 
            +
            def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
         | 
| 83 | 
            +
                return (
         | 
| 84 | 
            +
                    [f"1{noun}"]
         | 
| 85 | 
            +
                    + [f"{num}{noun}s" for num in range(minimum + 1, maximum + 1)]
         | 
| 86 | 
            +
                    + [f"{maximum+1}+{noun}s"]
         | 
| 87 | 
            +
                )
         | 
| 88 | 
            +
             | 
| 89 | 
            +
             | 
| 90 | 
            +
            PEOPLE_TAGS = (
         | 
| 91 | 
            +
                _people_tag("girl") + _people_tag("boy") + _people_tag("other") + ["no humans"]
         | 
| 92 | 
            +
            )
         | 
| 93 | 
            +
             | 
| 94 | 
            +
             | 
| 95 | 
            +
            def gen_prompt_text(output: UpsamplingOutput):
         | 
| 96 | 
            +
                # separate people tags (e.g. 1girl)
         | 
| 97 | 
            +
                people_tags = []
         | 
| 98 | 
            +
                other_general_tags = []
         | 
| 99 | 
            +
                
         | 
| 100 | 
            +
                for tag in output.general_tags.split(","):
         | 
| 101 | 
            +
                    tag = tag.strip()
         | 
| 102 | 
            +
                    if tag in PEOPLE_TAGS:
         | 
| 103 | 
            +
                        people_tags.append(tag)
         | 
| 104 | 
            +
                    else:
         | 
| 105 | 
            +
                        other_general_tags.append(tag)
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                return ", ".join(
         | 
| 108 | 
            +
                    [
         | 
| 109 | 
            +
                        part.strip()
         | 
| 110 | 
            +
                        for part in [
         | 
| 111 | 
            +
                            *people_tags,
         | 
| 112 | 
            +
                            output.character_tags,
         | 
| 113 | 
            +
                            output.copyright_tags,
         | 
| 114 | 
            +
                            *other_general_tags,
         | 
| 115 | 
            +
                            output.upsampled_tags,
         | 
| 116 | 
            +
                            output.rating_tag,
         | 
| 117 | 
            +
                        ]
         | 
| 118 | 
            +
                        if part.strip() != ""
         | 
| 119 | 
            +
                    ]
         | 
| 120 | 
            +
                )
         | 
| 121 | 
            +
             | 
| 122 | 
            +
             | 
| 123 | 
            +
            def elapsed_time_format(elapsed_time: float) -> str:
         | 
| 124 | 
            +
                return f"Elapsed: {elapsed_time:.2f} seconds"
         | 
| 125 | 
            +
             | 
| 126 | 
            +
             | 
| 127 | 
            +
            def parse_upsampling_output(
         | 
| 128 | 
            +
                upsampler: Callable[..., UpsamplingOutput],
         | 
| 129 | 
            +
            ):
         | 
| 130 | 
            +
                def _parse_upsampling_output(*args) -> tuple[str, str, dict]:
         | 
| 131 | 
            +
                    output = upsampler(*args)
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    return (
         | 
| 134 | 
            +
                        gen_prompt_text(output),
         | 
| 135 | 
            +
                        elapsed_time_format(output.elapsed_time),
         | 
| 136 | 
            +
                        gr.update(interactive=True),
         | 
| 137 | 
            +
                        gr.update(interactive=True),
         | 
| 138 | 
            +
                    )
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                return _parse_upsampling_output
         | 
| 141 | 
            +
             | 
| 142 | 
            +
             | 
| 143 | 
            +
            class V2UI:
         | 
| 144 | 
            +
                model_name: str | None = None
         | 
| 145 | 
            +
                model: V2Model
         | 
| 146 | 
            +
                tokenizer: DartTokenizer
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                input_components: list[Component] = []
         | 
| 149 | 
            +
                generate_btn: gr.Button
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                def on_generate(
         | 
| 152 | 
            +
                    self,
         | 
| 153 | 
            +
                    model_name: str,
         | 
| 154 | 
            +
                    copyright_tags: str,
         | 
| 155 | 
            +
                    character_tags: str,
         | 
| 156 | 
            +
                    general_tags: str,
         | 
| 157 | 
            +
                    rating_tag: RatingTag,
         | 
| 158 | 
            +
                    aspect_ratio_tag: AspectRatioTag,
         | 
| 159 | 
            +
                    length_tag: LengthTag,
         | 
| 160 | 
            +
                    identity_tag: IdentityTag,
         | 
| 161 | 
            +
                    ban_tags: str,
         | 
| 162 | 
            +
                    *args,
         | 
| 163 | 
            +
                ) -> UpsamplingOutput:
         | 
| 164 | 
            +
                    if self.model_name is None or self.model_name != model_name:
         | 
| 165 | 
            +
                        models = prepare_models(V2_ALL_MODELS[model_name])
         | 
| 166 | 
            +
                        self.model = models["model"]
         | 
| 167 | 
            +
                        self.tokenizer = models["tokenizer"]
         | 
| 168 | 
            +
                        self.model_name = model_name
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                    # normalize tags
         | 
| 171 | 
            +
                    # copyright_tags = normalize_tags(self.tokenizer, copyright_tags)
         | 
| 172 | 
            +
                    # character_tags = normalize_tags(self.tokenizer, character_tags)
         | 
| 173 | 
            +
                    # general_tags = normalize_tags(self.tokenizer, general_tags)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    ban_token_ids = self.tokenizer.encode(ban_tags.strip())
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                    prompt = compose_prompt(
         | 
| 178 | 
            +
                        prompt=general_tags,
         | 
| 179 | 
            +
                        copyright=copyright_tags,
         | 
| 180 | 
            +
                        character=character_tags,
         | 
| 181 | 
            +
                        rating=rating_tag,
         | 
| 182 | 
            +
                        aspect_ratio=aspect_ratio_tag,
         | 
| 183 | 
            +
                        length=length_tag,
         | 
| 184 | 
            +
                        identity=identity_tag,
         | 
| 185 | 
            +
                    )
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    start = time.time()
         | 
| 188 | 
            +
                    upsampled_tags = generate_tags(
         | 
| 189 | 
            +
                        self.model,
         | 
| 190 | 
            +
                        self.tokenizer,
         | 
| 191 | 
            +
                        prompt,
         | 
| 192 | 
            +
                        ban_token_ids,
         | 
| 193 | 
            +
                    )
         | 
| 194 | 
            +
                    elapsed_time = time.time() - start
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                    return UpsamplingOutput(
         | 
| 197 | 
            +
                        upsampled_tags=upsampled_tags,
         | 
| 198 | 
            +
                        copyright_tags=copyright_tags,
         | 
| 199 | 
            +
                        character_tags=character_tags,
         | 
| 200 | 
            +
                        general_tags=general_tags,
         | 
| 201 | 
            +
                        rating_tag=rating_tag,
         | 
| 202 | 
            +
                        aspect_ratio_tag=aspect_ratio_tag,
         | 
| 203 | 
            +
                        length_tag=length_tag,
         | 
| 204 | 
            +
                        identity_tag=identity_tag,
         | 
| 205 | 
            +
                        elapsed_time=elapsed_time,
         | 
| 206 | 
            +
                    )
         | 
| 207 | 
            +
             | 
| 208 | 
            +
             | 
| 209 | 
            +
            def parse_upsampling_output_simple(upsampler: UpsamplingOutput):
         | 
| 210 | 
            +
                return gen_prompt_text(upsampler)
         | 
| 211 | 
            +
             | 
| 212 | 
            +
             | 
| 213 | 
            +
            v2 = V2UI()
         | 
| 214 | 
            +
             | 
| 215 | 
            +
             | 
| 216 | 
            +
            def v2_upsampling_prompt(model: str = "dart-v2-moe-sft", copyright: str = "", character: str = "",
         | 
| 217 | 
            +
                                      general_tags: str = "", rating: str = "nsfw", aspect_ratio: str = "square",
         | 
| 218 | 
            +
                                        length: str = "very_long", identity: str = "lax", ban_tags: str = "censored"):
         | 
| 219 | 
            +
                raw_prompt = parse_upsampling_output_simple(v2.on_generate(model, copyright, character, general_tags,
         | 
| 220 | 
            +
                                                                            rating, aspect_ratio, length, identity, ban_tags))
         | 
| 221 | 
            +
                return raw_prompt
         | 
| 222 | 
            +
             | 
| 223 | 
            +
             | 
| 224 | 
            +
            def load_dict_from_csv(filename):
         | 
| 225 | 
            +
                dict = {}
         | 
| 226 | 
            +
                if not Path(filename).exists():
         | 
| 227 | 
            +
                    if Path('./tagger/', filename).exists(): filename = str(Path('./tagger/', filename))
         | 
| 228 | 
            +
                    else: return dict
         | 
| 229 | 
            +
                try:
         | 
| 230 | 
            +
                    with open(filename, 'r', encoding="utf-8") as f:
         | 
| 231 | 
            +
                        lines = f.readlines()
         | 
| 232 | 
            +
                except Exception:
         | 
| 233 | 
            +
                    print(f"Failed to open dictionary file: {filename}")
         | 
| 234 | 
            +
                    return dict
         | 
| 235 | 
            +
                for line in lines:
         | 
| 236 | 
            +
                    parts = line.strip().split(',')
         | 
| 237 | 
            +
                    dict[parts[0]] = parts[1]
         | 
| 238 | 
            +
                return dict
         | 
| 239 | 
            +
             | 
| 240 | 
            +
             | 
| 241 | 
            +
            anime_series_dict = load_dict_from_csv('character_series_dict.csv')
         | 
| 242 | 
            +
             | 
| 243 | 
            +
             | 
| 244 | 
            +
            def select_random_character(series: str, character: str):
         | 
| 245 | 
            +
                from random import seed, randrange
         | 
| 246 | 
            +
                seed()
         | 
| 247 | 
            +
                character_list = list(anime_series_dict.keys())
         | 
| 248 | 
            +
                character = character_list[randrange(len(character_list) - 1)]
         | 
| 249 | 
            +
                series = anime_series_dict.get(character.split(",")[0].strip(), "")
         | 
| 250 | 
            +
                return series, character
         | 
| 251 | 
            +
             | 
| 252 | 
            +
             | 
| 253 | 
            +
            def v2_random_prompt(general_tags: str = "", copyright: str = "", character: str = "", rating: str = "nsfw",
         | 
| 254 | 
            +
                                  aspect_ratio: str = "square", length: str = "very_long", identity: str = "lax",
         | 
| 255 | 
            +
                                  ban_tags: str = "censored", model: str = "dart-v2-moe-sft"):
         | 
| 256 | 
            +
                if copyright == "" and character == "":
         | 
| 257 | 
            +
                    copyright, character = select_random_character("", "")
         | 
| 258 | 
            +
                raw_prompt = v2_upsampling_prompt(model, copyright, character, general_tags, rating,
         | 
| 259 | 
            +
                                                   aspect_ratio, length, identity, ban_tags)
         | 
| 260 | 
            +
                return raw_prompt, copyright, character
         | 
    	
        utils.py
    ADDED
    
    | @@ -0,0 +1,421 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import re
         | 
| 3 | 
            +
            import gradio as gr
         | 
| 4 | 
            +
            from constants import (
         | 
| 5 | 
            +
                DIFFUSERS_FORMAT_LORAS,
         | 
| 6 | 
            +
                CIVITAI_API_KEY,
         | 
| 7 | 
            +
                HF_TOKEN,
         | 
| 8 | 
            +
                MODEL_TYPE_CLASS,
         | 
| 9 | 
            +
                DIRECTORY_LORAS,
         | 
| 10 | 
            +
            )
         | 
| 11 | 
            +
            from huggingface_hub import HfApi
         | 
| 12 | 
            +
            from diffusers import DiffusionPipeline
         | 
| 13 | 
            +
            from huggingface_hub import model_info as model_info_data
         | 
| 14 | 
            +
            from diffusers.pipelines.pipeline_loading_utils import variant_compatible_siblings
         | 
| 15 | 
            +
            from pathlib import PosixPath
         | 
| 16 | 
            +
            from unidecode import unidecode
         | 
| 17 | 
            +
            import urllib.parse
         | 
| 18 | 
            +
            import copy
         | 
| 19 | 
            +
            import requests
         | 
| 20 | 
            +
            from requests.adapters import HTTPAdapter
         | 
| 21 | 
            +
            from urllib3.util import Retry
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            USER_AGENT = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:127.0) Gecko/20100101 Firefox/127.0'
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            def request_json_data(url):
         | 
| 27 | 
            +
                model_version_id = url.split('/')[-1]
         | 
| 28 | 
            +
                if "?modelVersionId=" in model_version_id:
         | 
| 29 | 
            +
                    match = re.search(r'modelVersionId=(\d+)', url)
         | 
| 30 | 
            +
                    model_version_id = match.group(1)
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                endpoint_url = f"https://civitai.com/api/v1/model-versions/{model_version_id}"
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                params = {}
         | 
| 35 | 
            +
                headers = {'User-Agent': USER_AGENT, 'content-type': 'application/json'}
         | 
| 36 | 
            +
                session = requests.Session()
         | 
| 37 | 
            +
                retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
         | 
| 38 | 
            +
                session.mount("https://", HTTPAdapter(max_retries=retries))
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                try:
         | 
| 41 | 
            +
                    result = session.get(endpoint_url, params=params, headers=headers, stream=True, timeout=(3.0, 15))
         | 
| 42 | 
            +
                    result.raise_for_status()
         | 
| 43 | 
            +
                    json_data = result.json()
         | 
| 44 | 
            +
                    return json_data if json_data else None
         | 
| 45 | 
            +
                except Exception as e:
         | 
| 46 | 
            +
                    print(f"Error: {e}")
         | 
| 47 | 
            +
                    return None
         | 
| 48 | 
            +
             | 
| 49 | 
            +
             | 
| 50 | 
            +
            class ModelInformation:
         | 
| 51 | 
            +
                def __init__(self, json_data):
         | 
| 52 | 
            +
                    self.model_version_id = json_data.get("id", "")
         | 
| 53 | 
            +
                    self.model_id = json_data.get("modelId", "")
         | 
| 54 | 
            +
                    self.download_url = json_data.get("downloadUrl", "")
         | 
| 55 | 
            +
                    self.model_url = f"https://civitai.com/models/{self.model_id}?modelVersionId={self.model_version_id}"
         | 
| 56 | 
            +
                    self.filename_url = next(
         | 
| 57 | 
            +
                        (v.get("name", "") for v in json_data.get("files", []) if str(self.model_version_id) in v.get("downloadUrl", "")), ""
         | 
| 58 | 
            +
                    )
         | 
| 59 | 
            +
                    self.filename_url = self.filename_url if self.filename_url else ""
         | 
| 60 | 
            +
                    self.description = json_data.get("description", "")
         | 
| 61 | 
            +
                    if self.description is None: self.description = ""
         | 
| 62 | 
            +
                    self.model_name = json_data.get("model", {}).get("name", "")
         | 
| 63 | 
            +
                    self.model_type = json_data.get("model", {}).get("type", "")
         | 
| 64 | 
            +
                    self.nsfw = json_data.get("model", {}).get("nsfw", False)
         | 
| 65 | 
            +
                    self.poi = json_data.get("model", {}).get("poi", False)
         | 
| 66 | 
            +
                    self.images = [img.get("url", "") for img in json_data.get("images", [])]
         | 
| 67 | 
            +
                    self.example_prompt = json_data.get("trainedWords", [""])[0] if json_data.get("trainedWords") else ""
         | 
| 68 | 
            +
                    self.original_json = copy.deepcopy(json_data)
         | 
| 69 | 
            +
             | 
| 70 | 
            +
             | 
| 71 | 
            +
            def retrieve_model_info(url):
         | 
| 72 | 
            +
                json_data = request_json_data(url)
         | 
| 73 | 
            +
                if not json_data:
         | 
| 74 | 
            +
                    return None
         | 
| 75 | 
            +
                model_descriptor = ModelInformation(json_data)
         | 
| 76 | 
            +
                return model_descriptor
         | 
| 77 | 
            +
             | 
| 78 | 
            +
             | 
| 79 | 
            +
            def download_things(directory, url, hf_token="", civitai_api_key="", romanize=False):
         | 
| 80 | 
            +
                url = url.strip()
         | 
| 81 | 
            +
                downloaded_file_path = None
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                if "drive.google.com" in url:
         | 
| 84 | 
            +
                    original_dir = os.getcwd()
         | 
| 85 | 
            +
                    os.chdir(directory)
         | 
| 86 | 
            +
                    os.system(f"gdown --fuzzy {url}")
         | 
| 87 | 
            +
                    os.chdir(original_dir)
         | 
| 88 | 
            +
                elif "huggingface.co" in url:
         | 
| 89 | 
            +
                    url = url.replace("?download=true", "")
         | 
| 90 | 
            +
                    # url = urllib.parse.quote(url, safe=':/')  # fix encoding
         | 
| 91 | 
            +
                    if "/blob/" in url:
         | 
| 92 | 
            +
                        url = url.replace("/blob/", "/resolve/")
         | 
| 93 | 
            +
                    user_header = f'"Authorization: Bearer {hf_token}"'
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                    filename = unidecode(url.split('/')[-1]) if romanize else url.split('/')[-1]
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    if hf_token:
         | 
| 98 | 
            +
                        os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory}  -o {filename}")
         | 
| 99 | 
            +
                    else:
         | 
| 100 | 
            +
                        os.system(f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory}  -o {filename}")
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    downloaded_file_path = os.path.join(directory, filename)
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                elif "civitai.com" in url:
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    if not civitai_api_key:
         | 
| 107 | 
            +
                        print("\033[91mYou need an API key to download Civitai models.\033[0m")
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    model_profile = retrieve_model_info(url)
         | 
| 110 | 
            +
                    if model_profile.download_url and model_profile.filename_url:
         | 
| 111 | 
            +
                        url = model_profile.download_url
         | 
| 112 | 
            +
                        filename = unidecode(model_profile.filename_url) if romanize else model_profile.filename_url
         | 
| 113 | 
            +
                    else:
         | 
| 114 | 
            +
                        if "?" in url:
         | 
| 115 | 
            +
                            url = url.split("?")[0]
         | 
| 116 | 
            +
                        filename = ""
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                    url_dl = url + f"?token={civitai_api_key}"
         | 
| 119 | 
            +
                    print(f"Filename: {filename}")
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                    param_filename = ""
         | 
| 122 | 
            +
                    if filename:
         | 
| 123 | 
            +
                        param_filename = f"-o '{filename}'"
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                    aria2_command = (
         | 
| 126 | 
            +
                        f'aria2c --console-log-level=error --summary-interval=10 -c -x 16 '
         | 
| 127 | 
            +
                        f'-k 1M -s 16 -d "{directory}" {param_filename} "{url_dl}"'
         | 
| 128 | 
            +
                    )
         | 
| 129 | 
            +
                    os.system(aria2_command)
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                    if param_filename and os.path.exists(os.path.join(directory, filename)):
         | 
| 132 | 
            +
                        downloaded_file_path = os.path.join(directory, filename)
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                    # # PLAN B
         | 
| 135 | 
            +
                    # # Follow the redirect to get the actual download URL
         | 
| 136 | 
            +
                    # curl_command = (
         | 
| 137 | 
            +
                    #     f'curl -L -sI --connect-timeout 5 --max-time 5 '
         | 
| 138 | 
            +
                    #     f'-H "Content-Type: application/json" '
         | 
| 139 | 
            +
                    #     f'-H "Authorization: Bearer {civitai_api_key}" "{url}"'
         | 
| 140 | 
            +
                    # )
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                    # headers = os.popen(curl_command).read()
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    # # Look for the redirected "Location" URL
         | 
| 145 | 
            +
                    # location_match = re.search(r'location: (.+)', headers, re.IGNORECASE)
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                    # if location_match:
         | 
| 148 | 
            +
                    #     redirect_url = location_match.group(1).strip()
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    #     # Extract the filename from the redirect URL's "Content-Disposition"
         | 
| 151 | 
            +
                    #     filename_match = re.search(r'filename%3D%22(.+?)%22', redirect_url)
         | 
| 152 | 
            +
                    #     if filename_match:
         | 
| 153 | 
            +
                    #         encoded_filename = filename_match.group(1)
         | 
| 154 | 
            +
                    #         # Decode the URL-encoded filename
         | 
| 155 | 
            +
                    #         decoded_filename = urllib.parse.unquote(encoded_filename)
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                    #         filename = unidecode(decoded_filename) if romanize else decoded_filename
         | 
| 158 | 
            +
                    #         print(f"Filename: {filename}")
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                    #         aria2_command = (
         | 
| 161 | 
            +
                    #             f'aria2c --console-log-level=error --summary-interval=10 -c -x 16 '
         | 
| 162 | 
            +
                    #             f'-k 1M -s 16 -d "{directory}" -o "{filename}" "{redirect_url}"'
         | 
| 163 | 
            +
                    #         )
         | 
| 164 | 
            +
                    #         return_code = os.system(aria2_command)
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                    #         # if return_code != 0:
         | 
| 167 | 
            +
                    #         #     raise RuntimeError(f"Failed to download file: {filename}. Error code: {return_code}")
         | 
| 168 | 
            +
                    #         downloaded_file_path = os.path.join(directory, filename)
         | 
| 169 | 
            +
                    #         if not os.path.exists(downloaded_file_path):
         | 
| 170 | 
            +
                    #             downloaded_file_path = None
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                    # if not downloaded_file_path:
         | 
| 173 | 
            +
                    #     # Old method
         | 
| 174 | 
            +
                    #     if "?" in url:
         | 
| 175 | 
            +
                    #         url = url.split("?")[0]
         | 
| 176 | 
            +
                    #     url = url + f"?token={civitai_api_key}"
         | 
| 177 | 
            +
                    #     os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                else:
         | 
| 180 | 
            +
                    os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                return downloaded_file_path
         | 
| 183 | 
            +
             | 
| 184 | 
            +
             | 
| 185 | 
            +
            def get_model_list(directory_path):
         | 
| 186 | 
            +
                model_list = []
         | 
| 187 | 
            +
                valid_extensions = {'.ckpt', '.pt', '.pth', '.safetensors', '.bin'}
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                for filename in os.listdir(directory_path):
         | 
| 190 | 
            +
                    if os.path.splitext(filename)[1] in valid_extensions:
         | 
| 191 | 
            +
                        # name_without_extension = os.path.splitext(filename)[0]
         | 
| 192 | 
            +
                        file_path = os.path.join(directory_path, filename)
         | 
| 193 | 
            +
                        # model_list.append((name_without_extension, file_path))
         | 
| 194 | 
            +
                        model_list.append(file_path)
         | 
| 195 | 
            +
                        print('\033[34mFILE: ' + file_path + '\033[0m')
         | 
| 196 | 
            +
                return model_list
         | 
| 197 | 
            +
             | 
| 198 | 
            +
             | 
| 199 | 
            +
            def extract_parameters(input_string):
         | 
| 200 | 
            +
                parameters = {}
         | 
| 201 | 
            +
                input_string = input_string.replace("\n", "")
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                if "Negative prompt:" not in input_string:
         | 
| 204 | 
            +
                    if "Steps:" in input_string:
         | 
| 205 | 
            +
                        input_string = input_string.replace("Steps:", "Negative prompt: Steps:")
         | 
| 206 | 
            +
                    else:
         | 
| 207 | 
            +
                        print("Invalid metadata")
         | 
| 208 | 
            +
                        parameters["prompt"] = input_string
         | 
| 209 | 
            +
                        return parameters
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                parm = input_string.split("Negative prompt:")
         | 
| 212 | 
            +
                parameters["prompt"] = parm[0].strip()
         | 
| 213 | 
            +
                if "Steps:" not in parm[1]:
         | 
| 214 | 
            +
                    print("Steps not detected")
         | 
| 215 | 
            +
                    parameters["neg_prompt"] = parm[1].strip()
         | 
| 216 | 
            +
                    return parameters
         | 
| 217 | 
            +
                parm = parm[1].split("Steps:")
         | 
| 218 | 
            +
                parameters["neg_prompt"] = parm[0].strip()
         | 
| 219 | 
            +
                input_string = "Steps:" + parm[1]
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                # Extracting Steps
         | 
| 222 | 
            +
                steps_match = re.search(r'Steps: (\d+)', input_string)
         | 
| 223 | 
            +
                if steps_match:
         | 
| 224 | 
            +
                    parameters['Steps'] = int(steps_match.group(1))
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                # Extracting Size
         | 
| 227 | 
            +
                size_match = re.search(r'Size: (\d+x\d+)', input_string)
         | 
| 228 | 
            +
                if size_match:
         | 
| 229 | 
            +
                    parameters['Size'] = size_match.group(1)
         | 
| 230 | 
            +
                    width, height = map(int, parameters['Size'].split('x'))
         | 
| 231 | 
            +
                    parameters['width'] = width
         | 
| 232 | 
            +
                    parameters['height'] = height
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                # Extracting other parameters
         | 
| 235 | 
            +
                other_parameters = re.findall(r'(\w+): (.*?)(?=, \w+|$)', input_string)
         | 
| 236 | 
            +
                for param in other_parameters:
         | 
| 237 | 
            +
                    parameters[param[0]] = param[1].strip('"')
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                return parameters
         | 
| 240 | 
            +
             | 
| 241 | 
            +
             | 
| 242 | 
            +
            def get_my_lora(link_url, romanize):
         | 
| 243 | 
            +
                l_name = ""
         | 
| 244 | 
            +
                for url in [url.strip() for url in link_url.split(',')]:
         | 
| 245 | 
            +
                    if not os.path.exists(f"./loras/{url.split('/')[-1]}"):
         | 
| 246 | 
            +
                        l_name = download_things(DIRECTORY_LORAS, url, HF_TOKEN, CIVITAI_API_KEY, romanize)
         | 
| 247 | 
            +
                new_lora_model_list = get_model_list(DIRECTORY_LORAS)
         | 
| 248 | 
            +
                new_lora_model_list.insert(0, "None")
         | 
| 249 | 
            +
                new_lora_model_list = new_lora_model_list + DIFFUSERS_FORMAT_LORAS
         | 
| 250 | 
            +
                msg_lora = "Downloaded"
         | 
| 251 | 
            +
                if l_name:
         | 
| 252 | 
            +
                    msg_lora += f": <b>{l_name}</b>"
         | 
| 253 | 
            +
                    print(msg_lora)
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                return gr.update(
         | 
| 256 | 
            +
                    choices=new_lora_model_list
         | 
| 257 | 
            +
                ), gr.update(
         | 
| 258 | 
            +
                    choices=new_lora_model_list
         | 
| 259 | 
            +
                ), gr.update(
         | 
| 260 | 
            +
                    choices=new_lora_model_list
         | 
| 261 | 
            +
                ), gr.update(
         | 
| 262 | 
            +
                    choices=new_lora_model_list
         | 
| 263 | 
            +
                ), gr.update(
         | 
| 264 | 
            +
                    choices=new_lora_model_list
         | 
| 265 | 
            +
                ), gr.update(
         | 
| 266 | 
            +
                    value=msg_lora
         | 
| 267 | 
            +
                )
         | 
| 268 | 
            +
             | 
| 269 | 
            +
             | 
| 270 | 
            +
            def info_html(json_data, title, subtitle):
         | 
| 271 | 
            +
                return f"""
         | 
| 272 | 
            +
                    <div style='padding: 0; border-radius: 10px;'>
         | 
| 273 | 
            +
                        <p style='margin: 0; font-weight: bold;'>{title}</p>
         | 
| 274 | 
            +
                        <details>
         | 
| 275 | 
            +
                            <summary>Details</summary>
         | 
| 276 | 
            +
                            <p style='margin: 0; font-weight: bold;'>{subtitle}</p>
         | 
| 277 | 
            +
                        </details>
         | 
| 278 | 
            +
                    </div>
         | 
| 279 | 
            +
                    """
         | 
| 280 | 
            +
             | 
| 281 | 
            +
             | 
| 282 | 
            +
            def get_model_type(repo_id: str):
         | 
| 283 | 
            +
                api = HfApi(token=os.environ.get("HF_TOKEN"))  # if use private or gated model
         | 
| 284 | 
            +
                default = "SD 1.5"
         | 
| 285 | 
            +
                try:
         | 
| 286 | 
            +
                    model = api.model_info(repo_id=repo_id, timeout=5.0)
         | 
| 287 | 
            +
                    tags = model.tags
         | 
| 288 | 
            +
                    for tag in tags:
         | 
| 289 | 
            +
                        if tag in MODEL_TYPE_CLASS.keys(): return MODEL_TYPE_CLASS.get(tag, default)
         | 
| 290 | 
            +
                except Exception:
         | 
| 291 | 
            +
                    return default
         | 
| 292 | 
            +
                return default
         | 
| 293 | 
            +
             | 
| 294 | 
            +
             | 
| 295 | 
            +
            def restart_space(repo_id: str, factory_reboot: bool):
         | 
| 296 | 
            +
                api = HfApi(token=os.environ.get("HF_TOKEN"))
         | 
| 297 | 
            +
                try:
         | 
| 298 | 
            +
                    runtime = api.get_space_runtime(repo_id=repo_id)
         | 
| 299 | 
            +
                    if runtime.stage == "RUNNING":
         | 
| 300 | 
            +
                        api.restart_space(repo_id=repo_id, factory_reboot=factory_reboot)
         | 
| 301 | 
            +
                        print(f"Restarting space: {repo_id}")
         | 
| 302 | 
            +
                    else:
         | 
| 303 | 
            +
                        print(f"Space {repo_id} is in stage: {runtime.stage}")
         | 
| 304 | 
            +
                except Exception as e:
         | 
| 305 | 
            +
                    print(e)
         | 
| 306 | 
            +
             | 
| 307 | 
            +
             | 
| 308 | 
            +
            def extract_exif_data(image):
         | 
| 309 | 
            +
                if image is None: return ""
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                try:
         | 
| 312 | 
            +
                    metadata_keys = ['parameters', 'metadata', 'prompt', 'Comment']
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                    for key in metadata_keys:
         | 
| 315 | 
            +
                        if key in image.info:
         | 
| 316 | 
            +
                            return image.info[key]
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                    return str(image.info)
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                except Exception as e:
         | 
| 321 | 
            +
                    return f"Error extracting metadata: {str(e)}"
         | 
| 322 | 
            +
             | 
| 323 | 
            +
             | 
| 324 | 
            +
            def create_mask_now(img, invert):
         | 
| 325 | 
            +
                import numpy as np
         | 
| 326 | 
            +
                import time
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                time.sleep(0.5)
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                transparent_image = img["layers"][0]
         | 
| 331 | 
            +
             | 
| 332 | 
            +
                # Extract the alpha channel
         | 
| 333 | 
            +
                alpha_channel = np.array(transparent_image)[:, :, 3]
         | 
| 334 | 
            +
             | 
| 335 | 
            +
                # Create a binary mask by thresholding the alpha channel
         | 
| 336 | 
            +
                binary_mask = alpha_channel > 1
         | 
| 337 | 
            +
             | 
| 338 | 
            +
                if invert:
         | 
| 339 | 
            +
                    print("Invert")
         | 
| 340 | 
            +
                    # Invert the binary mask so that the drawn shape is white and the rest is black
         | 
| 341 | 
            +
                    binary_mask = np.invert(binary_mask)
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                # Convert the binary mask to a 3-channel RGB mask
         | 
| 344 | 
            +
                rgb_mask = np.stack((binary_mask,) * 3, axis=-1)
         | 
| 345 | 
            +
             | 
| 346 | 
            +
                # Convert the mask to uint8
         | 
| 347 | 
            +
                rgb_mask = rgb_mask.astype(np.uint8) * 255
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                return img["background"], rgb_mask
         | 
| 350 | 
            +
             | 
| 351 | 
            +
             | 
| 352 | 
            +
            def download_diffuser_repo(repo_name: str, model_type: str, revision: str = "main", token=True):
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                variant = None
         | 
| 355 | 
            +
                if token is True and not os.environ.get("HF_TOKEN"):
         | 
| 356 | 
            +
                    token = None
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                if model_type == "SDXL":
         | 
| 359 | 
            +
                    info = model_info_data(
         | 
| 360 | 
            +
                        repo_name,
         | 
| 361 | 
            +
                        token=token,
         | 
| 362 | 
            +
                        revision=revision,
         | 
| 363 | 
            +
                        timeout=5.0,
         | 
| 364 | 
            +
                    )
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                    filenames = {sibling.rfilename for sibling in info.siblings}
         | 
| 367 | 
            +
                    model_filenames, variant_filenames = variant_compatible_siblings(
         | 
| 368 | 
            +
                        filenames, variant="fp16"
         | 
| 369 | 
            +
                    )
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                    if len(variant_filenames):
         | 
| 372 | 
            +
                        variant = "fp16"
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                cached_folder = DiffusionPipeline.download(
         | 
| 375 | 
            +
                    pretrained_model_name=repo_name,
         | 
| 376 | 
            +
                    force_download=False,
         | 
| 377 | 
            +
                    token=token,
         | 
| 378 | 
            +
                    revision=revision,
         | 
| 379 | 
            +
                    # mirror="https://hf-mirror.com",
         | 
| 380 | 
            +
                    variant=variant,
         | 
| 381 | 
            +
                    use_safetensors=True,
         | 
| 382 | 
            +
                    trust_remote_code=False,
         | 
| 383 | 
            +
                    timeout=5.0,
         | 
| 384 | 
            +
                )
         | 
| 385 | 
            +
             | 
| 386 | 
            +
                if isinstance(cached_folder, PosixPath):
         | 
| 387 | 
            +
                    cached_folder = cached_folder.as_posix()
         | 
| 388 | 
            +
             | 
| 389 | 
            +
                # Task model
         | 
| 390 | 
            +
                # from huggingface_hub import hf_hub_download
         | 
| 391 | 
            +
                # hf_hub_download(
         | 
| 392 | 
            +
                #     task_model,
         | 
| 393 | 
            +
                #     filename="diffusion_pytorch_model.safetensors",  # fix fp16 variant
         | 
| 394 | 
            +
                # )
         | 
| 395 | 
            +
             | 
| 396 | 
            +
                return cached_folder
         | 
| 397 | 
            +
             | 
| 398 | 
            +
             | 
| 399 | 
            +
            def progress_step_bar(step, total):
         | 
| 400 | 
            +
                # Calculate the percentage for the progress bar width
         | 
| 401 | 
            +
                percentage = min(100, ((step / total) * 100))
         | 
| 402 | 
            +
             | 
| 403 | 
            +
                return f"""
         | 
| 404 | 
            +
                    <div style="position: relative; width: 100%; background-color: gray; border-radius: 5px; overflow: hidden;">
         | 
| 405 | 
            +
                        <div style="width: {percentage}%; height: 17px; background-color: #800080; transition: width 0.5s;"></div>
         | 
| 406 | 
            +
                        <div style="position: absolute; width: 100%; text-align: center; color: white; top: 0; line-height: 19px; font-size: 13px;">
         | 
| 407 | 
            +
                            {int(percentage)}%
         | 
| 408 | 
            +
                        </div>
         | 
| 409 | 
            +
                    </div>
         | 
| 410 | 
            +
                    """
         | 
| 411 | 
            +
             | 
| 412 | 
            +
             | 
| 413 | 
            +
            def html_template_message(msg):
         | 
| 414 | 
            +
                return f"""
         | 
| 415 | 
            +
                    <div style="position: relative; width: 100%; background-color: gray; border-radius: 5px; overflow: hidden;">
         | 
| 416 | 
            +
                        <div style="width: 0%; height: 17px; background-color: #800080; transition: width 0.5s;"></div>
         | 
| 417 | 
            +
                        <div style="position: absolute; width: 100%; text-align: center; color: white; top: 0; line-height: 19px; font-size: 14px; font-weight: bold; text-shadow: 1px 1px 2px black;">
         | 
| 418 | 
            +
                            {msg}
         | 
| 419 | 
            +
                        </div>
         | 
| 420 | 
            +
                    </div>
         | 
| 421 | 
            +
                    """
         | 
 
			
