Spaces:
Runtime error
Runtime error
| import argparse | |
| import torch | |
| from omegaconf import OmegaConf | |
| from ldm.models.diffusion.ddim import DDIMSampler | |
| from ldm.models.diffusion.plms import PLMSSampler | |
| from ldm.modules.encoders.adapter import Adapter, StyleAdapter, Adapter_light | |
| from ldm.modules.extra_condition.api import ExtraCondition | |
| from ldm.util import fix_cond_shapes, load_model_from_config, read_state_dict | |
| DEFAULT_NEGATIVE_PROMPT = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \ | |
| 'fewer digits, cropped, worst quality, low quality' | |
| def get_base_argument_parser() -> argparse.ArgumentParser: | |
| """get the base argument parser for inference scripts""" | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| '--outdir', | |
| type=str, | |
| help='dir to write results to', | |
| default=None, | |
| ) | |
| parser.add_argument( | |
| '--prompt', | |
| type=str, | |
| nargs='?', | |
| default=None, | |
| help='positive prompt', | |
| ) | |
| parser.add_argument( | |
| '--neg_prompt', | |
| type=str, | |
| default=DEFAULT_NEGATIVE_PROMPT, | |
| help='negative prompt', | |
| ) | |
| parser.add_argument( | |
| '--cond_path', | |
| type=str, | |
| default=None, | |
| help='condition image path', | |
| ) | |
| parser.add_argument( | |
| '--cond_inp_type', | |
| type=str, | |
| default='image', | |
| help='the type of the input condition image, take depth T2I as example, the input can be raw image, ' | |
| 'which depth will be calculated, or the input can be a directly a depth map image', | |
| ) | |
| parser.add_argument( | |
| '--sampler', | |
| type=str, | |
| default='ddim', | |
| choices=['ddim', 'plms'], | |
| help='sampling algorithm, currently, only ddim and plms are supported, more are on the way', | |
| ) | |
| parser.add_argument( | |
| '--steps', | |
| type=int, | |
| default=50, | |
| help='number of sampling steps', | |
| ) | |
| parser.add_argument( | |
| '--sd_ckpt', | |
| type=str, | |
| default='models/sd-v1-4.ckpt', | |
| help='path to checkpoint of stable diffusion model, both .ckpt and .safetensor are supported', | |
| ) | |
| parser.add_argument( | |
| '--vae_ckpt', | |
| type=str, | |
| default=None, | |
| help='vae checkpoint, anime SD models usually have seperate vae ckpt that need to be loaded', | |
| ) | |
| parser.add_argument( | |
| '--adapter_ckpt', | |
| type=str, | |
| default=None, | |
| help='path to checkpoint of adapter', | |
| ) | |
| parser.add_argument( | |
| '--config', | |
| type=str, | |
| default='configs/stable-diffusion/sd-v1-inference.yaml', | |
| help='path to config which constructs SD model', | |
| ) | |
| parser.add_argument( | |
| '--max_resolution', | |
| type=float, | |
| default=512 * 512, | |
| help='max image height * width, only for computer with limited vram', | |
| ) | |
| parser.add_argument( | |
| '--resize_short_edge', | |
| type=int, | |
| default=None, | |
| help='resize short edge of the input image, if this arg is set, max_resolution will not be used', | |
| ) | |
| parser.add_argument( | |
| '--C', | |
| type=int, | |
| default=4, | |
| help='latent channels', | |
| ) | |
| parser.add_argument( | |
| '--f', | |
| type=int, | |
| default=8, | |
| help='downsampling factor', | |
| ) | |
| parser.add_argument( | |
| '--scale', | |
| type=float, | |
| default=7.5, | |
| help='unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))', | |
| ) | |
| parser.add_argument( | |
| '--cond_tau', | |
| type=float, | |
| default=1.0, | |
| help='timestamp parameter that determines until which step the adapter is applied, ' | |
| 'similar as Prompt-to-Prompt tau', | |
| ) | |
| parser.add_argument( | |
| '--style_cond_tau', | |
| type=float, | |
| default=1.0, | |
| help='timestamp parameter that determines until which step the adapter is applied, ' | |
| 'similar as Prompt-to-Prompt tau', | |
| ) | |
| parser.add_argument( | |
| '--cond_weight', | |
| type=float, | |
| default=1.0, | |
| help='the adapter features are multiplied by the cond_weight. The larger the cond_weight, the more aligned ' | |
| 'the generated image and condition will be, but the generated quality may be reduced', | |
| ) | |
| parser.add_argument( | |
| '--seed', | |
| type=int, | |
| default=42, | |
| ) | |
| parser.add_argument( | |
| '--n_samples', | |
| type=int, | |
| default=4, | |
| help='# of samples to generate', | |
| ) | |
| return parser | |
| def get_sd_models(opt): | |
| """ | |
| build stable diffusion model, sampler | |
| """ | |
| # SD | |
| config = OmegaConf.load(f"{opt.config}") | |
| model = load_model_from_config(config, opt.sd_ckpt, opt.vae_ckpt) | |
| sd_model = model.to(opt.device) | |
| # sampler | |
| if opt.sampler == 'plms': | |
| sampler = PLMSSampler(model) | |
| elif opt.sampler == 'ddim': | |
| sampler = DDIMSampler(model) | |
| else: | |
| raise NotImplementedError | |
| return sd_model, sampler | |
| def get_t2i_adapter_models(opt): | |
| config = OmegaConf.load(f"{opt.config}") | |
| model = load_model_from_config(config, opt.sd_ckpt, opt.vae_ckpt) | |
| adapter_ckpt_path = getattr(opt, f'{opt.which_cond}_adapter_ckpt', None) | |
| if adapter_ckpt_path is None: | |
| adapter_ckpt_path = getattr(opt, 'adapter_ckpt') | |
| adapter_ckpt = read_state_dict(adapter_ckpt_path) | |
| new_state_dict = {} | |
| for k, v in adapter_ckpt.items(): | |
| if not k.startswith('adapter.'): | |
| new_state_dict[f'adapter.{k}'] = v | |
| else: | |
| new_state_dict[k] = v | |
| m, u = model.load_state_dict(new_state_dict, strict=False) | |
| if len(u) > 0: | |
| print(f"unexpected keys in loading adapter ckpt {adapter_ckpt_path}:") | |
| print(u) | |
| model = model.to(opt.device) | |
| # sampler | |
| if opt.sampler == 'plms': | |
| sampler = PLMSSampler(model) | |
| elif opt.sampler == 'ddim': | |
| sampler = DDIMSampler(model) | |
| else: | |
| raise NotImplementedError | |
| return model, sampler | |
| def get_cond_ch(cond_type: ExtraCondition): | |
| if cond_type == ExtraCondition.sketch or cond_type == ExtraCondition.canny: | |
| return 1 | |
| return 3 | |
| def get_adapters(opt, cond_type: ExtraCondition): | |
| adapter = {} | |
| cond_weight = getattr(opt, f'{cond_type.name}_weight', None) | |
| if cond_weight is None: | |
| cond_weight = getattr(opt, 'cond_weight') | |
| adapter['cond_weight'] = cond_weight | |
| if cond_type == ExtraCondition.style: | |
| adapter['model'] = StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8).to(opt.device) | |
| elif cond_type == ExtraCondition.color: | |
| adapter['model'] = Adapter_light( | |
| cin=64 * get_cond_ch(cond_type), | |
| channels=[320, 640, 1280, 1280], | |
| nums_rb=4).to(opt.device) | |
| else: | |
| adapter['model'] = Adapter( | |
| cin=64 * get_cond_ch(cond_type), | |
| channels=[320, 640, 1280, 1280][:4], | |
| nums_rb=2, | |
| ksize=1, | |
| sk=True, | |
| use_conv=False).to(opt.device) | |
| ckpt_path = getattr(opt, f'{cond_type.name}_adapter_ckpt', None) | |
| if ckpt_path is None: | |
| ckpt_path = getattr(opt, 'adapter_ckpt') | |
| adapter['model'].load_state_dict(torch.load(ckpt_path)) | |
| return adapter | |
| def diffusion_inference(opt, model, sampler, adapter_features, append_to_context=None): | |
| # get text embedding | |
| c = model.get_learned_conditioning([opt.prompt]) | |
| if opt.scale != 1.0: | |
| uc = model.get_learned_conditioning([opt.neg_prompt]) | |
| else: | |
| uc = None | |
| c, uc = fix_cond_shapes(model, c, uc) | |
| if not hasattr(opt, 'H'): | |
| opt.H = 512 | |
| opt.W = 512 | |
| shape = [opt.C, opt.H // opt.f, opt.W // opt.f] | |
| samples_latents, _ = sampler.sample( | |
| S=opt.steps, | |
| conditioning=c, | |
| batch_size=1, | |
| shape=shape, | |
| verbose=False, | |
| unconditional_guidance_scale=opt.scale, | |
| unconditional_conditioning=uc, | |
| x_T=None, | |
| features_adapter=adapter_features, | |
| append_to_context=append_to_context, | |
| cond_tau=opt.cond_tau, | |
| style_cond_tau=opt.style_cond_tau, | |
| ) | |
| x_samples = model.decode_first_stage(samples_latents) | |
| x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) | |
| return x_samples | |