|
import torch |
|
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel |
|
from transformers.utils.hub import cached_file |
|
from safetensors.torch import load_file |
|
import gradio as gr |
|
|
|
|
|
base_model = "stabilityai/stable-diffusion-xl-base-1.0" |
|
lightning_repo = "ByteDance/SDXL-Lightning" |
|
checkpoint_filename = "sdxl_lightning_2step_unet.safetensors" |
|
device = "cpu" |
|
|
|
|
|
print("Downloading Lightning checkpoint...") |
|
ckpt_path = cached_file(lightning_repo, checkpoint_filename, cache_dir=".cache") |
|
|
|
print("Loading UNet (CPU)...") |
|
unet = UNet2DConditionModel.from_config(base_model, subfolder="unet") |
|
unet.load_state_dict(load_file(ckpt_path)) |
|
unet.to(device) |
|
unet.eval() |
|
|
|
|
|
print("Loading Stable Diffusion XL Pipeline...") |
|
pipe = StableDiffusionXLPipeline.from_pretrained( |
|
base_model, |
|
unet=unet |
|
) |
|
pipe.to(device) |
|
|
|
|
|
def generate_image(prompt): |
|
if not prompt: |
|
return "Prompt cannot be empty!" |
|
image = pipe(prompt, num_inference_steps=2, guidance_scale=0).images[0] |
|
return image |
|
|
|
|
|
examples = [ |
|
["A futuristic city skyline at sunset, ultra-detailed, sci-fi style"], |
|
["An astronaut riding a horse on Mars, digital art"], |
|
["A serene forest landscape with glowing mushrooms, fantasy art"], |
|
["Cyberpunk samurai under neon lights, raining scene"], |
|
["A cute robot chef in a cozy kitchen, Pixar style"] |
|
] |
|
|
|
|
|
demo = gr.Interface( |
|
fn=generate_image, |
|
inputs=gr.Textbox(label="Enter your prompt", placeholder="e.g., A castle floating in the clouds"), |
|
outputs=gr.Image(type="pil"), |
|
title="SDXL-Lightning (2-Step, CPU) Image Generator", |
|
description="Fast image generation using ByteDance's SDXL-Lightning 2-step model on CPU (no optimization).", |
|
examples=examples, |
|
cache_examples=False |
|
) |
|
|
|
|
|
demo.launch() |