lcipolina commited on
Commit
fbcecb4
1 Parent(s): cd6f4c0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +248 -0
app.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ subprocess.run('pip install -e .', shell=True)
3
+
4
+ print("Installed the repo!")
5
+
6
+ # GLIDE imports
7
+ from typing import Tuple
8
+
9
+ from IPython.display import display
10
+ from PIL import Image
11
+ import numpy as np
12
+ import torch as th
13
+ import torch.nn.functional as F
14
+
15
+ from glide_text2im.download import load_checkpoint
16
+ from glide_text2im.model_creation import (
17
+ create_model_and_diffusion,
18
+ model_and_diffusion_defaults,
19
+ model_and_diffusion_defaults_upsampler
20
+ )
21
+
22
+ # gradio app imports
23
+ import gradio as gr
24
+
25
+ from torchvision.transforms import ToTensor, ToPILImage
26
+ image_to_tensor = ToTensor()
27
+ tensor_to_image = ToPILImage()
28
+
29
+ # This notebook supports both CPU and GPU.
30
+ # On CPU, generating one sample may take on the order of 20 minutes.
31
+ # On a GPU, it should be under a minute.
32
+
33
+ has_cuda = th.cuda.is_available()
34
+ device = th.device('cpu' if not has_cuda else 'cuda')
35
+
36
+ # Create base model.
37
+ options = model_and_diffusion_defaults()
38
+ options['inpaint'] = True
39
+ options['use_fp16'] = has_cuda
40
+ options['timestep_respacing'] = '100' # use 100 diffusion steps for fast sampling
41
+ model, diffusion = create_model_and_diffusion(**options)
42
+ model.eval()
43
+ if has_cuda:
44
+ model.convert_to_fp16()
45
+ model.to(device)
46
+ model.load_state_dict(load_checkpoint('base-inpaint', device))
47
+ print('total base parameters', sum(x.numel() for x in model.parameters()))
48
+
49
+ # Create upsampler model.
50
+ options_up = model_and_diffusion_defaults_upsampler()
51
+ options_up['inpaint'] = True
52
+ options_up['use_fp16'] = has_cuda
53
+ options_up['timestep_respacing'] = 'fast27' # use 27 diffusion steps for very fast sampling
54
+ model_up, diffusion_up = create_model_and_diffusion(**options_up)
55
+ model_up.eval()
56
+ if has_cuda:
57
+ model_up.convert_to_fp16()
58
+ model_up.to(device)
59
+ model_up.load_state_dict(load_checkpoint('upsample-inpaint', device))
60
+ print('total upsampler parameters', sum(x.numel() for x in model_up.parameters()))
61
+
62
+ # Sampling parameters
63
+ batch_size = 1
64
+ guidance_scale = 5.0
65
+
66
+ # Tune this parameter to control the sharpness of 256x256 images.
67
+ # A value of 1.0 is sharper, but sometimes results in grainy artifacts.
68
+ upsample_temp = 0.997
69
+
70
+ # Create an classifier-free guidance sampling function
71
+ def model_fn(x_t, ts, **kwargs):
72
+ half = x_t[: len(x_t) // 2]
73
+ combined = th.cat([half, half], dim=0)
74
+ model_out = model(combined, ts, **kwargs)
75
+ eps, rest = model_out[:, :3], model_out[:, 3:]
76
+ cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
77
+ half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
78
+ eps = th.cat([half_eps, half_eps], dim=0)
79
+ return th.cat([eps, rest], dim=1)
80
+
81
+ def denoised_fn(x_start):
82
+ # Force the model to have the exact right x_start predictions
83
+ # for the part of the image which is known.
84
+ return (
85
+ x_start * (1 - model_kwargs['inpaint_mask'])
86
+ + model_kwargs['inpaint_image'] * model_kwargs['inpaint_mask']
87
+ )
88
+
89
+ def show_images(batch: th.Tensor):
90
+ """ Display a batch of images inline. """
91
+ scaled = ((batch + 1)*127.5).round().clamp(0,255).to(th.uint8).cpu()
92
+ reshaped = scaled.permute(2, 0, 3, 1).reshape([batch.shape[2], -1, 3])
93
+ return Image.fromarray(reshaped.numpy())
94
+
95
+ def read_image(path: str, size: int = 256) -> Tuple[th.Tensor, th.Tensor]:
96
+ pil_img = Image.open(path).convert('RGB')
97
+ pil_img = pil_img.resize((size, size), resample=Image.BICUBIC)
98
+ img = np.array(pil_img)
99
+ return th.from_numpy(img)[None].permute(0, 3, 1, 2).float() / 127.5 - 1
100
+
101
+ def pil_to_numpy(pil_img: Image) -> Tuple[th.Tensor, th.Tensor]:
102
+ img = np.array(pil_img)
103
+ return th.from_numpy(img)[None].permute(0, 3, 1, 2).float() / 127.5 - 1
104
+
105
+ model_kwargs = dict()
106
+ def inpaint(input_img, input_img_with_mask, prompt):
107
+
108
+ print(prompt)
109
+
110
+ # Save as png for later mask detection :)
111
+ input_img_256 = input_img.convert('RGB').resize((256, 256), resample=Image.BICUBIC)
112
+ input_img_64 = input_img.convert('RGB').resize((64, 64), resample=Image.BICUBIC)
113
+
114
+ # Source image we are inpainting
115
+ source_image_256 = pil_to_numpy(input_img_256)
116
+ source_image_64 = pil_to_numpy(input_img_64)
117
+
118
+ # Since gradio doesn't supply which pixels were drawn, we need to find it ourselves!
119
+ # Assuming that all black pixels are meant for inpainting.
120
+ input_img_with_mask_64 = input_img_with_mask.convert('L').resize((64, 64), resample=Image.BICUBIC)
121
+ gray_scale_source_image = image_to_tensor(input_img_with_mask_64)
122
+ source_mask_64 = (gray_scale_source_image!=0).float()
123
+ source_mask_64_img = tensor_to_image(source_mask_64)
124
+
125
+ # The mask should always be a boolean 64x64 mask, and then we
126
+ # can upsample it for the second stage.
127
+ source_mask_64 = source_mask_64.unsqueeze(0)
128
+ source_mask_256 = F.interpolate(source_mask_64, (256, 256), mode='nearest')
129
+
130
+
131
+ ##############################
132
+ # Sample from the base model #
133
+ ##############################
134
+
135
+ # Create the text tokens to feed to the model.
136
+ tokens = model.tokenizer.encode(prompt)
137
+ tokens, mask = model.tokenizer.padded_tokens_and_mask(
138
+ tokens, options['text_ctx']
139
+ )
140
+
141
+ # Create the classifier-free guidance tokens (empty)
142
+ full_batch_size = batch_size * 2
143
+ uncond_tokens, uncond_mask = model.tokenizer.padded_tokens_and_mask(
144
+ [], options['text_ctx']
145
+ )
146
+
147
+ # Pack the tokens together into model kwargs.
148
+ global model_kwargs
149
+ model_kwargs = dict(
150
+ tokens=th.tensor(
151
+ [tokens] * batch_size + [uncond_tokens] * batch_size, device=device
152
+ ),
153
+ mask=th.tensor(
154
+ [mask] * batch_size + [uncond_mask] * batch_size,
155
+ dtype=th.bool,
156
+ device=device,
157
+ ),
158
+
159
+ # Masked inpainting image
160
+ inpaint_image=(source_image_64 * source_mask_64).repeat(full_batch_size, 1, 1, 1).to(device),
161
+ inpaint_mask=source_mask_64.repeat(full_batch_size, 1, 1, 1).to(device),
162
+ )
163
+
164
+ # Sample from the base model.
165
+ model.del_cache()
166
+ samples = diffusion.p_sample_loop(
167
+ model_fn,
168
+ (full_batch_size, 3, options["image_size"], options["image_size"]),
169
+ device=device,
170
+ clip_denoised=True,
171
+ progress=True,
172
+ model_kwargs=model_kwargs,
173
+ cond_fn=None,
174
+ denoised_fn=denoised_fn,
175
+ )[:batch_size]
176
+ model.del_cache()
177
+
178
+ ##############################
179
+ # Upsample the 64x64 samples #
180
+ ##############################
181
+
182
+ tokens = model_up.tokenizer.encode(prompt)
183
+ tokens, mask = model_up.tokenizer.padded_tokens_and_mask(
184
+ tokens, options_up['text_ctx']
185
+ )
186
+
187
+ # Create the model conditioning dict.
188
+ model_kwargs = dict(
189
+ # Low-res image to upsample.
190
+ low_res=((samples+1)*127.5).round()/127.5 - 1,
191
+
192
+ # Text tokens
193
+ tokens=th.tensor(
194
+ [tokens] * batch_size, device=device
195
+ ),
196
+ mask=th.tensor(
197
+ [mask] * batch_size,
198
+ dtype=th.bool,
199
+ device=device,
200
+ ),
201
+
202
+ # Masked inpainting image.
203
+ inpaint_image=(source_image_256 * source_mask_256).repeat(batch_size, 1, 1, 1).to(device),
204
+ inpaint_mask=source_mask_256.repeat(batch_size, 1, 1, 1).to(device),
205
+ )
206
+
207
+ # Sample from the base model.
208
+ model_up.del_cache()
209
+ up_shape = (batch_size, 3, options_up["image_size"], options_up["image_size"])
210
+ up_samples = diffusion_up.p_sample_loop(
211
+ model_up,
212
+ up_shape,
213
+ noise=th.randn(up_shape, device=device) * upsample_temp,
214
+ device=device,
215
+ clip_denoised=True,
216
+ progress=True,
217
+ model_kwargs=model_kwargs,
218
+ cond_fn=None,
219
+ denoised_fn=denoised_fn,
220
+ )[:batch_size]
221
+ model_up.del_cache()
222
+
223
+ return source_mask_64_img, show_images(up_samples)
224
+
225
+ gradio_inputs = [gr.inputs.Image(type='pil',
226
+ label="Input Image"),
227
+ gr.inputs.Image(type='pil',
228
+ label="Input Image With Mask"),
229
+ gr.inputs.Textbox(label='Conditional Text to Inpaint')]
230
+
231
+ # gradio_outputs = [gr.outputs.Image(label='Auto-Detected Mask (From drawn black pixels)')]
232
+
233
+ gradio_outputs = [gr.outputs.Image(label='Auto-Detected Mask (From drawn black pixels)'),
234
+ gr.outputs.Image(label='Inpainted Image')]
235
+ #examples = [['grass.png', 'grass_with_mask.png', 'a corgi in a field']]
236
+
237
+ title = "GLIDE Inpainting"
238
+
239
+ #description = "[WARNING: Queue times may take 4-6 minutes per person if there's no GPU! If there is a GPU, it'll take around 60 seconds] Using GLIDE to inpaint black regions of an input image! Instructions: 1) For the 'Input Image', upload an image. 2) For the 'Input Image with Mask', draw a black-colored mask (either manually with something like Paint, or by using gradio's built-in image editor & add a black-colored shape) IT MUST BE BLACK COLOR, but doesn't have to be rectangular! This is because it auto-detects the mask based on 0 (black) pixel values! 3) For the Conditional Text, type something you'd like to see the black region get filled in with :)"
240
+
241
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.10741' target='_blank'>GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models</a> | <a href='https://github.com/openai/glide-text2im' target='_blank'>Github Repo</a> | <img src='https://visitor-badge.glitch.me/badge?page_id=epoching_glide_inpaint' alt='visitor badge'></p>"
242
+
243
+ iface = gr.Interface(fn=inpaint, inputs=gradio_inputs,
244
+ outputs=gradio_outputs,
245
+ examples=examples, title=title,
246
+ description=description, article=article,
247
+ enable_queue=True)
248
+ iface.launch()