Illustrious-Lumina-v0.03 / demo-proper.py
AngelBottomless's picture
Upload demo-proper.py
9bd91d7 verified
raw
history blame
19.5 kB
import gradio as gr
import json
import math
import os
import random
import socket
import time
import gc
from diffusers.models import AutoencoderKL
import numpy as np
import torch
# import torch.distributed as dist # Removed distributed
from torchvision.transforms.functional import to_pil_image
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer
import models # Assuming models.py is in the same directory or accessible
from transport import Sampler, create_transport # Assuming transport.py is accessible
# --- Globals for Models and Config ---
# --- Set these paths and configurations correctly ---
CKPT_PATH = "checkpoint/consolidated.00-of-01.pth" # IMPORTANT: Set path to your specific checkpoint file
MODEL_ARGS_PATH = "checkpoint/model_args.pth" # IMPORTANT: Set path to model args
VAE_TYPE = os.environ.get("VAE_TYPE", "flux") # Or "ema", "mse", "sdxl"
PRECISION = os.environ.get("PRECISION", "bf16") # Or "fp32"
TEXT_ENCODER_MODEL = os.environ.get("TEXT_ENCODER_MODEL", 'google/gemma-2-2B')
# --- End of required path configurations ---
# Will be loaded by load_models() - Initialized to None
tokenizer = None
text_encoder = None
vae = None
model = None
sampler = None
transport = None
device = "cuda" if torch.cuda.is_available() else "cpu"
# Keep dtype definition global as it's used in loading and inference
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[PRECISION]
train_args = None
cap_feat_dim = None
# --- Helper Functions (adapted from original script) ---
def encode_prompt(prompt_batch, text_encoder_model, tokenizer_model, target_device, target_dtype):
"""
Encodes prompts, moving the text encoder to the target device temporarily.
"""
captions = []
for caption in prompt_batch:
# Simplified: always use the provided caption
captions.append(caption if isinstance(caption, str) else caption[0])
# Move text_encoder to the target device for encoding
text_encoder_model.to(target_device)
prompt_embeds, prompt_masks = None, None # Initialize
try:
with torch.no_grad():
text_inputs = tokenizer_model(
captions,
padding=True,
pad_to_multiple_of=8,
max_length=256, # Make this configurable if needed
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(target_device)
prompt_masks = text_inputs.attention_mask.to(target_device)
# Use autocast on the target device during model inference
with torch.autocast(device_type=target_device.split(':')[0], dtype=target_dtype):
prompt_embeds = text_encoder_model(
input_ids=text_input_ids,
attention_mask=prompt_masks,
output_hidden_states=True,
).hidden_states[-2]
finally:
# Move text encoder back to CPU regardless of success/failure
text_encoder_model.to("cpu")
if target_device != "cpu":
torch.cuda.empty_cache() # Clear cache after moving model off GPU
gc.collect() # Explicitly call garbage collector
# Return results moved to CPU to minimize GPU memory usage outside inference
return prompt_embeds.cpu(), prompt_masks.cpu()
def none_or_str(value): # Keep this helper if used by create_transport or sampler
if value == "None" or value is None:
return None
return str(value)
# --- Model Loading Function ---
def load_models():
global tokenizer, text_encoder, vae, model, sampler, transport, device, dtype, train_args, cap_feat_dim
print("--- Starting Model Loading (targeting CPU initially) ---")
torch.set_grad_enabled(False)
# Load training args
if not os.path.exists(MODEL_ARGS_PATH):
raise FileNotFoundError(f"Model args file not found: {MODEL_ARGS_PATH}")
# Force loading pickled data with weights_only=False if it contains non-tensor data (like args)
train_args = torch.load(MODEL_ARGS_PATH, map_location='cpu', weights_only=False)
print("Loaded model arguments:", json.dumps(train_args.__dict__, indent=2))
# --- Load Tokenizer ---
print(f"Creating Tokenizer: {TEXT_ENCODER_MODEL}")
tokenizer = AutoTokenizer.from_pretrained(TEXT_ENCODER_MODEL)
tokenizer.padding_side = "right"
print("Tokenizer loaded.")
# --- Load Text Encoder (to CPU) ---
print(f"Creating Text Encoder: {TEXT_ENCODER_MODEL}")
# Load with desired dtype but map to CPU initially
text_encoder = AutoModel.from_pretrained(
TEXT_ENCODER_MODEL, torch_dtype=dtype
).eval().to("cpu") # Explicitly move to CPU
cap_feat_dim = text_encoder.config.hidden_size
print("Text encoder loaded to CPU.")
# --- Load VAE (to CPU) ---
print(f"Loading VAE: {VAE_TYPE}")
vae_path = f"black-forest-labs/FLUX.1-dev" if VAE_TYPE == "flux" else \
f"stabilityai/sdxl-vae" if VAE_TYPE == "sdxl" else \
f"stabilityai/sd-vae-ft-{VAE_TYPE}" if VAE_TYPE in ["ema", "mse"] else None
if vae_path is None:
raise ValueError(f"Unsupported VAE type: {VAE_TYPE}")
subfolder = "vae" if VAE_TYPE == "flux" else None
vae = AutoencoderKL.from_pretrained(vae_path, subfolder=subfolder, torch_dtype=dtype).to("cpu") # Explicitly move to CPU
vae.requires_grad_(False)
vae.eval()
print("VAE loaded to CPU.")
# --- Load DiT Model (to CPU) ---
print(f"Creating DiT model: {train_args.model}")
if not hasattr(models, train_args.model):
raise AttributeError(f"Model class {train_args.model} not found in models module.")
# Instantiate model on CPU
model = models.__dict__[train_args.model](
in_channels=16, # Should match FLUX VAE latent channels
qk_norm=getattr(train_args, 'qk_norm', True), # Use getattr for safety
cap_feat_dim=cap_feat_dim,
# Add other args as needed
).eval().to("cpu") # Ensure model is created on CPU and in eval mode
print(f"Loading DiT checkpoint: {CKPT_PATH}")
if not os.path.exists(CKPT_PATH):
raise FileNotFoundError(f"DiT checkpoint not found: {CKPT_PATH}")
state_dict = torch.load(CKPT_PATH, map_location='cpu') # Load state dict to CPU
# Handle potential 'module.' prefix
new_state_dict = {}
for k, v in state_dict.items():
new_state_dict[k[len('module.'):] if k.startswith('module.') else k] = v
del state_dict # Free memory
missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
if missing_keys: print(f"Warning: Missing keys in state_dict: {missing_keys}")
if unexpected_keys: print(f"Warning: Unexpected keys in state_dict: {unexpected_keys}")
del new_state_dict # Free memory
# Model remains on CPU here
print("DiT model loaded and checkpoint applied to CPU.")
# --- Setup Sampler ---
path_type = "Linear"
prediction = "velocity"
loss_weight = None
train_eps = None
sample_eps = None
transport = create_transport(path_type, prediction, loss_weight, train_eps, sample_eps)
sampler = Sampler(transport) # Sampler itself doesn't hold large weights
print("Sampler initialized.")
print("--- Model Loading Complete ---")
# --- Gradio Inference Function ---
def generate_image(
prompt,
negative_prompt,
system_type,
solver,
resolution_str,
guidance_scale,
num_steps,
seed,
time_shifting_factor, # For DPM solver flow shift
t_shift, # For ODE solver time shift
# ODE specific args
atol=1e-6,
rtol=1e-3,
):
"""Generates an image, moving models to GPU only when needed."""
global model, vae, text_encoder, tokenizer, sampler # Access global models
if model is None or vae is None or text_encoder is None or tokenizer is None or sampler is None:
return None, "Models not loaded. Please wait or check console.", -1
print("\n--- Starting Generation ---")
start_time = time.time()
# --- Prepare Inputs ---
if int(seed) == -1:
current_seed = random.randint(0, 2**32 - 1)
else:
current_seed = int(seed)
torch.manual_seed(current_seed)
np.random.seed(current_seed)
random.seed(current_seed)
print(f"Using Seed: {current_seed}")
# System Prompt Logic (same as before)
# ... (system prompt selection logic remains unchanged) ...
if system_type == "align":
system_prompt = "You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts. <Prompt Start> " # noqa
elif system_type == "base":
system_prompt = "You are an assistant designed to generate high-quality images based on user prompts. <Prompt Start> " # noqa
elif system_type == "aesthetics":
system_prompt = "You are an assistant designed to generate high-quality images with highest degree of aesthetics based on user prompts. <Prompt Start> " # noqa
elif system_type == "real":
system_prompt = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts. <Prompt Start> " # noqa
elif system_type == "4grid":
system_prompt = "You are an assistant designed to generate four high-quality images with highest degree of aesthetics arranged in 2x2 grids based on user prompts. <Prompt Start> " # noqa
elif system_type == "tags":
system_prompt = "You are an assistant designed to generate high-quality images based on user prompts based on danbooru tags. <Prompt Start> "
elif system_type == "empty":
system_prompt = ""
else: # Default or fallback
system_prompt = "You are an assistant designed to generate high-quality images based on user prompts. <Prompt Start> "
full_prompt = system_prompt + prompt
full_negative_prompt = system_prompt + negative_prompt if negative_prompt else ""
try:
w_str, h_str = resolution_str.split("x")
w, h = int(w_str), int(h_str)
latent_w, latent_h = w // 8, h // 8 # Assume VAE stride 8
except Exception as e:
print(f"Error parsing resolution: {e}")
return None, f"Invalid resolution format: {resolution_str}. Use WxH.", current_seed
# --- Encode Prompts (using temporary GPU placement) ---
print("Encoding prompts...")
# encode_prompt handles moving text_encoder to device and back to CPU
cap_feats_cpu, cap_mask_cpu = encode_prompt(
[full_prompt, full_negative_prompt],
text_encoder,
tokenizer,
device, # Target device for encoding
dtype # Target dtype for encoding
)
print("Prompt encoding complete (Text Encoder back on CPU).")
# --- Prepare for Sampling ---
samples = None # Initialize samples variable
with torch.no_grad():
try:
print("Moving DiT model to GPU...")
model.to(device) # Move main model to GPU
# Initial noise needs to be on the target device
z = torch.randn([1, 16, latent_h, latent_w], device=device, dtype=dtype) # Match latent channels (16 for flux)
z = z.repeat(2, 1, 1, 1) # Repeat for 2 prompts (pos, neg) - Model forward handles CFG splitting
# Model kwargs need tensors on the target device
model_kwargs = dict(
cap_feats=cap_feats_cpu.to(device, dtype=dtype), # Move features to GPU
cap_mask=cap_mask_cpu.to(device), # Move mask to GPU
cfg_scale=guidance_scale,
)
del cap_feats_cpu, cap_mask_cpu # Free CPU memory
gc.collect()
# --- Perform Sampling (on GPU) ---
print(f"Starting sampling with solver: {solver}, steps: {num_steps}...")
with torch.autocast(device_type=device.split(':')[0], dtype=dtype): # Autocast for the diffusion model forward pass
if solver == "dpm":
# DPM typically uses Linear/velocity, create a temporary specific sampler
_transport = create_transport("Linear", "velocity")
_sampler = Sampler(_transport)
sample_fn = _sampler.sample_dpm(
model.forward_with_cfg, # Pass the model method directly
model_kwargs=model_kwargs,
)
# DPM sample_fn might need z repeated differently or handle cfg internally
# The original code's `forward_with_cfg` likely handles the split, so z shape [2, C, H, W] is correct
samples = sample_fn(z, steps=num_steps, order=2, skip_type="time_uniform_flow", method="multistep", flow_shift=time_shifting_factor)
else: # ODE Solvers
# Use the globally loaded sampler, but ensure model call uses the GPU model
sample_fn = sampler.sample_ode(
sampling_method=solver, num_steps=num_steps,
atol=atol, rtol=rtol,
time_shifting_factor=t_shift
)
# sample_fn expects (z, model_call, **model_kwargs)
# Pass the GPU model's method directly
samples = sample_fn(z, model.forward_with_cfg, **model_kwargs)[-1]
# Sampling finished, samples are on GPU. Keep only the positive sample(s) for decoding.
samples = samples[:1].detach() # Detach from graph, take positive sample(s) (index 0)
finally:
# Move DiT model back to CPU regardless of sampling success/failure
print("Moving DiT model back to CPU...")
model.to("cpu")
del model_kwargs # Remove reference to tensors potentially on GPU
del z
if device != "cpu":
torch.cuda.empty_cache()
gc.collect()
# --- Decode and Post-process ---
pil_image = None
try:
if samples is None:
raise RuntimeError("Sampling failed, 'samples' tensor is None.")
print("Moving VAE to GPU for decoding...")
vae.to(device) # Move VAE to GPU
print("Decoding samples...")
with torch.no_grad(): # No gradients needed for VAE decode
# Ensure samples are on the correct device and dtype for VAE
samples = samples.to(device=device, dtype=vae.dtype)
# VAE scaling factor and shift (using previously determined values)
scaling_factor = vae.config.scaling_factor
shift_factor = vae.config.shift_factor
samples = samples / scaling_factor + shift_factor
# Use autocast for VAE decoding as well if using mixed precision
with torch.autocast(device_type=device.split(':')[0], dtype=dtype):
decoded_samples = vae.decode(samples)[0] # Decode expects input matching VAE dtype
# Move decoded samples to CPU for post-processing
decoded_samples = decoded_samples.cpu()
print("Decoding complete.")
# Normalize [approx -1, 1] -> [0, 1] -> PIL (on CPU)
decoded_samples = (decoded_samples + 1.0) / 2.0
decoded_samples.clamp_(0.0, 1.0)
# Convert to PIL Image (already on CPU)
pil_image = to_pil_image(decoded_samples[0].float()) # Use first sample, ensure float
finally:
# Move VAE back to CPU regardless of decoding success/failure
print("Moving VAE back to CPU...")
vae.to("cpu")
del samples # Remove reference to tensor
del decoded_samples
if device != "cpu":
torch.cuda.empty_cache()
gc.collect()
end_time = time.time()
generation_time = round(end_time - start_time, 2)
print(f"Generation finished in {generation_time} seconds.")
if pil_image is None:
return None, "Error during generation or decoding.", current_seed
return pil_image, f"Generated in {generation_time}s", current_seed
# --- Load Models Before Launching UI ---
load_models() # Load models to CPU initially
# --- Create Gradio Interface (UI definition remains the same) ---
with gr.Blocks() as demo:
gr.Markdown("# Lumina 2.0 Generation Demo")
gr.Markdown("Enter a prompt and adjust settings to generate an image using Lumina 2.0 (Memory Optimized).")
with gr.Row():
with gr.Column(scale=2):
prompt = gr.Textbox(label="Prompt", placeholder="e.g., A photo of an astronaut riding a horse on the moon")
negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="e.g., low quality, blurry, text, watermark")
system_type = gr.Dropdown(
label="System Prompt Type",
choices=["align", "base", "aesthetics", "real", "4grid", "tags", "empty"],
value="real"
)
resolution = gr.Dropdown(
label="Resolution (Width x Height)",
choices=["1024x1024", "1280x768", "768x1280", "1536x1024", "1024x1536"], # Common resolutions
value="1024x1024"
)
solver = gr.Dropdown(
label="Solver",
choices=["dpm", "euler", "midpoint", "heun", "rk4"],
value="dpm"
)
run_button = gr.Button("Generate Image", variant="primary")
with gr.Column(scale=1):
guidance_scale = gr.Slider(label="CFG Scale", minimum=1.0, maximum=15.0, step=0.5, value=4.0)
num_steps = gr.Slider(label="Sampling Steps", minimum=10, maximum=200, step=1, value=50)
seed = gr.Number(label="Seed", value=-1, precision=0) # -1 for random
time_shifting_factor = gr.Slider(label="Time Shift Factor (DPM)", minimum=0.0, maximum=10.0, step=0.1, value=1.0)
t_shift = gr.Slider(label="T Shift (ODE)", minimum=0, maximum=10, step=1, value=4)
with gr.Row():
output_image = gr.Image(label="Generated Image")
with gr.Row():
status_text = gr.Textbox(label="Status", interactive=False)
final_seed = gr.Number(label="Seed Used", interactive=False)
run_button.click(
fn=generate_image,
inputs=[
prompt, negative_prompt, system_type, solver, resolution,
guidance_scale, num_steps, seed, time_shifting_factor, t_shift
],
outputs=[output_image, status_text, final_seed]
)
gr.Examples(
examples=[
["A hyperrealistic photo of a cat wearing sunglasses and playing a saxophone on a beach at sunset", "", "real", "dpm", "1024x1024", 4.5, 40, -1, 1.0, 4],
["cinematic film still, A vast cyberpunk cityscape at night, raining, neon lights reflecting on wet streets, high detail", "low detail, drawing, illustration, sketch", "aesthetics", "euler", "1280x768", 5.0, 60, -1, 1.0, 4],
["Macro shot of a dewdrop on a spider web, intricate details, forest background, bokeh", "blurry, unfocused", "real", "dpm", "1024x1024", 4.0, 50, 12345, 1.0, 4],
],
inputs=[prompt, negative_prompt, system_type, solver, resolution, guidance_scale, num_steps, seed, time_shifting_factor, t_shift],
outputs=[output_image, status_text, final_seed],
fn=generate_image
)
# --- Launch the Gradio App ---
if __name__ == "__main__":
# Consider adding share=True if you need external access
demo.launch()