|
import gradio as gr |
|
import json |
|
import math |
|
import os |
|
import random |
|
import socket |
|
import time |
|
import gc |
|
|
|
|
|
from diffusers.models import AutoencoderKL |
|
import numpy as np |
|
import torch |
|
|
|
from torchvision.transforms.functional import to_pil_image |
|
from tqdm import tqdm |
|
from transformers import AutoModel, AutoTokenizer |
|
import models |
|
from transport import Sampler, create_transport |
|
|
|
|
|
|
|
CKPT_PATH = "checkpoint/consolidated.00-of-01.pth" |
|
MODEL_ARGS_PATH = "checkpoint/model_args.pth" |
|
VAE_TYPE = os.environ.get("VAE_TYPE", "flux") |
|
PRECISION = os.environ.get("PRECISION", "bf16") |
|
TEXT_ENCODER_MODEL = os.environ.get("TEXT_ENCODER_MODEL", 'google/gemma-2-2B') |
|
|
|
|
|
|
|
tokenizer = None |
|
text_encoder = None |
|
vae = None |
|
model = None |
|
sampler = None |
|
transport = None |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[PRECISION] |
|
train_args = None |
|
cap_feat_dim = None |
|
|
|
|
|
|
|
def encode_prompt(prompt_batch, text_encoder_model, tokenizer_model, target_device, target_dtype): |
|
""" |
|
Encodes prompts, moving the text encoder to the target device temporarily. |
|
""" |
|
captions = [] |
|
for caption in prompt_batch: |
|
|
|
captions.append(caption if isinstance(caption, str) else caption[0]) |
|
|
|
|
|
text_encoder_model.to(target_device) |
|
prompt_embeds, prompt_masks = None, None |
|
|
|
try: |
|
with torch.no_grad(): |
|
text_inputs = tokenizer_model( |
|
captions, |
|
padding=True, |
|
pad_to_multiple_of=8, |
|
max_length=256, |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
|
|
text_input_ids = text_inputs.input_ids.to(target_device) |
|
prompt_masks = text_inputs.attention_mask.to(target_device) |
|
|
|
|
|
with torch.autocast(device_type=target_device.split(':')[0], dtype=target_dtype): |
|
prompt_embeds = text_encoder_model( |
|
input_ids=text_input_ids, |
|
attention_mask=prompt_masks, |
|
output_hidden_states=True, |
|
).hidden_states[-2] |
|
|
|
finally: |
|
|
|
text_encoder_model.to("cpu") |
|
if target_device != "cpu": |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
|
|
return prompt_embeds.cpu(), prompt_masks.cpu() |
|
|
|
|
|
def none_or_str(value): |
|
if value == "None" or value is None: |
|
return None |
|
return str(value) |
|
|
|
|
|
|
|
def load_models(): |
|
global tokenizer, text_encoder, vae, model, sampler, transport, device, dtype, train_args, cap_feat_dim |
|
|
|
print("--- Starting Model Loading (targeting CPU initially) ---") |
|
torch.set_grad_enabled(False) |
|
|
|
|
|
if not os.path.exists(MODEL_ARGS_PATH): |
|
raise FileNotFoundError(f"Model args file not found: {MODEL_ARGS_PATH}") |
|
|
|
train_args = torch.load(MODEL_ARGS_PATH, map_location='cpu', weights_only=False) |
|
print("Loaded model arguments:", json.dumps(train_args.__dict__, indent=2)) |
|
|
|
|
|
print(f"Creating Tokenizer: {TEXT_ENCODER_MODEL}") |
|
tokenizer = AutoTokenizer.from_pretrained(TEXT_ENCODER_MODEL) |
|
tokenizer.padding_side = "right" |
|
print("Tokenizer loaded.") |
|
|
|
|
|
print(f"Creating Text Encoder: {TEXT_ENCODER_MODEL}") |
|
|
|
text_encoder = AutoModel.from_pretrained( |
|
TEXT_ENCODER_MODEL, torch_dtype=dtype |
|
).eval().to("cpu") |
|
cap_feat_dim = text_encoder.config.hidden_size |
|
print("Text encoder loaded to CPU.") |
|
|
|
|
|
print(f"Loading VAE: {VAE_TYPE}") |
|
vae_path = f"black-forest-labs/FLUX.1-dev" if VAE_TYPE == "flux" else \ |
|
f"stabilityai/sdxl-vae" if VAE_TYPE == "sdxl" else \ |
|
f"stabilityai/sd-vae-ft-{VAE_TYPE}" if VAE_TYPE in ["ema", "mse"] else None |
|
if vae_path is None: |
|
raise ValueError(f"Unsupported VAE type: {VAE_TYPE}") |
|
|
|
subfolder = "vae" if VAE_TYPE == "flux" else None |
|
vae = AutoencoderKL.from_pretrained(vae_path, subfolder=subfolder, torch_dtype=dtype).to("cpu") |
|
vae.requires_grad_(False) |
|
vae.eval() |
|
print("VAE loaded to CPU.") |
|
|
|
|
|
print(f"Creating DiT model: {train_args.model}") |
|
if not hasattr(models, train_args.model): |
|
raise AttributeError(f"Model class {train_args.model} not found in models module.") |
|
|
|
|
|
model = models.__dict__[train_args.model]( |
|
in_channels=16, |
|
qk_norm=getattr(train_args, 'qk_norm', True), |
|
cap_feat_dim=cap_feat_dim, |
|
|
|
).eval().to("cpu") |
|
|
|
print(f"Loading DiT checkpoint: {CKPT_PATH}") |
|
if not os.path.exists(CKPT_PATH): |
|
raise FileNotFoundError(f"DiT checkpoint not found: {CKPT_PATH}") |
|
|
|
state_dict = torch.load(CKPT_PATH, map_location='cpu') |
|
|
|
|
|
new_state_dict = {} |
|
for k, v in state_dict.items(): |
|
new_state_dict[k[len('module.'):] if k.startswith('module.') else k] = v |
|
del state_dict |
|
|
|
missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False) |
|
if missing_keys: print(f"Warning: Missing keys in state_dict: {missing_keys}") |
|
if unexpected_keys: print(f"Warning: Unexpected keys in state_dict: {unexpected_keys}") |
|
del new_state_dict |
|
|
|
|
|
print("DiT model loaded and checkpoint applied to CPU.") |
|
|
|
|
|
path_type = "Linear" |
|
prediction = "velocity" |
|
loss_weight = None |
|
train_eps = None |
|
sample_eps = None |
|
|
|
transport = create_transport(path_type, prediction, loss_weight, train_eps, sample_eps) |
|
sampler = Sampler(transport) |
|
print("Sampler initialized.") |
|
print("--- Model Loading Complete ---") |
|
|
|
|
|
|
|
def generate_image( |
|
prompt, |
|
negative_prompt, |
|
system_type, |
|
solver, |
|
resolution_str, |
|
guidance_scale, |
|
num_steps, |
|
seed, |
|
time_shifting_factor, |
|
t_shift, |
|
|
|
atol=1e-6, |
|
rtol=1e-3, |
|
): |
|
"""Generates an image, moving models to GPU only when needed.""" |
|
global model, vae, text_encoder, tokenizer, sampler |
|
|
|
if model is None or vae is None or text_encoder is None or tokenizer is None or sampler is None: |
|
return None, "Models not loaded. Please wait or check console.", -1 |
|
|
|
print("\n--- Starting Generation ---") |
|
start_time = time.time() |
|
|
|
|
|
if int(seed) == -1: |
|
current_seed = random.randint(0, 2**32 - 1) |
|
else: |
|
current_seed = int(seed) |
|
torch.manual_seed(current_seed) |
|
np.random.seed(current_seed) |
|
random.seed(current_seed) |
|
print(f"Using Seed: {current_seed}") |
|
|
|
|
|
|
|
if system_type == "align": |
|
system_prompt = "You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts. <Prompt Start> " |
|
elif system_type == "base": |
|
system_prompt = "You are an assistant designed to generate high-quality images based on user prompts. <Prompt Start> " |
|
elif system_type == "aesthetics": |
|
system_prompt = "You are an assistant designed to generate high-quality images with highest degree of aesthetics based on user prompts. <Prompt Start> " |
|
elif system_type == "real": |
|
system_prompt = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts. <Prompt Start> " |
|
elif system_type == "4grid": |
|
system_prompt = "You are an assistant designed to generate four high-quality images with highest degree of aesthetics arranged in 2x2 grids based on user prompts. <Prompt Start> " |
|
elif system_type == "tags": |
|
system_prompt = "You are an assistant designed to generate high-quality images based on user prompts based on danbooru tags. <Prompt Start> " |
|
elif system_type == "empty": |
|
system_prompt = "" |
|
else: |
|
system_prompt = "You are an assistant designed to generate high-quality images based on user prompts. <Prompt Start> " |
|
|
|
|
|
full_prompt = system_prompt + prompt |
|
full_negative_prompt = system_prompt + negative_prompt if negative_prompt else "" |
|
|
|
try: |
|
w_str, h_str = resolution_str.split("x") |
|
w, h = int(w_str), int(h_str) |
|
latent_w, latent_h = w // 8, h // 8 |
|
except Exception as e: |
|
print(f"Error parsing resolution: {e}") |
|
return None, f"Invalid resolution format: {resolution_str}. Use WxH.", current_seed |
|
|
|
|
|
print("Encoding prompts...") |
|
|
|
cap_feats_cpu, cap_mask_cpu = encode_prompt( |
|
[full_prompt, full_negative_prompt], |
|
text_encoder, |
|
tokenizer, |
|
device, |
|
dtype |
|
) |
|
print("Prompt encoding complete (Text Encoder back on CPU).") |
|
|
|
|
|
samples = None |
|
with torch.no_grad(): |
|
try: |
|
print("Moving DiT model to GPU...") |
|
model.to(device) |
|
|
|
|
|
z = torch.randn([1, 16, latent_h, latent_w], device=device, dtype=dtype) |
|
z = z.repeat(2, 1, 1, 1) |
|
|
|
|
|
model_kwargs = dict( |
|
cap_feats=cap_feats_cpu.to(device, dtype=dtype), |
|
cap_mask=cap_mask_cpu.to(device), |
|
cfg_scale=guidance_scale, |
|
) |
|
del cap_feats_cpu, cap_mask_cpu |
|
gc.collect() |
|
|
|
|
|
print(f"Starting sampling with solver: {solver}, steps: {num_steps}...") |
|
with torch.autocast(device_type=device.split(':')[0], dtype=dtype): |
|
if solver == "dpm": |
|
|
|
_transport = create_transport("Linear", "velocity") |
|
_sampler = Sampler(_transport) |
|
sample_fn = _sampler.sample_dpm( |
|
model.forward_with_cfg, |
|
model_kwargs=model_kwargs, |
|
) |
|
|
|
|
|
samples = sample_fn(z, steps=num_steps, order=2, skip_type="time_uniform_flow", method="multistep", flow_shift=time_shifting_factor) |
|
|
|
else: |
|
|
|
sample_fn = sampler.sample_ode( |
|
sampling_method=solver, num_steps=num_steps, |
|
atol=atol, rtol=rtol, |
|
time_shifting_factor=t_shift |
|
) |
|
|
|
|
|
samples = sample_fn(z, model.forward_with_cfg, **model_kwargs)[-1] |
|
|
|
|
|
samples = samples[:1].detach() |
|
|
|
finally: |
|
|
|
print("Moving DiT model back to CPU...") |
|
model.to("cpu") |
|
del model_kwargs |
|
del z |
|
if device != "cpu": |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
|
|
|
|
pil_image = None |
|
try: |
|
if samples is None: |
|
raise RuntimeError("Sampling failed, 'samples' tensor is None.") |
|
print("Moving VAE to GPU for decoding...") |
|
vae.to(device) |
|
|
|
print("Decoding samples...") |
|
with torch.no_grad(): |
|
|
|
samples = samples.to(device=device, dtype=vae.dtype) |
|
|
|
|
|
scaling_factor = vae.config.scaling_factor |
|
shift_factor = vae.config.shift_factor |
|
|
|
samples = samples / scaling_factor + shift_factor |
|
|
|
with torch.autocast(device_type=device.split(':')[0], dtype=dtype): |
|
decoded_samples = vae.decode(samples)[0] |
|
|
|
|
|
decoded_samples = decoded_samples.cpu() |
|
print("Decoding complete.") |
|
|
|
|
|
decoded_samples = (decoded_samples + 1.0) / 2.0 |
|
decoded_samples.clamp_(0.0, 1.0) |
|
|
|
|
|
pil_image = to_pil_image(decoded_samples[0].float()) |
|
|
|
finally: |
|
|
|
print("Moving VAE back to CPU...") |
|
vae.to("cpu") |
|
del samples |
|
del decoded_samples |
|
if device != "cpu": |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
|
|
end_time = time.time() |
|
generation_time = round(end_time - start_time, 2) |
|
print(f"Generation finished in {generation_time} seconds.") |
|
|
|
if pil_image is None: |
|
return None, "Error during generation or decoding.", current_seed |
|
|
|
return pil_image, f"Generated in {generation_time}s", current_seed |
|
|
|
|
|
|
|
load_models() |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Lumina 2.0 Generation Demo") |
|
gr.Markdown("Enter a prompt and adjust settings to generate an image using Lumina 2.0 (Memory Optimized).") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
prompt = gr.Textbox(label="Prompt", placeholder="e.g., A photo of an astronaut riding a horse on the moon") |
|
negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="e.g., low quality, blurry, text, watermark") |
|
system_type = gr.Dropdown( |
|
label="System Prompt Type", |
|
choices=["align", "base", "aesthetics", "real", "4grid", "tags", "empty"], |
|
value="real" |
|
) |
|
resolution = gr.Dropdown( |
|
label="Resolution (Width x Height)", |
|
choices=["1024x1024", "1280x768", "768x1280", "1536x1024", "1024x1536"], |
|
value="1024x1024" |
|
) |
|
solver = gr.Dropdown( |
|
label="Solver", |
|
choices=["dpm", "euler", "midpoint", "heun", "rk4"], |
|
value="dpm" |
|
) |
|
run_button = gr.Button("Generate Image", variant="primary") |
|
with gr.Column(scale=1): |
|
guidance_scale = gr.Slider(label="CFG Scale", minimum=1.0, maximum=15.0, step=0.5, value=4.0) |
|
num_steps = gr.Slider(label="Sampling Steps", minimum=10, maximum=200, step=1, value=50) |
|
seed = gr.Number(label="Seed", value=-1, precision=0) |
|
time_shifting_factor = gr.Slider(label="Time Shift Factor (DPM)", minimum=0.0, maximum=10.0, step=0.1, value=1.0) |
|
t_shift = gr.Slider(label="T Shift (ODE)", minimum=0, maximum=10, step=1, value=4) |
|
|
|
with gr.Row(): |
|
output_image = gr.Image(label="Generated Image") |
|
|
|
with gr.Row(): |
|
status_text = gr.Textbox(label="Status", interactive=False) |
|
final_seed = gr.Number(label="Seed Used", interactive=False) |
|
|
|
run_button.click( |
|
fn=generate_image, |
|
inputs=[ |
|
prompt, negative_prompt, system_type, solver, resolution, |
|
guidance_scale, num_steps, seed, time_shifting_factor, t_shift |
|
], |
|
outputs=[output_image, status_text, final_seed] |
|
) |
|
|
|
gr.Examples( |
|
examples=[ |
|
["A hyperrealistic photo of a cat wearing sunglasses and playing a saxophone on a beach at sunset", "", "real", "dpm", "1024x1024", 4.5, 40, -1, 1.0, 4], |
|
["cinematic film still, A vast cyberpunk cityscape at night, raining, neon lights reflecting on wet streets, high detail", "low detail, drawing, illustration, sketch", "aesthetics", "euler", "1280x768", 5.0, 60, -1, 1.0, 4], |
|
["Macro shot of a dewdrop on a spider web, intricate details, forest background, bokeh", "blurry, unfocused", "real", "dpm", "1024x1024", 4.0, 50, 12345, 1.0, 4], |
|
], |
|
inputs=[prompt, negative_prompt, system_type, solver, resolution, guidance_scale, num_steps, seed, time_shifting_factor, t_shift], |
|
outputs=[output_image, status_text, final_seed], |
|
fn=generate_image |
|
) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
demo.launch() |
|
|