File size: 2,366 Bytes
4358e59
b301caf
3fec1fb
4358e59
 
f4ff201
3fec1fb
 
70e3d12
3fec1fb
 
 
 
70e3d12
3fec1fb
 
 
b47c647
7e9a760
 
4358e59
548031b
4358e59
dba1359
98b5af6
7e9a760
ae9efe4
76ca690
4358e59
 
 
 
b21649c
4358e59
 
 
3fec1fb
4358e59
 
98b5af6
ca74145
a4ca4f9
ca74145
4358e59
 
 
 
 
 
 
 
75f237b
3fec1fb
0a14984
3fec1fb
70e3d12
a4ca4f9
548031b
4358e59
a4ca4f9
75f237b
5f1159f
a4ca4f9
 
 
 
b47c647
4358e59
6fed0f7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from diffusers import AutoPipelineForText2Image, StableDiffusionImg2ImgPipeline
from PIL import Image
import gradio as gr
import random
import torch

css = """
.btn-green {
  background-image: linear-gradient(to bottom right, #6dd178, #00a613) !important;
  border-color: #22c55e !important;
  color: #166534 !important;
}
.btn-green:hover {
  background-image: linear-gradient(to bottom right, #6dd178, #6dd178) !important;
}
"""

def generate(prompt, samp_steps, batch_size, seed, progress=gr.Progress(track_tqdm=True)):
    if seed < 0:
        seed = random.randint(1,999999)
    images = txt2img(
        prompt,
        num_inference_steps=1,
        num_images_per_prompt=batch_size,
        guidance_scale=0.0,
        generator=torch.manual_seed(seed),
    ).images
    upscaled_images = [img.resize((1024,1024), 1) for img in images]
    final_images = img2img(
        prompt,
        num_inference_steps=samp_steps,
        guidance_scale=5,
        strength=0.5,
        generator=torch.manual_seed(seed),
    ).images
    return gr.update(value = [(img, f"Image {i+1}") for i, img in enumerate(final_images)]), seed
        
def set_base_models():
    txt2img = AutoPipelineForText2Image.from_pretrained(
        "stabilityai/sdxl-turbo",
        torch_dtype = torch.float16,
        variant = "fp16"
    )
    txt2img.to("cuda")
    img2img = StableDiffusionImg2ImgPipeline.from_pretrained(
        "Lykon/dreamshaper-8",
        torch_dtype = torch.float16,
        variant = "fp16"
    )
    img2img.to("cuda")
    return txt2img, img2img

with gr.Blocks(css=css) as demo:
    with gr.Column():
        prompt = gr.Textbox(label="Prompt")
        submit_btn = gr.Button("Generate", elem_classes="btn-green")
        
        with gr.Row():
            sampling_steps = gr.Slider(1, 20, value=5, step=1, label="Sampling steps")
            batch_size = gr.Slider(1, 6, value=1, step=1, label="Batch size")
            seed = gr.Number(label="Seed", value=-1, minimum=-1, precision=0)
            lastSeed = gr.Number(label="Last Seed", value=-1, interactive=False)
            
        gallery = gr.Gallery(show_label=False, preview=True, container=False, height=650)
        
    submit_btn.click(generate, [prompt, sampling_steps, batch_size, seed], [gallery, lastSeed], queue=True)
    
txt2img, img2img = set_base_models()
demo.launch(debug=True)