Spaces:
Sleeping
Sleeping
| # Minimum Inference Code for FLUX | |
| import argparse | |
| import datetime | |
| import math | |
| import os | |
| import random | |
| from typing import Callable, List, Optional | |
| import einops | |
| import numpy as np | |
| import torch | |
| from tqdm import tqdm | |
| from PIL import Image | |
| import accelerate | |
| from transformers import CLIPTextModel | |
| from safetensors.torch import load_file | |
| from library import device_utils | |
| from library.device_utils import init_ipex, get_preferred_device | |
| from networks import oft_flux | |
| init_ipex() | |
| from library.utils import setup_logging, str_to_dtype | |
| setup_logging() | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| import networks.asylora_flux as lora_flux | |
| from library import flux_models, flux_utils, sd3_utils, strategy_flux | |
| def time_shift(mu: float, sigma: float, t: torch.Tensor): | |
| return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) | |
| def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]: | |
| m = (y2 - y1) / (x2 - x1) | |
| b = y1 - m * x1 | |
| return lambda x: m * x + b | |
| def get_schedule( | |
| num_steps: int, | |
| image_seq_len: int, | |
| base_shift: float = 0.5, | |
| max_shift: float = 1.15, | |
| shift: bool = True, | |
| ) -> list[float]: | |
| # extra step for zero | |
| timesteps = torch.linspace(1, 0, num_steps + 1) | |
| # shifting the schedule to favor high timesteps for higher signal images | |
| if shift: | |
| # eastimate mu based on linear estimation between two points | |
| mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) | |
| timesteps = time_shift(mu, 1.0, timesteps) | |
| return timesteps.tolist() | |
| def denoise( | |
| model: flux_models.Flux, | |
| img: torch.Tensor, | |
| img_ids: torch.Tensor, | |
| txt: torch.Tensor, | |
| txt_ids: torch.Tensor, | |
| vec: torch.Tensor, | |
| timesteps: list[float], | |
| guidance: float = 4.0, | |
| t5_attn_mask: Optional[torch.Tensor] = None, | |
| neg_txt: Optional[torch.Tensor] = None, | |
| neg_vec: Optional[torch.Tensor] = None, | |
| neg_t5_attn_mask: Optional[torch.Tensor] = None, | |
| cfg_scale: Optional[float] = None, | |
| ): | |
| # this is ignored for schnell | |
| logger.info(f"guidance: {guidance}, cfg_scale: {cfg_scale}") | |
| guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) | |
| # prepare classifier free guidance | |
| if neg_txt is not None and neg_vec is not None: | |
| b_img_ids = torch.cat([img_ids, img_ids], dim=0) | |
| b_txt_ids = torch.cat([txt_ids, txt_ids], dim=0) | |
| b_txt = torch.cat([neg_txt, txt], dim=0) | |
| b_vec = torch.cat([neg_vec, vec], dim=0) | |
| if t5_attn_mask is not None and neg_t5_attn_mask is not None: | |
| b_t5_attn_mask = torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0) | |
| else: | |
| b_t5_attn_mask = None | |
| else: | |
| b_img_ids = img_ids | |
| b_txt_ids = txt_ids | |
| b_txt = txt | |
| b_vec = vec | |
| b_t5_attn_mask = t5_attn_mask | |
| for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): | |
| t_vec = torch.full((b_img_ids.shape[0],), t_curr, dtype=img.dtype, device=img.device) | |
| # classifier free guidance | |
| if neg_txt is not None and neg_vec is not None: | |
| b_img = torch.cat([img, img], dim=0) | |
| else: | |
| b_img = img | |
| pred = model( | |
| img=b_img, | |
| img_ids=b_img_ids, | |
| txt=b_txt, | |
| txt_ids=b_txt_ids, | |
| y=b_vec, | |
| timesteps=t_vec, | |
| guidance=guidance_vec, | |
| txt_attention_mask=b_t5_attn_mask, | |
| ) | |
| # classifier free guidance | |
| if neg_txt is not None and neg_vec is not None: | |
| pred_uncond, pred = torch.chunk(pred, 2, dim=0) | |
| pred = pred_uncond + cfg_scale * (pred - pred_uncond) | |
| img = img + (t_prev - t_curr) * pred | |
| return img | |
| def do_sample( | |
| accelerator: Optional[accelerate.Accelerator], | |
| model: flux_models.Flux, | |
| img: torch.Tensor, | |
| img_ids: torch.Tensor, | |
| l_pooled: torch.Tensor, | |
| t5_out: torch.Tensor, | |
| txt_ids: torch.Tensor, | |
| num_steps: int, | |
| guidance: float, | |
| t5_attn_mask: Optional[torch.Tensor], | |
| is_schnell: bool, | |
| device: torch.device, | |
| flux_dtype: torch.dtype, | |
| neg_l_pooled: Optional[torch.Tensor] = None, | |
| neg_t5_out: Optional[torch.Tensor] = None, | |
| neg_t5_attn_mask: Optional[torch.Tensor] = None, | |
| cfg_scale: Optional[float] = None, | |
| ): | |
| logger.info(f"num_steps: {num_steps}") | |
| timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell) | |
| # denoise initial noise | |
| if accelerator: | |
| with accelerator.autocast(), torch.no_grad(): | |
| x = denoise( | |
| model, | |
| img, | |
| img_ids, | |
| t5_out, | |
| txt_ids, | |
| l_pooled, | |
| timesteps, | |
| guidance, | |
| t5_attn_mask, | |
| neg_t5_out, | |
| neg_l_pooled, | |
| neg_t5_attn_mask, | |
| cfg_scale, | |
| ) | |
| else: | |
| with torch.autocast(device_type=device.type, dtype=flux_dtype), torch.no_grad(): | |
| x = denoise( | |
| model, | |
| img, | |
| img_ids, | |
| t5_out, | |
| txt_ids, | |
| l_pooled, | |
| timesteps, | |
| guidance, | |
| t5_attn_mask, | |
| neg_t5_out, | |
| neg_l_pooled, | |
| neg_t5_attn_mask, | |
| cfg_scale, | |
| ) | |
| return x | |
| def generate_image( | |
| model, | |
| clip_l: CLIPTextModel, | |
| t5xxl, | |
| ae, | |
| prompt: str, | |
| seed: Optional[int], | |
| image_width: int, | |
| image_height: int, | |
| steps: Optional[int], | |
| guidance: float, | |
| negative_prompt: Optional[str], | |
| cfg_scale: float, | |
| ): | |
| seed = seed if seed is not None else random.randint(0, 2**32 - 1) | |
| logger.info(f"Seed: {seed}") | |
| # make first noise with packed shape | |
| # original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2 | |
| packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16) | |
| noise_dtype = torch.float32 if is_fp8(dtype) else dtype | |
| noise = torch.randn( | |
| 1, | |
| packed_latent_height * packed_latent_width, | |
| 16 * 2 * 2, | |
| device=device, | |
| dtype=noise_dtype, | |
| generator=torch.Generator(device=device).manual_seed(seed), | |
| ) | |
| # prepare img and img ids | |
| # this is needed only for img2img | |
| # img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) | |
| # if img.shape[0] == 1 and bs > 1: | |
| # img = repeat(img, "1 ... -> bs ...", bs=bs) | |
| # txt2img only needs img_ids | |
| img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width) | |
| # prepare fp8 models | |
| if is_fp8(clip_l_dtype) and (not hasattr(clip_l, "fp8_prepared") or not clip_l.fp8_prepared): | |
| logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}") | |
| clip_l.to(clip_l_dtype) # fp8 | |
| clip_l.text_model.embeddings.to(dtype=torch.bfloat16) | |
| clip_l.fp8_prepared = True | |
| if is_fp8(t5xxl_dtype) and (not hasattr(t5xxl, "fp8_prepared") or not t5xxl.fp8_prepared): | |
| logger.info(f"prepare T5xxl for fp8: set to {t5xxl_dtype}") | |
| def prepare_fp8(text_encoder, target_dtype): | |
| def forward_hook(module): | |
| def forward(hidden_states): | |
| hidden_gelu = module.act(module.wi_0(hidden_states)) | |
| hidden_linear = module.wi_1(hidden_states) | |
| hidden_states = hidden_gelu * hidden_linear | |
| hidden_states = module.dropout(hidden_states) | |
| hidden_states = module.wo(hidden_states) | |
| return hidden_states | |
| return forward | |
| for module in text_encoder.modules(): | |
| if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: | |
| # print("set", module.__class__.__name__, "to", target_dtype) | |
| module.to(target_dtype) | |
| if module.__class__.__name__ in ["T5DenseGatedActDense"]: | |
| # print("set", module.__class__.__name__, "hooks") | |
| module.forward = forward_hook(module) | |
| t5xxl.to(t5xxl_dtype) | |
| prepare_fp8(t5xxl.encoder, torch.bfloat16) | |
| t5xxl.fp8_prepared = True | |
| # prepare embeddings | |
| logger.info("Encoding prompts...") | |
| clip_l = clip_l.to(device) | |
| t5xxl = t5xxl.to(device) | |
| def encode(prpt: str): | |
| tokens_and_masks = tokenize_strategy.tokenize(prpt) | |
| with torch.no_grad(): | |
| if is_fp8(clip_l_dtype): | |
| with accelerator.autocast(): | |
| l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) | |
| else: | |
| with torch.autocast(device_type=device.type, dtype=clip_l_dtype): | |
| l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) | |
| if is_fp8(t5xxl_dtype): | |
| with accelerator.autocast(): | |
| _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( | |
| tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask | |
| ) | |
| else: | |
| with torch.autocast(device_type=device.type, dtype=t5xxl_dtype): | |
| _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( | |
| tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask | |
| ) | |
| return l_pooled, t5_out, txt_ids, t5_attn_mask | |
| l_pooled, t5_out, txt_ids, t5_attn_mask = encode(prompt) | |
| if negative_prompt: | |
| neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode(negative_prompt) | |
| else: | |
| neg_l_pooled, neg_t5_out, neg_t5_attn_mask = None, None, None | |
| # NaN check | |
| if torch.isnan(l_pooled).any(): | |
| raise ValueError("NaN in l_pooled") | |
| if torch.isnan(t5_out).any(): | |
| raise ValueError("NaN in t5_out") | |
| if args.offload: | |
| clip_l = clip_l.cpu() | |
| t5xxl = t5xxl.cpu() | |
| # del clip_l, t5xxl | |
| device_utils.clean_memory() | |
| # generate image | |
| logger.info("Generating image...") | |
| model = model.to(device) | |
| if steps is None: | |
| steps = 4 if is_schnell else 50 | |
| img_ids = img_ids.to(device) | |
| t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None | |
| x = do_sample( | |
| accelerator, | |
| model, | |
| noise, | |
| img_ids, | |
| l_pooled, | |
| t5_out, | |
| txt_ids, | |
| steps, | |
| guidance, | |
| t5_attn_mask, | |
| is_schnell, | |
| device, | |
| flux_dtype, | |
| neg_l_pooled, | |
| neg_t5_out, | |
| neg_t5_attn_mask, | |
| cfg_scale, | |
| ) | |
| if args.offload: | |
| model = model.cpu() | |
| # del model | |
| device_utils.clean_memory() | |
| # unpack | |
| x = x.float() | |
| x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2) | |
| # decode | |
| logger.info("Decoding image...") | |
| ae = ae.to(device) | |
| with torch.no_grad(): | |
| if is_fp8(ae_dtype): | |
| with accelerator.autocast(): | |
| x = ae.decode(x) | |
| else: | |
| with torch.autocast(device_type=device.type, dtype=ae_dtype): | |
| x = ae.decode(x) | |
| if args.offload: | |
| ae = ae.cpu() | |
| x = x.clamp(-1, 1) | |
| x = x.permute(0, 2, 3, 1) | |
| img = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0]) | |
| # save image | |
| output_dir = args.output_dir | |
| os.makedirs(output_dir, exist_ok=True) | |
| output_path = os.path.join(output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png") | |
| img.save(output_path) | |
| logger.info(f"Saved image to {output_path}") | |
| if __name__ == "__main__": | |
| target_height = 768 # 1024 | |
| target_width = 1360 # 1024 | |
| # steps = 50 # 28 # 50 | |
| # guidance_scale = 5 | |
| # seed = 1 # None # 1 | |
| device = get_preferred_device() | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--lora_ups_num", type=int, required=True) | |
| parser.add_argument("--lora_up_cur", type=int, required=True) | |
| parser.add_argument("--ckpt_path", type=str, required=True) | |
| parser.add_argument("--clip_l", type=str, required=False) | |
| parser.add_argument("--t5xxl", type=str, required=False) | |
| parser.add_argument("--ae", type=str, required=False) | |
| parser.add_argument("--apply_t5_attn_mask", action="store_true") | |
| parser.add_argument("--prompt", type=str, default="A photo of a cat") | |
| parser.add_argument("--output_dir", type=str, default=".") | |
| parser.add_argument("--dtype", type=str, default="bfloat16", help="base dtype") | |
| parser.add_argument("--clip_l_dtype", type=str, default=None, help="dtype for clip_l") | |
| parser.add_argument("--ae_dtype", type=str, default=None, help="dtype for ae") | |
| parser.add_argument("--t5xxl_dtype", type=str, default=None, help="dtype for t5xxl") | |
| parser.add_argument("--flux_dtype", type=str, default=None, help="dtype for flux") | |
| parser.add_argument("--seed", type=int, default=None) | |
| parser.add_argument("--steps", type=int, default=None, help="Number of steps. Default is 4 for schnell, 50 for dev") | |
| parser.add_argument("--guidance", type=float, default=3.5) | |
| parser.add_argument("--negative_prompt", type=str, default=None) | |
| parser.add_argument("--cfg_scale", type=float, default=1.0) | |
| parser.add_argument("--offload", action="store_true", help="Offload to CPU") | |
| parser.add_argument( | |
| "--lora_weights", | |
| type=str, | |
| nargs="*", | |
| default=[], | |
| help="LoRA weights, only supports networks.lora_flux and lora_oft, each argument is a `path;multiplier` (semi-colon separated)", | |
| ) | |
| parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model") | |
| parser.add_argument("--width", type=int, default=target_width) | |
| parser.add_argument("--height", type=int, default=target_height) | |
| parser.add_argument("--interactive", action="store_true") | |
| args = parser.parse_args() | |
| seed = args.seed | |
| steps = args.steps | |
| guidance_scale = args.guidance | |
| lora_ups_num = args.lora_ups_num | |
| lora_up_cur = args.lora_up_cur | |
| def is_fp8(dt): | |
| return dt in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz] | |
| dtype = str_to_dtype(args.dtype) | |
| clip_l_dtype = str_to_dtype(args.clip_l_dtype, dtype) | |
| t5xxl_dtype = str_to_dtype(args.t5xxl_dtype, dtype) | |
| ae_dtype = str_to_dtype(args.ae_dtype, dtype) | |
| flux_dtype = str_to_dtype(args.flux_dtype, dtype) | |
| logger.info(f"Dtypes for clip_l, t5xxl, ae, flux: {clip_l_dtype}, {t5xxl_dtype}, {ae_dtype}, {flux_dtype}") | |
| loading_device = "cpu" if args.offload else device | |
| use_fp8 = [is_fp8(d) for d in [dtype, clip_l_dtype, t5xxl_dtype, ae_dtype, flux_dtype]] | |
| if any(use_fp8): | |
| accelerator = accelerate.Accelerator(mixed_precision="bf16") | |
| else: | |
| accelerator = None | |
| # load clip_l | |
| logger.info(f"Loading clip_l from {args.clip_l}...") | |
| clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device) | |
| clip_l.eval() | |
| logger.info(f"Loading t5xxl from {args.t5xxl}...") | |
| t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device) | |
| t5xxl.eval() | |
| # if is_fp8(clip_l_dtype): | |
| # clip_l = accelerator.prepare(clip_l) | |
| # if is_fp8(t5xxl_dtype): | |
| # t5xxl = accelerator.prepare(t5xxl) | |
| # DiT | |
| is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device) | |
| model.eval() | |
| logger.info(f"Casting model to {flux_dtype}") | |
| model.to(flux_dtype) # make sure model is dtype | |
| # if is_fp8(flux_dtype): | |
| # model = accelerator.prepare(model) | |
| # if args.offload: | |
| # model = model.to("cpu") | |
| t5xxl_max_length = 256 if is_schnell else 512 | |
| tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length) | |
| encoding_strategy = strategy_flux.FluxTextEncodingStrategy() | |
| # AE | |
| ae = flux_utils.load_ae(args.ae, ae_dtype, loading_device) | |
| ae.eval() | |
| # if is_fp8(ae_dtype): | |
| # ae = accelerator.prepare(ae) | |
| # LoRA | |
| lora_models: List[lora_flux.LoRANetwork] = [] | |
| for weights_file in args.lora_weights: | |
| if ";" in weights_file: | |
| weights_file, multiplier = weights_file.split(";") | |
| multiplier = float(multiplier) | |
| else: | |
| multiplier = 1.0 | |
| weights_sd = load_file(weights_file) | |
| is_lora = is_oft = False | |
| for key in weights_sd.keys(): | |
| if key.startswith("lora"): | |
| is_lora = True | |
| if key.startswith("oft"): | |
| is_oft = True | |
| if is_lora or is_oft: | |
| break | |
| module = lora_flux if is_lora else oft_flux | |
| lora_model, _ = module.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True, lora_ups_num) | |
| for sub_lora in lora_model.unet_loras: | |
| sub_lora.set_lora_up_cur(lora_up_cur-1) | |
| if args.merge_lora_weights: | |
| lora_model.merge_to([clip_l, t5xxl], model, weights_sd) | |
| else: | |
| lora_model.apply_to([clip_l, t5xxl], model) | |
| info = lora_model.load_state_dict(weights_sd, strict=True) | |
| logger.info(f"Loaded LoRA weights from {weights_file}: {info}") | |
| lora_model.eval() | |
| lora_model.to(device) | |
| lora_models.append(lora_model) | |
| if not args.interactive: | |
| generate_image( | |
| model, | |
| clip_l, | |
| t5xxl, | |
| ae, | |
| args.prompt, | |
| args.seed, | |
| args.width, | |
| args.height, | |
| args.steps, | |
| args.guidance, | |
| args.negative_prompt, | |
| args.cfg_scale, | |
| ) | |
| else: | |
| # loop for interactive | |
| width = target_width | |
| height = target_height | |
| steps = None | |
| guidance = args.guidance | |
| cfg_scale = args.cfg_scale | |
| while True: | |
| print( | |
| "Enter prompt (empty to exit). Options: --w <width> --h <height> --s <steps> --d <seed> --g <guidance> --m <multipliers for LoRA>" | |
| " --n <negative prompt>, `-` for empty negative prompt --c <cfg_scale>" | |
| ) | |
| prompt = input() | |
| if prompt == "": | |
| break | |
| # parse options | |
| options = prompt.split("--") | |
| prompt = options[0].strip() | |
| seed = None | |
| negative_prompt = None | |
| for opt in options[1:]: | |
| try: | |
| opt = opt.strip() | |
| if opt.startswith("w"): | |
| width = int(opt[1:].strip()) | |
| elif opt.startswith("h"): | |
| height = int(opt[1:].strip()) | |
| elif opt.startswith("s"): | |
| steps = int(opt[1:].strip()) | |
| elif opt.startswith("d"): | |
| seed = int(opt[1:].strip()) | |
| elif opt.startswith("g"): | |
| guidance = float(opt[1:].strip()) | |
| elif opt.startswith("m"): | |
| mutipliers = opt[1:].strip().split(",") | |
| if len(mutipliers) != len(lora_models): | |
| logger.error(f"Invalid number of multipliers, expected {len(lora_models)}") | |
| continue | |
| for i, lora_model in enumerate(lora_models): | |
| lora_model.set_multiplier(float(mutipliers[i])) | |
| elif opt.startswith("n"): | |
| negative_prompt = opt[1:].strip() | |
| if negative_prompt == "-": | |
| negative_prompt = "" | |
| elif opt.startswith("c"): | |
| cfg_scale = float(opt[1:].strip()) | |
| except ValueError as e: | |
| logger.error(f"Invalid option: {opt}, {e}") | |
| generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance, negative_prompt, cfg_scale) | |
| logger.info("Done!") | |