|
import torch |
|
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel |
|
from huggingface_hub import hf_hub_download |
|
from safetensors.torch import load_file |
|
import gradio as gr |
|
|
|
|
|
base_model = "stabilityai/stable-diffusion-xl-base-1.0" |
|
lightning_repo = "ByteDance/SDXL-Lightning" |
|
checkpoint_file = "sdxl_lightning_4step_unet.safetensors" |
|
|
|
|
|
print("Loading UNet...") |
|
unet = UNet2DConditionModel.from_config(base_model, subfolder="unet") |
|
unet.load_state_dict(load_file(hf_hub_download(lightning_repo, checkpoint_file))) |
|
unet.eval() |
|
|
|
|
|
print("Initializing pipeline...") |
|
pipe = StableDiffusionXLPipeline.from_pretrained( |
|
base_model, |
|
unet=unet, |
|
torch_dtype=torch.float16 |
|
) |
|
pipe.to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
def generate(prompt): |
|
image = pipe(prompt, num_inference_steps=4, guidance_scale=0).images[0] |
|
return image |
|
|
|
|
|
example_prompts = [ |
|
"A futuristic city skyline at sunset, ultra-detailed, sci-fi style", |
|
"An astronaut riding a horse on Mars, digital art", |
|
"A serene forest landscape with glowing mushrooms, fantasy art", |
|
"Cyberpunk samurai standing under neon lights in the rain", |
|
"Cute robot cooking in a cozy kitchen, Pixar style" |
|
] |
|
|
|
|
|
interface = gr.Interface( |
|
fn=generate, |
|
inputs=gr.Textbox(label="Enter your prompt"), |
|
outputs=gr.Image(type="pil"), |
|
title="ByteDance SDXL-Lightning 4-Step Image Generator", |
|
description="Ultra-fast AI image generation using 4-step SDXL-Lightning by ByteDance.", |
|
examples=example_prompts |
|
) |
|
|
|
|
|
interface.launch() |