test1 / app.py
matt-le-kat's picture
Update app.py
843ad3f
raw
history blame contribute delete
No virus
905 Bytes
import gradio as gr
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline
from datasets import load_dataset
from PIL import Image
import re
import os
model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda"
pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=os.environ["auth_token"], revision="fp16", torch_dtype=torch.float16)
pipe = pipe.to(device)
def infer(prompt):
generator = torch.Generator(device=device)
with autocast("cuda"):
images_list = pipe(
[prompt],
generator=generator)
return images_list
text = gr.Textbox(
label="Enter your prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt")
gallery = gr.Gallery(
label="Generated images",
show_label=False)
intf = gr.Interface(fn = infer, inputs = text, outputs = gallery)
intf.launch()