Update app.py
Browse files
app.py
CHANGED
@@ -8,25 +8,23 @@ import gradio as gr
|
|
8 |
base_model = "stabilityai/stable-diffusion-xl-base-1.0"
|
9 |
lightning_repo = "ByteDance/SDXL-Lightning"
|
10 |
checkpoint_filename = "sdxl_lightning_2step_unet.safetensors"
|
11 |
-
device = "
|
12 |
|
13 |
# --- Download and Load Lightning UNet ---
|
14 |
print("Downloading Lightning checkpoint...")
|
15 |
ckpt_path = cached_file(lightning_repo, checkpoint_filename, cache_dir=".cache")
|
16 |
|
17 |
-
print("Loading UNet (
|
18 |
unet = UNet2DConditionModel.from_config(base_model, subfolder="unet")
|
19 |
unet.load_state_dict(load_file(ckpt_path))
|
20 |
-
unet.to(device
|
21 |
unet.eval()
|
22 |
|
23 |
-
# --- Load Pipeline
|
24 |
print("Loading Stable Diffusion XL Pipeline...")
|
25 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
26 |
base_model,
|
27 |
-
unet=unet
|
28 |
-
torch_dtype=torch.float16,
|
29 |
-
variant="fp16"
|
30 |
)
|
31 |
pipe.to(device)
|
32 |
|
@@ -34,8 +32,7 @@ pipe.to(device)
|
|
34 |
def generate_image(prompt):
|
35 |
if not prompt:
|
36 |
return "Prompt cannot be empty!"
|
37 |
-
|
38 |
-
image = pipe(prompt, num_inference_steps=2, guidance_scale=0).images[0]
|
39 |
return image
|
40 |
|
41 |
# --- Example Prompts ---
|
@@ -52,11 +49,11 @@ demo = gr.Interface(
|
|
52 |
fn=generate_image,
|
53 |
inputs=gr.Textbox(label="Enter your prompt", placeholder="e.g., A castle floating in the clouds"),
|
54 |
outputs=gr.Image(type="pil"),
|
55 |
-
title="SDXL-Lightning (2-Step,
|
56 |
-
description="Fast
|
57 |
examples=examples,
|
58 |
-
cache_examples=False
|
59 |
)
|
60 |
|
61 |
-
# --- Launch
|
62 |
demo.launch()
|
|
|
8 |
base_model = "stabilityai/stable-diffusion-xl-base-1.0"
|
9 |
lightning_repo = "ByteDance/SDXL-Lightning"
|
10 |
checkpoint_filename = "sdxl_lightning_2step_unet.safetensors"
|
11 |
+
device = "cpu" # Force CPU
|
12 |
|
13 |
# --- Download and Load Lightning UNet ---
|
14 |
print("Downloading Lightning checkpoint...")
|
15 |
ckpt_path = cached_file(lightning_repo, checkpoint_filename, cache_dir=".cache")
|
16 |
|
17 |
+
print("Loading UNet (CPU)...")
|
18 |
unet = UNet2DConditionModel.from_config(base_model, subfolder="unet")
|
19 |
unet.load_state_dict(load_file(ckpt_path))
|
20 |
+
unet.to(device)
|
21 |
unet.eval()
|
22 |
|
23 |
+
# --- Load Pipeline without FP16 ---
|
24 |
print("Loading Stable Diffusion XL Pipeline...")
|
25 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
26 |
base_model,
|
27 |
+
unet=unet
|
|
|
|
|
28 |
)
|
29 |
pipe.to(device)
|
30 |
|
|
|
32 |
def generate_image(prompt):
|
33 |
if not prompt:
|
34 |
return "Prompt cannot be empty!"
|
35 |
+
image = pipe(prompt, num_inference_steps=2, guidance_scale=0).images[0]
|
|
|
36 |
return image
|
37 |
|
38 |
# --- Example Prompts ---
|
|
|
49 |
fn=generate_image,
|
50 |
inputs=gr.Textbox(label="Enter your prompt", placeholder="e.g., A castle floating in the clouds"),
|
51 |
outputs=gr.Image(type="pil"),
|
52 |
+
title="SDXL-Lightning (2-Step, CPU) Image Generator",
|
53 |
+
description="Fast image generation using ByteDance's SDXL-Lightning 2-step model on CPU (no optimization).",
|
54 |
examples=examples,
|
55 |
+
cache_examples=False
|
56 |
)
|
57 |
|
58 |
+
# --- Launch Interface ---
|
59 |
demo.launch()
|