Krebzonide's picture
Update app.py
3f9da05
raw
history blame
2.43 kB
from diffusers import AutoPipelineForText2Image, StableDiffusionImg2ImgPipeline
import torchvision.transforms.functional as fn
import torchvision.transforms.InterpolationMode as interp
import gradio as gr
import random
import torch
css = """
.btn-green {
background-image: linear-gradient(to bottom right, #6dd178, #00a613) !important;
border-color: #22c55e !important;
color: #166534 !important;
}
.btn-green:hover {
background-image: linear-gradient(to bottom right, #6dd178, #6dd178) !important;
}
"""
def generate(prompt, samp_steps, batch_size, seed, progress=gr.Progress(track_tqdm=True)):
if seed < 0:
seed = random.randint(1,999999)
images = txt2img(
prompt,
num_inference_steps=1,
num_images_per_prompt=batch_size,
guidance_scale=0.0,
generator=torch.manual_seed(seed),
).images
upscaled_images = fn.resize(images, 1024, interp.NEAREST_EXACT)
final_images = img2img(
prompt,
num_inference_steps=samp_steps,
guidance_scale=5,
generator=torch.manual_seed(seed),
).images
return gr.update(value = [(img, f"Image {i+1}") for i, img in enumerate(final_images)]), seed
def set_base_models():
txt2img = AutoPipelineForText2Image.from_pretrained(
"stabilityai/sdxl-turbo",
torch_dtype = torch.float16,
variant = "fp16"
)
txt2img.to("cuda")
img2img = StableDiffusionImg2ImgPipeline.from_pretrained(
"Lykon/dreamshaper-8",
torch_dtype = torch.float16,
variant = "fp16"
)
img2img.to("cuda")
return txt2img, img2img
with gr.Blocks(css=css) as demo:
with gr.Column():
prompt = gr.Textbox(label="Prompt")
submit_btn = gr.Button("Generate", elem_classes="btn-green")
with gr.Row():
sampling_steps = gr.Slider(1, 20, value=5, step=1, label="Sampling steps")
batch_size = gr.Slider(1, 6, value=1, step=1, label="Batch size")
seed = gr.Number(label="Seed", value=-1, minimum=-1, precision=0)
lastSeed = gr.Number(label="Last Seed", value=-1, interactive=False)
gallery = gr.Gallery(show_label=False, preview=True, container=False, height=650)
submit_btn.click(generate, [prompt, sampling_steps, batch_size, seed], [gallery, lastSeed], queue=True)
txt2img, img2img = set_base_models()
demo.launch(debug=True)