File size: 4,162 Bytes
3943768
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import filelock

import torch

from src.utils import makedirs
from src.vision.sdxl_turbo import get_device


def get_pipe_make_image(gpu_id, refine=True,
                        base_model="stabilityai/stable-diffusion-xl-base-1.0",
                        refiner_model="stabilityai/stable-diffusion-xl-refiner-1.0",
                        high_noise_frac=0.8):
    if base_model is None:
        base_model = "stabilityai/stable-diffusion-xl-base-1.0"
    if base_model == "stabilityai/stable-diffusion-xl-base-1.0" and refiner_model is None:
        refiner_model = "stabilityai/stable-diffusion-xl-refiner-1.0"

    device = get_device(gpu_id)

    if 'diffusion-3' in base_model:
        from diffusers import StableDiffusion3Pipeline
        cls = StableDiffusion3Pipeline
        extra1 = dict()
        extra2 = dict()
    else:
        from diffusers import DiffusionPipeline
        cls = DiffusionPipeline
        # extra1 = dict(denoising_end=high_noise_frac, output_type="latent")
        # extra2 = dict(denoising_end=high_noise_frac)
        extra1 = dict()
        extra2 = dict()

    base = cls.from_pretrained(
        base_model,
        torch_dtype=torch.float16,
        use_safetensors=True,
        add_watermarker=False,
        # variant="fp16"
    ).to(device)
    if not refine or not refiner_model:
        refiner = None
    else:
        refiner = cls.from_pretrained(
            refiner_model,
            text_encoder_2=base.text_encoder_2,
            vae=base.vae,
            torch_dtype=torch.float16,
            use_safetensors=True,
            # variant="fp16",
        ).to(device)

    return base, refiner, extra1, extra2


def make_image(prompt,
               filename=None,
               gpu_id='auto',
               pipe=None,
               image_size="1024x1024",
               image_quality='standard',
               image_guidance_scale=3.0,
               base_model=None,
               refiner_model=None,
               image_num_inference_steps=40, high_noise_frac=0.8):
    if image_quality == 'manual':
        # listen to guidance_scale and num_inference_steps passed in
        pass
    else:
        if image_quality == 'quick':
            image_num_inference_steps = 10
            image_size = "512x512"
        elif image_quality == 'standard':
            image_num_inference_steps = 20
        elif image_quality == 'hd':
            image_num_inference_steps = 50

    if pipe is None:
        base, refiner, extra1, extra2 = get_pipe_make_image(gpu_id=gpu_id,
                                                            base_model=base_model,
                                                            refiner_model=refiner_model,
                                                            high_noise_frac=high_noise_frac)
    else:
        base, refiner, extra1, extra2 = pipe

    lock_type = 'image'
    base_path = os.path.join('locks', 'image_locks')
    base_path = makedirs(base_path, exist_ok=True, tmp_ok=True, use_base=True)
    lock_file = os.path.join(base_path, "%s.lock" % lock_type)
    makedirs(os.path.dirname(lock_file))  # ensure made
    with filelock.FileLock(lock_file):
        # Define how many steps and what % of steps to be run on each experts (80/20) here
        # run both experts
        image = base(
            prompt=prompt,
            height=int(image_size.lower().split('x')[0]),
            width=int(image_size.lower().split('x')[1]),
            num_inference_steps=image_num_inference_steps,
            guidance_scale=image_guidance_scale,
            **extra1,
        ).images
        if refiner:
            image = refiner(
                prompt=prompt,
                height=int(image_size.lower().split('x')[0]),
                width=int(image_size.lower().split('x')[1]),
                num_inference_steps=image_num_inference_steps,
                guidance_scale=image_guidance_scale,
                **extra2,
                image=image,
            ).images[0]

    if filename:
        if isinstance(image, list):
            image = image[-1]
        image.save(filename)
        return filename
    return image