SDXS / app.py
thliang01's picture
feat: add spaces import for zero gpu
765b5c0 unverified
raw
history blame contribute delete
925 Bytes
import torch
import spaces
from diffusers import StableDiffusionPipeline
import gradio as gr
repo = "IDKiro/sdxs-512-0.9"
seed = 42
weight_type = torch.float16
zero = torch.Tensor([0]).cuda()
print(zero.device) # <-- 'cpu' 🤔
# Load model.
pipe = StableDiffusionPipeline.from_pretrained(repo, torch_dtype=weight_type)
generator = pipe
# move to GPU if available
if torch.cuda.is_available():
generator = generator.to("cuda")
@spaces.GPU(duration=120)
def generate(prompts):
images = generator(list(prompts)).images
return [images]
demo = gr.Interface(
generate,
"textbox",
"image",
title="SDXS: Real-Time One-Step Latent Diffusion Models with Image Conditions",
description="This demo showcases [SDXS](https://arxiv.org/abs/2403.16627)",
batch=True,
max_batch_size=4, # Set the batch size based on your CPU/GPU memory
).queue()
if __name__ == "__main__":
demo.launch()