Spaces:
Paused
Paused
import torch | |
import spaces | |
from diffusers import StableDiffusionPipeline | |
import gradio as gr | |
repo = "IDKiro/sdxs-512-0.9" | |
seed = 42 | |
weight_type = torch.float16 | |
zero = torch.Tensor([0]).cuda() | |
print(zero.device) # <-- 'cpu' 🤔 | |
# Load model. | |
pipe = StableDiffusionPipeline.from_pretrained(repo, torch_dtype=weight_type) | |
generator = pipe | |
# move to GPU if available | |
if torch.cuda.is_available(): | |
generator = generator.to("cuda") | |
def generate(prompts): | |
images = generator(list(prompts)).images | |
return [images] | |
demo = gr.Interface( | |
generate, | |
"textbox", | |
"image", | |
title="SDXS: Real-Time One-Step Latent Diffusion Models with Image Conditions", | |
description="This demo showcases [SDXS](https://arxiv.org/abs/2403.16627)", | |
batch=True, | |
max_batch_size=4, # Set the batch size based on your CPU/GPU memory | |
).queue() | |
if __name__ == "__main__": | |
demo.launch() | |