Bluestrike commited on
Commit
a36f40b
·
verified ·
1 Parent(s): 66e19c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -13
app.py CHANGED
@@ -8,25 +8,23 @@ import gradio as gr
8
  base_model = "stabilityai/stable-diffusion-xl-base-1.0"
9
  lightning_repo = "ByteDance/SDXL-Lightning"
10
  checkpoint_filename = "sdxl_lightning_2step_unet.safetensors"
11
- device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
  # --- Download and Load Lightning UNet ---
14
  print("Downloading Lightning checkpoint...")
15
  ckpt_path = cached_file(lightning_repo, checkpoint_filename, cache_dir=".cache")
16
 
17
- print("Loading UNet (FP16)...")
18
  unet = UNet2DConditionModel.from_config(base_model, subfolder="unet")
19
  unet.load_state_dict(load_file(ckpt_path))
20
- unet.to(device, dtype=torch.float16)
21
  unet.eval()
22
 
23
- # --- Load Pipeline with FP16 ---
24
  print("Loading Stable Diffusion XL Pipeline...")
25
  pipe = StableDiffusionXLPipeline.from_pretrained(
26
  base_model,
27
- unet=unet,
28
- torch_dtype=torch.float16,
29
- variant="fp16"
30
  )
31
  pipe.to(device)
32
 
@@ -34,8 +32,7 @@ pipe.to(device)
34
  def generate_image(prompt):
35
  if not prompt:
36
  return "Prompt cannot be empty!"
37
- with torch.autocast(device_type="cuda"):
38
- image = pipe(prompt, num_inference_steps=2, guidance_scale=0).images[0]
39
  return image
40
 
41
  # --- Example Prompts ---
@@ -52,11 +49,11 @@ demo = gr.Interface(
52
  fn=generate_image,
53
  inputs=gr.Textbox(label="Enter your prompt", placeholder="e.g., A castle floating in the clouds"),
54
  outputs=gr.Image(type="pil"),
55
- title="SDXL-Lightning (2-Step, FP16) Image Generator",
56
- description="Fast & optimized image generation using ByteDance's SDXL-Lightning 4-step model with FP16 on CUDA.",
57
  examples=examples,
58
- cache_examples=False # Prevent FileNotFoundError
59
  )
60
 
61
- # --- Launch Public Server ---
62
  demo.launch()
 
8
  base_model = "stabilityai/stable-diffusion-xl-base-1.0"
9
  lightning_repo = "ByteDance/SDXL-Lightning"
10
  checkpoint_filename = "sdxl_lightning_2step_unet.safetensors"
11
+ device = "cpu" # Force CPU
12
 
13
  # --- Download and Load Lightning UNet ---
14
  print("Downloading Lightning checkpoint...")
15
  ckpt_path = cached_file(lightning_repo, checkpoint_filename, cache_dir=".cache")
16
 
17
+ print("Loading UNet (CPU)...")
18
  unet = UNet2DConditionModel.from_config(base_model, subfolder="unet")
19
  unet.load_state_dict(load_file(ckpt_path))
20
+ unet.to(device)
21
  unet.eval()
22
 
23
+ # --- Load Pipeline without FP16 ---
24
  print("Loading Stable Diffusion XL Pipeline...")
25
  pipe = StableDiffusionXLPipeline.from_pretrained(
26
  base_model,
27
+ unet=unet
 
 
28
  )
29
  pipe.to(device)
30
 
 
32
  def generate_image(prompt):
33
  if not prompt:
34
  return "Prompt cannot be empty!"
35
+ image = pipe(prompt, num_inference_steps=2, guidance_scale=0).images[0]
 
36
  return image
37
 
38
  # --- Example Prompts ---
 
49
  fn=generate_image,
50
  inputs=gr.Textbox(label="Enter your prompt", placeholder="e.g., A castle floating in the clouds"),
51
  outputs=gr.Image(type="pil"),
52
+ title="SDXL-Lightning (2-Step, CPU) Image Generator",
53
+ description="Fast image generation using ByteDance's SDXL-Lightning 2-step model on CPU (no optimization).",
54
  examples=examples,
55
+ cache_examples=False
56
  )
57
 
58
+ # --- Launch Interface ---
59
  demo.launch()