import gradio as gr import json import math import os import random import socket import time import gc from diffusers.models import AutoencoderKL import numpy as np import torch # import torch.distributed as dist # Removed distributed from torchvision.transforms.functional import to_pil_image from tqdm import tqdm from transformers import AutoModel, AutoTokenizer import models # Assuming models.py is in the same directory or accessible from transport import Sampler, create_transport # Assuming transport.py is accessible # --- Globals for Models and Config --- # --- Set these paths and configurations correctly --- CKPT_PATH = "checkpoint/consolidated.00-of-01.pth" # IMPORTANT: Set path to your specific checkpoint file MODEL_ARGS_PATH = "checkpoint/model_args.pth" # IMPORTANT: Set path to model args VAE_TYPE = os.environ.get("VAE_TYPE", "flux") # Or "ema", "mse", "sdxl" PRECISION = os.environ.get("PRECISION", "bf16") # Or "fp32" TEXT_ENCODER_MODEL = os.environ.get("TEXT_ENCODER_MODEL", 'google/gemma-2-2B') # --- End of required path configurations --- # Will be loaded by load_models() - Initialized to None tokenizer = None text_encoder = None vae = None model = None sampler = None transport = None device = "cuda" if torch.cuda.is_available() else "cpu" # Keep dtype definition global as it's used in loading and inference dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[PRECISION] train_args = None cap_feat_dim = None # --- Helper Functions (adapted from original script) --- def encode_prompt(prompt_batch, text_encoder_model, tokenizer_model, target_device, target_dtype): """ Encodes prompts, moving the text encoder to the target device temporarily. """ captions = [] for caption in prompt_batch: # Simplified: always use the provided caption captions.append(caption if isinstance(caption, str) else caption[0]) # Move text_encoder to the target device for encoding text_encoder_model.to(target_device) prompt_embeds, prompt_masks = None, None # Initialize try: with torch.no_grad(): text_inputs = tokenizer_model( captions, padding=True, pad_to_multiple_of=8, max_length=256, # Make this configurable if needed truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids.to(target_device) prompt_masks = text_inputs.attention_mask.to(target_device) # Use autocast on the target device during model inference with torch.autocast(device_type=target_device.split(':')[0], dtype=target_dtype): prompt_embeds = text_encoder_model( input_ids=text_input_ids, attention_mask=prompt_masks, output_hidden_states=True, ).hidden_states[-2] finally: # Move text encoder back to CPU regardless of success/failure text_encoder_model.to("cpu") if target_device != "cpu": torch.cuda.empty_cache() # Clear cache after moving model off GPU gc.collect() # Explicitly call garbage collector # Return results moved to CPU to minimize GPU memory usage outside inference return prompt_embeds.cpu(), prompt_masks.cpu() def none_or_str(value): # Keep this helper if used by create_transport or sampler if value == "None" or value is None: return None return str(value) # --- Model Loading Function --- def load_models(): global tokenizer, text_encoder, vae, model, sampler, transport, device, dtype, train_args, cap_feat_dim print("--- Starting Model Loading (targeting CPU initially) ---") torch.set_grad_enabled(False) # Load training args if not os.path.exists(MODEL_ARGS_PATH): raise FileNotFoundError(f"Model args file not found: {MODEL_ARGS_PATH}") # Force loading pickled data with weights_only=False if it contains non-tensor data (like args) train_args = torch.load(MODEL_ARGS_PATH, map_location='cpu', weights_only=False) print("Loaded model arguments:", json.dumps(train_args.__dict__, indent=2)) # --- Load Tokenizer --- print(f"Creating Tokenizer: {TEXT_ENCODER_MODEL}") tokenizer = AutoTokenizer.from_pretrained(TEXT_ENCODER_MODEL) tokenizer.padding_side = "right" print("Tokenizer loaded.") # --- Load Text Encoder (to CPU) --- print(f"Creating Text Encoder: {TEXT_ENCODER_MODEL}") # Load with desired dtype but map to CPU initially text_encoder = AutoModel.from_pretrained( TEXT_ENCODER_MODEL, torch_dtype=dtype ).eval().to("cpu") # Explicitly move to CPU cap_feat_dim = text_encoder.config.hidden_size print("Text encoder loaded to CPU.") # --- Load VAE (to CPU) --- print(f"Loading VAE: {VAE_TYPE}") vae_path = f"black-forest-labs/FLUX.1-dev" if VAE_TYPE == "flux" else \ f"stabilityai/sdxl-vae" if VAE_TYPE == "sdxl" else \ f"stabilityai/sd-vae-ft-{VAE_TYPE}" if VAE_TYPE in ["ema", "mse"] else None if vae_path is None: raise ValueError(f"Unsupported VAE type: {VAE_TYPE}") subfolder = "vae" if VAE_TYPE == "flux" else None vae = AutoencoderKL.from_pretrained(vae_path, subfolder=subfolder, torch_dtype=dtype).to("cpu") # Explicitly move to CPU vae.requires_grad_(False) vae.eval() print("VAE loaded to CPU.") # --- Load DiT Model (to CPU) --- print(f"Creating DiT model: {train_args.model}") if not hasattr(models, train_args.model): raise AttributeError(f"Model class {train_args.model} not found in models module.") # Instantiate model on CPU model = models.__dict__[train_args.model]( in_channels=16, # Should match FLUX VAE latent channels qk_norm=getattr(train_args, 'qk_norm', True), # Use getattr for safety cap_feat_dim=cap_feat_dim, # Add other args as needed ).eval().to("cpu") # Ensure model is created on CPU and in eval mode print(f"Loading DiT checkpoint: {CKPT_PATH}") if not os.path.exists(CKPT_PATH): raise FileNotFoundError(f"DiT checkpoint not found: {CKPT_PATH}") state_dict = torch.load(CKPT_PATH, map_location='cpu') # Load state dict to CPU # Handle potential 'module.' prefix new_state_dict = {} for k, v in state_dict.items(): new_state_dict[k[len('module.'):] if k.startswith('module.') else k] = v del state_dict # Free memory missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False) if missing_keys: print(f"Warning: Missing keys in state_dict: {missing_keys}") if unexpected_keys: print(f"Warning: Unexpected keys in state_dict: {unexpected_keys}") del new_state_dict # Free memory # Model remains on CPU here print("DiT model loaded and checkpoint applied to CPU.") # --- Setup Sampler --- path_type = "Linear" prediction = "velocity" loss_weight = None train_eps = None sample_eps = None transport = create_transport(path_type, prediction, loss_weight, train_eps, sample_eps) sampler = Sampler(transport) # Sampler itself doesn't hold large weights print("Sampler initialized.") print("--- Model Loading Complete ---") # --- Gradio Inference Function --- def generate_image( prompt, negative_prompt, system_type, solver, resolution_str, guidance_scale, num_steps, seed, time_shifting_factor, # For DPM solver flow shift t_shift, # For ODE solver time shift # ODE specific args atol=1e-6, rtol=1e-3, ): """Generates an image, moving models to GPU only when needed.""" global model, vae, text_encoder, tokenizer, sampler # Access global models if model is None or vae is None or text_encoder is None or tokenizer is None or sampler is None: return None, "Models not loaded. Please wait or check console.", -1 print("\n--- Starting Generation ---") start_time = time.time() # --- Prepare Inputs --- if int(seed) == -1: current_seed = random.randint(0, 2**32 - 1) else: current_seed = int(seed) torch.manual_seed(current_seed) np.random.seed(current_seed) random.seed(current_seed) print(f"Using Seed: {current_seed}") # System Prompt Logic (same as before) # ... (system prompt selection logic remains unchanged) ... if system_type == "align": system_prompt = "You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts. " # noqa elif system_type == "base": system_prompt = "You are an assistant designed to generate high-quality images based on user prompts. " # noqa elif system_type == "aesthetics": system_prompt = "You are an assistant designed to generate high-quality images with highest degree of aesthetics based on user prompts. " # noqa elif system_type == "real": system_prompt = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts. " # noqa elif system_type == "4grid": system_prompt = "You are an assistant designed to generate four high-quality images with highest degree of aesthetics arranged in 2x2 grids based on user prompts. " # noqa elif system_type == "tags": system_prompt = "You are an assistant designed to generate high-quality images based on user prompts based on danbooru tags. " elif system_type == "empty": system_prompt = "" else: # Default or fallback system_prompt = "You are an assistant designed to generate high-quality images based on user prompts. " full_prompt = system_prompt + prompt full_negative_prompt = system_prompt + negative_prompt if negative_prompt else "" try: w_str, h_str = resolution_str.split("x") w, h = int(w_str), int(h_str) latent_w, latent_h = w // 8, h // 8 # Assume VAE stride 8 except Exception as e: print(f"Error parsing resolution: {e}") return None, f"Invalid resolution format: {resolution_str}. Use WxH.", current_seed # --- Encode Prompts (using temporary GPU placement) --- print("Encoding prompts...") # encode_prompt handles moving text_encoder to device and back to CPU cap_feats_cpu, cap_mask_cpu = encode_prompt( [full_prompt, full_negative_prompt], text_encoder, tokenizer, device, # Target device for encoding dtype # Target dtype for encoding ) print("Prompt encoding complete (Text Encoder back on CPU).") # --- Prepare for Sampling --- samples = None # Initialize samples variable with torch.no_grad(): try: print("Moving DiT model to GPU...") model.to(device) # Move main model to GPU # Initial noise needs to be on the target device z = torch.randn([1, 16, latent_h, latent_w], device=device, dtype=dtype) # Match latent channels (16 for flux) z = z.repeat(2, 1, 1, 1) # Repeat for 2 prompts (pos, neg) - Model forward handles CFG splitting # Model kwargs need tensors on the target device model_kwargs = dict( cap_feats=cap_feats_cpu.to(device, dtype=dtype), # Move features to GPU cap_mask=cap_mask_cpu.to(device), # Move mask to GPU cfg_scale=guidance_scale, ) del cap_feats_cpu, cap_mask_cpu # Free CPU memory gc.collect() # --- Perform Sampling (on GPU) --- print(f"Starting sampling with solver: {solver}, steps: {num_steps}...") with torch.autocast(device_type=device.split(':')[0], dtype=dtype): # Autocast for the diffusion model forward pass if solver == "dpm": # DPM typically uses Linear/velocity, create a temporary specific sampler _transport = create_transport("Linear", "velocity") _sampler = Sampler(_transport) sample_fn = _sampler.sample_dpm( model.forward_with_cfg, # Pass the model method directly model_kwargs=model_kwargs, ) # DPM sample_fn might need z repeated differently or handle cfg internally # The original code's `forward_with_cfg` likely handles the split, so z shape [2, C, H, W] is correct samples = sample_fn(z, steps=num_steps, order=2, skip_type="time_uniform_flow", method="multistep", flow_shift=time_shifting_factor) else: # ODE Solvers # Use the globally loaded sampler, but ensure model call uses the GPU model sample_fn = sampler.sample_ode( sampling_method=solver, num_steps=num_steps, atol=atol, rtol=rtol, time_shifting_factor=t_shift ) # sample_fn expects (z, model_call, **model_kwargs) # Pass the GPU model's method directly samples = sample_fn(z, model.forward_with_cfg, **model_kwargs)[-1] # Sampling finished, samples are on GPU. Keep only the positive sample(s) for decoding. samples = samples[:1].detach() # Detach from graph, take positive sample(s) (index 0) finally: # Move DiT model back to CPU regardless of sampling success/failure print("Moving DiT model back to CPU...") model.to("cpu") del model_kwargs # Remove reference to tensors potentially on GPU del z if device != "cpu": torch.cuda.empty_cache() gc.collect() # --- Decode and Post-process --- pil_image = None try: if samples is None: raise RuntimeError("Sampling failed, 'samples' tensor is None.") print("Moving VAE to GPU for decoding...") vae.to(device) # Move VAE to GPU print("Decoding samples...") with torch.no_grad(): # No gradients needed for VAE decode # Ensure samples are on the correct device and dtype for VAE samples = samples.to(device=device, dtype=vae.dtype) # VAE scaling factor and shift (using previously determined values) scaling_factor = vae.config.scaling_factor shift_factor = vae.config.shift_factor samples = samples / scaling_factor + shift_factor # Use autocast for VAE decoding as well if using mixed precision with torch.autocast(device_type=device.split(':')[0], dtype=dtype): decoded_samples = vae.decode(samples)[0] # Decode expects input matching VAE dtype # Move decoded samples to CPU for post-processing decoded_samples = decoded_samples.cpu() print("Decoding complete.") # Normalize [approx -1, 1] -> [0, 1] -> PIL (on CPU) decoded_samples = (decoded_samples + 1.0) / 2.0 decoded_samples.clamp_(0.0, 1.0) # Convert to PIL Image (already on CPU) pil_image = to_pil_image(decoded_samples[0].float()) # Use first sample, ensure float finally: # Move VAE back to CPU regardless of decoding success/failure print("Moving VAE back to CPU...") vae.to("cpu") del samples # Remove reference to tensor del decoded_samples if device != "cpu": torch.cuda.empty_cache() gc.collect() end_time = time.time() generation_time = round(end_time - start_time, 2) print(f"Generation finished in {generation_time} seconds.") if pil_image is None: return None, "Error during generation or decoding.", current_seed return pil_image, f"Generated in {generation_time}s", current_seed # --- Load Models Before Launching UI --- load_models() # Load models to CPU initially # --- Create Gradio Interface (UI definition remains the same) --- with gr.Blocks() as demo: gr.Markdown("# Lumina 2.0 Generation Demo") gr.Markdown("Enter a prompt and adjust settings to generate an image using Lumina 2.0 (Memory Optimized).") with gr.Row(): with gr.Column(scale=2): prompt = gr.Textbox(label="Prompt", placeholder="e.g., A photo of an astronaut riding a horse on the moon") negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="e.g., low quality, blurry, text, watermark") system_type = gr.Dropdown( label="System Prompt Type", choices=["align", "base", "aesthetics", "real", "4grid", "tags", "empty"], value="real" ) resolution = gr.Dropdown( label="Resolution (Width x Height)", choices=["1024x1024", "1280x768", "768x1280", "1536x1024", "1024x1536"], # Common resolutions value="1024x1024" ) solver = gr.Dropdown( label="Solver", choices=["dpm", "euler", "midpoint", "heun", "rk4"], value="dpm" ) run_button = gr.Button("Generate Image", variant="primary") with gr.Column(scale=1): guidance_scale = gr.Slider(label="CFG Scale", minimum=1.0, maximum=15.0, step=0.5, value=4.0) num_steps = gr.Slider(label="Sampling Steps", minimum=10, maximum=200, step=1, value=50) seed = gr.Number(label="Seed", value=-1, precision=0) # -1 for random time_shifting_factor = gr.Slider(label="Time Shift Factor (DPM)", minimum=0.0, maximum=10.0, step=0.1, value=1.0) t_shift = gr.Slider(label="T Shift (ODE)", minimum=0, maximum=10, step=1, value=4) with gr.Row(): output_image = gr.Image(label="Generated Image") with gr.Row(): status_text = gr.Textbox(label="Status", interactive=False) final_seed = gr.Number(label="Seed Used", interactive=False) run_button.click( fn=generate_image, inputs=[ prompt, negative_prompt, system_type, solver, resolution, guidance_scale, num_steps, seed, time_shifting_factor, t_shift ], outputs=[output_image, status_text, final_seed] ) gr.Examples( examples=[ ["A hyperrealistic photo of a cat wearing sunglasses and playing a saxophone on a beach at sunset", "", "real", "dpm", "1024x1024", 4.5, 40, -1, 1.0, 4], ["cinematic film still, A vast cyberpunk cityscape at night, raining, neon lights reflecting on wet streets, high detail", "low detail, drawing, illustration, sketch", "aesthetics", "euler", "1280x768", 5.0, 60, -1, 1.0, 4], ["Macro shot of a dewdrop on a spider web, intricate details, forest background, bokeh", "blurry, unfocused", "real", "dpm", "1024x1024", 4.0, 50, 12345, 1.0, 4], ], inputs=[prompt, negative_prompt, system_type, solver, resolution, guidance_scale, num_steps, seed, time_shifting_factor, t_shift], outputs=[output_image, status_text, final_seed], fn=generate_image ) # --- Launch the Gradio App --- if __name__ == "__main__": # Consider adding share=True if you need external access demo.launch()