Spaces:
Running
on
Zero
Running
on
Zero
| import argparse | |
| import copy | |
| import math | |
| import random | |
| from typing import Any | |
| import pdb | |
| import os | |
| import time | |
| from PIL import Image, ImageOps | |
| import torch | |
| from accelerate import Accelerator | |
| from library.device_utils import clean_memory_on_device | |
| from safetensors.torch import load_file | |
| from networks import lora_flux | |
| from library import flux_models, flux_train_utils_recraft as flux_train_utils, flux_utils, sd3_train_utils, \ | |
| strategy_base, strategy_flux, train_util | |
| from torchvision import transforms | |
| import train_network | |
| from library.utils import setup_logging | |
| from diffusers.utils import load_image | |
| import numpy as np | |
| setup_logging() | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| def load_target_model( | |
| fp8_base: bool, | |
| pretrained_model_name_or_path: str, | |
| disable_mmap_load_safetensors: bool, | |
| clip_l_path: str, | |
| fp8_base_unet: bool, | |
| t5xxl_path: str, | |
| ae_path: str, | |
| weight_dtype: torch.dtype, | |
| accelerator: Accelerator | |
| ): | |
| # Determine the loading data type | |
| loading_dtype = None if fp8_base else weight_dtype | |
| # Load the main model to the accelerator's device | |
| _, model = flux_utils.load_flow_model( | |
| pretrained_model_name_or_path, | |
| # loading_dtype, | |
| torch.float8_e4m3fn, | |
| # accelerator.device, # Changed from "cpu" to accelerator.device | |
| "cpu", | |
| disable_mmap=disable_mmap_load_safetensors | |
| ) | |
| if fp8_base: | |
| # Check dtype of the model | |
| if model.dtype in {torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz}: | |
| raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}") | |
| elif model.dtype == torch.float8_e4m3fn: | |
| logger.info("Loaded fp8 FLUX model") | |
| # Load the CLIP model to the accelerator's device | |
| clip_l = flux_utils.load_clip_l( | |
| clip_l_path, | |
| weight_dtype, | |
| # accelerator.device, # Changed from "cpu" to accelerator.device | |
| "cpu", | |
| disable_mmap=disable_mmap_load_safetensors | |
| ) | |
| clip_l.eval() | |
| # Determine the loading data type for T5XXL | |
| if fp8_base and not fp8_base_unet: | |
| loading_dtype_t5xxl = None # as is | |
| else: | |
| loading_dtype_t5xxl = weight_dtype | |
| # Load the T5XXL model to the accelerator's device | |
| t5xxl = flux_utils.load_t5xxl( | |
| t5xxl_path, | |
| loading_dtype_t5xxl, | |
| # accelerator.device, # Changed from "cpu" to accelerator.device | |
| "cpu", | |
| disable_mmap=disable_mmap_load_safetensors | |
| ) | |
| t5xxl.eval() | |
| if fp8_base and not fp8_base_unet: | |
| # Check dtype of the T5XXL model | |
| if t5xxl.dtype in {torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz}: | |
| raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}") | |
| elif t5xxl.dtype == torch.float8_e4m3fn: | |
| logger.info("Loaded fp8 T5XXL model") | |
| # Load the AE model to the accelerator's device | |
| ae = flux_utils.load_ae( | |
| ae_path, | |
| weight_dtype, | |
| # accelerator.device, # Changed from "cpu" to accelerator.device | |
| "cpu", | |
| disable_mmap=disable_mmap_load_safetensors | |
| ) | |
| # # Wrap models with Accelerator for potential distributed setups | |
| # model, clip_l, t5xxl, ae = accelerator.prepare(model, clip_l, t5xxl, ae) | |
| return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model | |
| import torchvision.transforms as transforms | |
| class ResizeWithPadding: | |
| def __init__(self, size, fill=255): | |
| self.size = size | |
| self.fill = fill | |
| def __call__(self, img): | |
| if isinstance(img, np.ndarray): | |
| img = Image.fromarray(img) | |
| elif not isinstance(img, Image.Image): | |
| raise TypeError("Input must be a PIL Image or a NumPy array") | |
| width, height = img.size | |
| if width == height: | |
| img = img.resize((self.size, self.size), Image.LANCZOS) | |
| else: | |
| max_dim = max(width, height) | |
| new_img = Image.new("RGB", (max_dim, max_dim), (self.fill, self.fill, self.fill)) | |
| new_img.paste(img, ((max_dim - width) // 2, (max_dim - height) // 2)) | |
| img = new_img.resize((self.size, self.size), Image.LANCZOS) | |
| return img | |
| def sample(args, accelerator, vae, text_encoder, flux, output_dir, sample_images, sample_prompts): | |
| def encode_images_to_latents(vae, images): | |
| # Get image dimensions | |
| b, c, h, w = images.shape | |
| num_split = 2 if args.frame_num == 4 else 3 | |
| # Split the image into three parts | |
| img_parts = [images[:, :, :, i * w // num_split:(i + 1) * w // num_split] for i in range(num_split)] | |
| # Encode each part | |
| latents = [vae.encode(img) for img in img_parts] | |
| # Concatenate latents in the latent space to reconstruct the full image | |
| latents = torch.cat(latents, dim=-1) | |
| return latents | |
| def encode_images_to_latents2(vae, images): | |
| latents = vae.encode(images) | |
| return latents | |
| # Directly use precomputed conditions | |
| conditions = {} | |
| with torch.no_grad(): | |
| for image_path, prompt_dict in zip(sample_images, sample_prompts): | |
| prompt = prompt_dict.get("prompt", "") | |
| if prompt not in conditions: | |
| logger.info(f"Cache conditions for image: {image_path} with prompt: {prompt}") | |
| resize_transform = ResizeWithPadding(size=512, fill=255) if args.frame_num == 4 else ResizeWithPadding(size=352, fill=255) | |
| img_transforms = transforms.Compose([ | |
| resize_transform, | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5], [0.5]), | |
| ]) | |
| # Load and preprocess image | |
| image = img_transforms(np.array(load_image(image_path), dtype=np.uint8)).unsqueeze(0).to( | |
| # accelerator.device, # Move image to CUDA | |
| vae.device, | |
| dtype=vae.dtype | |
| ) | |
| latents = encode_images_to_latents2(vae, image) | |
| # Log the shape of latents | |
| logger.debug(f"Encoded latents shape for prompt '{prompt}': {latents.shape}") | |
| # Store conditions on CUDA | |
| # conditions[prompt] = latents[:,:,latents.shape[2]//2:latents.shape[2], :latents.shape[3]//2].to("cpu") | |
| conditions[prompt] = latents.to("cpu") | |
| sample_conditions = conditions | |
| if sample_conditions is not None: | |
| conditions = {k: v for k, v in sample_conditions.items()} # Already on CUDA | |
| sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs | |
| text_encoder[0].to(accelerator.device) | |
| text_encoder[1].to(accelerator.device) | |
| tokenize_strategy = strategy_flux.FluxTokenizeStrategy(512) | |
| text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(True) | |
| with accelerator.autocast(), torch.no_grad(): | |
| for prompt_dict in sample_prompts: | |
| for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: | |
| if p not in sample_prompts_te_outputs: | |
| logger.info(f"Cache Text Encoder outputs for prompt: {p}") | |
| tokens_and_masks = tokenize_strategy.tokenize(p) | |
| sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( | |
| tokenize_strategy, text_encoder, tokens_and_masks, True | |
| ) | |
| logger.info(f"Generating image") | |
| save_dir = output_dir | |
| os.makedirs(save_dir, exist_ok=True) | |
| with torch.no_grad(), accelerator.autocast(): | |
| for prompt_dict in sample_prompts: | |
| sample_image_inference( | |
| args, | |
| accelerator, | |
| flux, | |
| text_encoder, | |
| vae, | |
| save_dir, | |
| prompt_dict, | |
| sample_prompts_te_outputs, | |
| None, | |
| conditions | |
| ) | |
| clean_memory_on_device(accelerator.device) | |
| def sample_image_inference( | |
| args, | |
| accelerator: Accelerator, | |
| flux: flux_models.Flux, | |
| text_encoder, | |
| ae: flux_models.AutoEncoder, | |
| save_dir, | |
| prompt_dict, | |
| sample_prompts_te_outputs, | |
| prompt_replacement, | |
| sample_images_ae_outputs | |
| ): | |
| # Extract parameters from prompt_dict | |
| sample_steps = prompt_dict.get("sample_steps", 20) | |
| width = prompt_dict.get("width", 1024) if args.frame_num == 4 else prompt_dict.get("width", 1056) | |
| height = prompt_dict.get("height", 1024) if args.frame_num == 4 else prompt_dict.get("height", 1056) | |
| scale = prompt_dict.get("scale", 1.0) | |
| seed = prompt_dict.get("seed") | |
| prompt: str = prompt_dict.get("prompt", "") | |
| if prompt_replacement is not None: | |
| prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) | |
| if seed is not None: | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| else: | |
| # True random sample image generation | |
| torch.seed() | |
| torch.cuda.seed() | |
| # Ensure height and width are divisible by 16 | |
| height = max(64, height - height % 16) | |
| width = max(64, width - width % 16) | |
| logger.info(f"prompt: {prompt}") | |
| logger.info(f"height: {height}") | |
| logger.info(f"width: {width}") | |
| logger.info(f"sample_steps: {sample_steps}") | |
| logger.info(f"scale: {scale}") | |
| if seed is not None: | |
| logger.info(f"seed: {seed}") | |
| # Encode prompts | |
| # Assuming that TokenizeStrategy and TextEncodingStrategy are compatible with Accelerator | |
| text_encoder_conds = [] | |
| if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs: | |
| text_encoder_conds = sample_prompts_te_outputs[prompt] | |
| logger.info(f"Using cached text encoder outputs for prompt: {prompt}") | |
| if sample_images_ae_outputs and prompt in sample_images_ae_outputs: | |
| ae_outputs = sample_images_ae_outputs[prompt] | |
| else: | |
| ae_outputs = None | |
| # ae_outputs = torch.load('ae_outputs.pth', map_location='cuda:0') | |
| # text_encoder_conds = torch.load('text_encoder_conds.pth', map_location='cuda:0') | |
| l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds | |
| # 打印调试信息 | |
| logger.debug( | |
| f"l_pooled shape: {l_pooled.shape}, t5_out shape: {t5_out.shape}, txt_ids shape: {txt_ids.shape}, t5_attn_mask shape: {t5_attn_mask.shape}") | |
| # 采样图像 | |
| weight_dtype = ae.dtype # TODO: give dtype as argument | |
| packed_latent_height = height // 16 | |
| packed_latent_width = width // 16 | |
| # 打印调试信息 | |
| logger.debug(f"packed_latent_height: {packed_latent_height}, packed_latent_width: {packed_latent_width}") | |
| # 准备噪声张量在 CUDA 上 | |
| noise = torch.randn( | |
| 1, | |
| packed_latent_height * packed_latent_width, | |
| 16 * 2 * 2, | |
| device=accelerator.device, | |
| dtype=weight_dtype, | |
| generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None, | |
| ) | |
| timesteps = flux_train_utils.get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True | |
| img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to( | |
| accelerator.device, dtype=weight_dtype | |
| ) | |
| t5_attn_mask = t5_attn_mask.to(accelerator.device) | |
| clip_l, t5xxl = text_encoder | |
| # ae.to("cpu") | |
| clip_l.to("cpu") | |
| t5xxl.to("cpu") | |
| clean_memory_on_device(accelerator.device) | |
| flux.to("cuda") | |
| for param in flux.parameters(): | |
| param.requires_grad = False | |
| # 执行去噪 | |
| with accelerator.autocast(), torch.no_grad(): | |
| x = flux_train_utils.denoise(args, flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, | |
| guidance=scale, t5_attn_mask=t5_attn_mask, ae_outputs=ae_outputs) | |
| # 打印x的形状 | |
| logger.debug(f"x shape after denoise: {x.shape}") | |
| x = x.float() | |
| x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) | |
| # 将潜在向量转换为图像 | |
| # clean_memory_on_device(accelerator.device) | |
| ae.to(accelerator.device) | |
| with accelerator.autocast(), torch.no_grad(): | |
| x = ae.decode(x) | |
| ae.to("cpu") | |
| clean_memory_on_device(accelerator.device) | |
| x = x.clamp(-1, 1) | |
| x = x.permute(0, 2, 3, 1) | |
| image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0]) | |
| # 生成唯一的文件名 | |
| ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) | |
| seed_suffix = "" if seed is None else f"_{seed}" | |
| i: int = prompt_dict.get("enum", 0) # Ensure 'enum' exists | |
| img_filename = f"{ts_str}{seed_suffix}_{i}.png" # Added 'i' to filename for uniqueness | |
| image.save(os.path.join(save_dir, img_filename)) | |
| def setup_argparse(): | |
| parser = argparse.ArgumentParser(description="FLUX-Controlnet-Inpainting Inference Script") | |
| # Paths | |
| parser.add_argument('--base_flux_checkpoint', type=str, required=True, | |
| help='Path to BASE_FLUX_CHECKPOINT') | |
| parser.add_argument('--lora_weights_path', type=str, required=True, | |
| help='Path to LORA_WEIGHTS_PATH') | |
| parser.add_argument('--clip_l_path', type=str, required=True, | |
| help='Path to CLIP_L_PATH') | |
| parser.add_argument('--t5xxl_path', type=str, required=True, | |
| help='Path to T5XXL_PATH') | |
| parser.add_argument('--ae_path', type=str, required=True, | |
| help='Path to AE_PATH') | |
| parser.add_argument('--sample_images_file', type=str, required=True, | |
| help='Path to SAMPLE_IMAGES_FILE') | |
| parser.add_argument('--sample_prompts_file', type=str, required=True, | |
| help='Path to SAMPLE_PROMPTS_FILE') | |
| parser.add_argument('--output_dir', type=str, required=True, | |
| help='Directory to save OUTPUT_DIR') | |
| parser.add_argument('--frame_num', type=int, choices=[4, 9], required=True, | |
| help="The number of steps in the generated step diagram (choose 4 or 9)") | |
| return parser.parse_args() | |
| def main(args): | |
| accelerator = Accelerator(mixed_precision='bf16', device_placement=True) | |
| BASE_FLUX_CHECKPOINT = args.base_flux_checkpoint | |
| LORA_WEIGHTS_PATH = args.lora_weights_path | |
| CLIP_L_PATH = args.clip_l_path | |
| T5XXL_PATH = args.t5xxl_path | |
| AE_PATH = args.ae_path | |
| SAMPLE_IMAGES_FILE = args.sample_images_file | |
| SAMPLE_PROMPTS_FILE = args.sample_prompts_file | |
| OUTPUT_DIR = args.output_dir | |
| with open(SAMPLE_IMAGES_FILE, "r", encoding="utf-8") as f: | |
| image_lines = f.readlines() | |
| sample_images = [line.strip() for line in image_lines if line.strip() and not line.strip().startswith("#")] | |
| sample_prompts = train_util.load_prompts(SAMPLE_PROMPTS_FILE) | |
| # Load models onto CUDA via Accelerator | |
| _, [clip_l, t5xxl], ae, model = load_target_model( | |
| fp8_base=True, | |
| pretrained_model_name_or_path=BASE_FLUX_CHECKPOINT, | |
| disable_mmap_load_safetensors=False, | |
| clip_l_path=CLIP_L_PATH, | |
| fp8_base_unet=False, | |
| t5xxl_path=T5XXL_PATH, | |
| ae_path=AE_PATH, | |
| weight_dtype=torch.bfloat16, | |
| accelerator=accelerator | |
| ) | |
| model.eval() | |
| clip_l.eval() | |
| t5xxl.eval() | |
| ae.eval() | |
| # LoRA | |
| multiplier = 1.0 | |
| weights_sd = load_file(LORA_WEIGHTS_PATH) | |
| lora_model, _ = lora_flux.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, | |
| True) | |
| lora_model.apply_to([clip_l, t5xxl], model) | |
| info = lora_model.load_state_dict(weights_sd, strict=True) | |
| logger.info(f"Loaded LoRA weights from {LORA_WEIGHTS_PATH}: {info}") | |
| lora_model.eval() | |
| lora_model.to("cuda") | |
| # Set text encoders | |
| text_encoder = [clip_l, t5xxl] | |
| sample(args, accelerator, vae=ae, text_encoder=text_encoder, flux=model, output_dir=OUTPUT_DIR, | |
| sample_images=sample_images, sample_prompts=sample_prompts) | |
| if __name__ == "__main__": | |
| args = setup_argparse() | |
| main(args) | |