import torch from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel from transformers.utils.hub import cached_file from safetensors.torch import load_file import gradio as gr # --- Configuration --- base_model = "stabilityai/stable-diffusion-xl-base-1.0" lightning_repo = "ByteDance/SDXL-Lightning" checkpoint_filename = "sdxl_lightning_2step_unet.safetensors" device = "cpu" # Force CPU # --- Download and Load Lightning UNet --- print("Downloading Lightning checkpoint...") ckpt_path = cached_file(lightning_repo, checkpoint_filename, cache_dir=".cache") print("Loading UNet (CPU)...") unet = UNet2DConditionModel.from_config(base_model, subfolder="unet") unet.load_state_dict(load_file(ckpt_path)) unet.to(device) unet.eval() # --- Load Pipeline without FP16 --- print("Loading Stable Diffusion XL Pipeline...") pipe = StableDiffusionXLPipeline.from_pretrained( base_model, unet=unet ) pipe.to(device) # --- Generation Function --- def generate_image(prompt): if not prompt: return "Prompt cannot be empty!" image = pipe(prompt, num_inference_steps=2, guidance_scale=0).images[0] return image # --- Example Prompts --- examples = [ ["A futuristic city skyline at sunset, ultra-detailed, sci-fi style"], ["An astronaut riding a horse on Mars, digital art"], ["A serene forest landscape with glowing mushrooms, fantasy art"], ["Cyberpunk samurai under neon lights, raining scene"], ["A cute robot chef in a cozy kitchen, Pixar style"] ] # --- Gradio UI --- demo = gr.Interface( fn=generate_image, inputs=gr.Textbox(label="Enter your prompt", placeholder="e.g., A castle floating in the clouds"), outputs=gr.Image(type="pil"), title="SDXL-Lightning (2-Step, CPU) Image Generator", description="Fast image generation using ByteDance's SDXL-Lightning 2-step model on CPU (no optimization).", examples=examples, cache_examples=False ) # --- Launch Interface --- demo.launch()