TiM / app.py
Julien Blanchon
soijds
aef3da7
raw
history blame
10.8 kB
import gradio as gr
import spaces # type: ignore - ZeroGPU spaces library
import numpy as np
import random
import torch
import functools
from pathlib import Path
from PIL import Image
from omegaconf import OmegaConf # type: ignore - YAML configuration library
from tim.schedulers.transition import TransitionSchedule
from tim.utils.misc_utils import instantiate_from_config, init_from_ckpt
from tim.models.vae import get_sd_vae, get_dc_ae, sd_vae_decode, dc_ae_decode
from tim.models.utils.text_encoders import load_text_encoder, encode_prompt
from kernels import get_kernel
# Configuration
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
# Global variables to store loaded components
model = None
scheduler = None
decode_func = None
config = None
text_encoder = None
tokenizer = None
def load_model_components(device: str = "cuda"):
"""Load all model components once at startup"""
global model, scheduler, decode_func, config, text_encoder, tokenizer
try:
# Load configuration
config_path = "configs/t2i/tim_xl_p1_t2i.yaml"
from huggingface_hub import hf_hub_download
ckpt_path = hf_hub_download(
repo_id="blanchon/TiM-checkpoints", filename="t2i_model.bin"
)
if not Path(config_path).exists():
raise FileNotFoundError(f"Config file not found: {config_path}")
if not Path(ckpt_path).exists():
raise FileNotFoundError(f"Checkpoint file not found: {ckpt_path}")
print("Loading configuration...")
config = OmegaConf.load(config_path)
model_config = config.model
print("Loading VAE...")
# Load VAE
if "dc-ae" in model_config.vae_dir:
dc_ae = get_dc_ae(model_config.vae_dir, dtype=torch.float32, device=device)
dc_ae.enable_tiling(2560, 2560, 2560, 2560)
decode_func = functools.partial(dc_ae_decode, dc_ae, slice_vae=True)
elif "sd-vae" in model_config.vae_dir:
sd_vae = get_sd_vae(
model_config.vae_dir, dtype=torch.float32, device=device
)
decode_func = functools.partial(sd_vae_decode, sd_vae, slice_vae=True)
else:
raise ValueError("Unsupported VAE type")
# Load text encoder
text_encoder, tokenizer = load_text_encoder(
text_encoder_dir=config.model.text_encoder_dir,
device=device,
weight_dtype=dtype,
)
print("Loading main model...")
# Load main model
model = instantiate_from_config(model_config.network).to(
device=device, dtype=dtype
)
init_from_ckpt(model, checkpoint_dir=ckpt_path, ignore_keys=None, verbose=True)
model.eval()
print("Loading scheduler...")
# Load scheduler
transport = instantiate_from_config(model_config.transport)
scheduler = TransitionSchedule(
transport=transport, **OmegaConf.to_container(model_config.transition_loss)
)
print("All components loaded successfully!")
except Exception as e:
print(f"Error loading model components: {e}")
raise e
@spaces.GPU(duration=60)
def generate_image(
prompt,
seed=42,
randomize_seed=False,
width=1024,
height=1024,
guidance_scale=2.5,
num_inference_steps=16,
progress=gr.Progress(track_tqdm=True),
):
"""Generate image from text prompt"""
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Validate inputs
if not prompt or len(prompt.strip()) == 0:
raise ValueError("Please enter a valid prompt")
if model is None or scheduler is None:
raise RuntimeError("Model components not loaded. Please check the setup.")
# Validate dimensions
if (
width < 256
or width > MAX_IMAGE_SIZE
or height < 256
or height > MAX_IMAGE_SIZE
):
raise ValueError(
f"Image dimensions must be between 256 and {MAX_IMAGE_SIZE}"
)
if width % 32 != 0 or height % 32 != 0:
raise ValueError("Image dimensions must be divisible by 32")
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device=device).manual_seed(seed)
# Calculate latent dimensions
spatial_downsample = 32 if "dc-ae" in config.model.vae_dir else 8
latent_h = int(height / spatial_downsample)
latent_w = int(width / spatial_downsample)
progress(0.1, desc="Generating random latent...")
# Generate random latent
z = torch.randn(
(1, model.in_channels, latent_h, latent_w),
device=device,
dtype=dtype,
generator=generator,
)
progress(0.1, desc="Loading text encoder...")
# Load text encoder
text_encoder.set_attn_implementation("flash_attention_2")
text_encoder.to(device)
# Encode prompt
cap_features, cap_mask = encode_prompt(
tokenizer,
text_encoder.model,
device,
dtype,
[prompt],
config.model.use_last_hidden_state,
max_seq_length=config.model.max_seq_length,
)
# Encode null caption for CFG
null_cap_feat, null_cap_mask = encode_prompt(
tokenizer,
text_encoder.model,
device,
dtype,
[""],
config.model.use_last_hidden_state,
max_seq_length=config.model.max_seq_length,
)
cur_max_seq_len = cap_mask.sum(dim=-1).max()
y = cap_features[:, :cur_max_seq_len]
y_null = null_cap_feat[:, :cur_max_seq_len]
y_null = y_null.expand(y.shape[0], cur_max_seq_len, null_cap_feat.shape[-1])
# Generate image
with torch.no_grad():
samples = scheduler.sample(
model,
y,
y_null,
z,
T_max=1.0,
T_min=0.0,
num_steps=num_inference_steps,
cfg_scale=guidance_scale,
cfg_low=0.0,
cfg_high=1.0,
stochasticity_ratio=0.0,
sample_type="transition",
step_callback=lambda step: progress(
0.1 + 0.9 * (step / num_inference_steps), desc="Generating image..."
),
)[-1]
samples = samples.to(torch.float32)
# Decode to image
images = decode_func(samples)
images = (
torch.clamp(127.5 * images + 128.0, 0, 255)
.permute(0, 2, 3, 1)
.to(torch.uint8)
.contiguous()
)
image = Image.fromarray(images[0].cpu().numpy())
progress(1.0, desc="Complete!")
return image, seed
except Exception as e:
print(f"Error during image generation: {e}")
# Return a placeholder image or error message
error_img = Image.new("RGB", (512, 512), color="red")
return error_img, seed
# Example prompts
examples = [
["a tiny astronaut hatching from an egg on the moon"],
["🐢 Wearing πŸ•Ά flying on the 🌈"],
["an anime illustration of a wiener schnitzel"],
["a photorealistic landscape of mountains at sunset"],
["a majestic lion in a golden savanna at sunset"],
["a futuristic city with flying cars and neon lights"],
["a cozy cabin in a snowy forest with smoke coming from the chimney"],
["a beautiful mermaid swimming in crystal clear water"],
]
# CSS styling
css = """
#col-container {
margin: 0 auto;
max-width: 520px;
}
"""
# Initialize model components
try:
flash_attn = get_kernel("kernels-community/flash-attn")
load_model_components(device)
print("Model components loaded successfully!")
except Exception as e:
print(f"Error loading model components: {e}")
print("Please ensure config and checkpoint files are available")
# Create Gradio interface
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# TiM Text-to-Image Generator")
gr.Markdown(
"Generate high-quality images from text prompts using the TiM (Transition in Matching) model"
)
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Generate", scale=0)
result = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=1,
maximum=15,
step=0.1,
value=2.5,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=16,
)
gr.Examples(
examples=examples,
fn=generate_image,
inputs=[prompt],
outputs=[result, seed],
cache_examples=True,
cache_mode="lazy",
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn=generate_image,
inputs=[
prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
],
outputs=[result, seed],
)
if __name__ == "__main__":
demo.launch()