Mariam-Elz commited on
Commit
28cc1df
·
verified ·
1 Parent(s): 0f95337

Upload libs/sample.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. libs/sample.py +380 -380
libs/sample.py CHANGED
@@ -1,380 +1,380 @@
1
- import numpy as np
2
- import torch
3
- from imagedream.camera_utils import get_camera_for_index
4
- from imagedream.ldm.util import set_seed, add_random_background
5
- from libs.base_utils import do_resize_content
6
- from imagedream.ldm.models.diffusion.ddim import DDIMSampler
7
- from torchvision import transforms as T
8
-
9
-
10
- class ImageDreamDiffusion:
11
- def __init__(
12
- self,
13
- model,
14
- device,
15
- dtype,
16
- mode,
17
- num_frames,
18
- camera_views,
19
- ref_position,
20
- random_background=False,
21
- offset_noise=False,
22
- resize_rate=1,
23
- image_size=256,
24
- seed=1234,
25
- ) -> None:
26
- assert mode in ["pixel", "local"]
27
- size = image_size
28
- self.seed = seed
29
- batch_size = max(4, num_frames)
30
-
31
- neg_texts = "uniform low no texture ugly, boring, bad anatomy, blurry, pixelated, obscure, unnatural colors, poor lighting, dull, and unclear."
32
- uc = model.get_learned_conditioning([neg_texts]).to(device)
33
- sampler = DDIMSampler(model)
34
-
35
- # pre-compute camera matrices
36
- camera = [get_camera_for_index(i).squeeze() for i in camera_views]
37
- camera[ref_position] = torch.zeros_like(camera[ref_position]) # set ref camera to zero
38
- camera = torch.stack(camera)
39
- camera = camera.repeat(batch_size // num_frames, 1).to(device)
40
-
41
- self.image_transform = T.Compose(
42
- [
43
- T.Resize((size, size)),
44
- T.ToTensor(),
45
- T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
46
- ]
47
- )
48
- self.dtype = dtype
49
- self.ref_position = ref_position
50
- self.mode = mode
51
- self.random_background = random_background
52
- self.resize_rate = resize_rate
53
- self.num_frames = num_frames
54
- self.size = size
55
- self.device = device
56
- self.batch_size = batch_size
57
- self.model = model
58
- self.sampler = sampler
59
- self.uc = uc
60
- self.camera = camera
61
- self.offset_noise = offset_noise
62
-
63
- @staticmethod
64
- def i2i(
65
- model,
66
- image_size,
67
- prompt,
68
- uc,
69
- sampler,
70
- ip=None,
71
- step=20,
72
- scale=5.0,
73
- batch_size=8,
74
- ddim_eta=0.0,
75
- dtype=torch.float32,
76
- device="cuda",
77
- camera=None,
78
- num_frames=4,
79
- pixel_control=False,
80
- transform=None,
81
- offset_noise=False,
82
- ):
83
- """ The function supports additional image prompt.
84
- Args:
85
- model (_type_): the image dream model
86
- image_size (_type_): size of diffusion output (standard 256)
87
- prompt (_type_): text prompt for the image (prompt in type str)
88
- uc (_type_): unconditional vector (tensor in shape [1, 77, 1024])
89
- sampler (_type_): imagedream.ldm.models.diffusion.ddim.DDIMSampler
90
- ip (Image, optional): the image prompt. Defaults to None.
91
- step (int, optional): _description_. Defaults to 20.
92
- scale (float, optional): _description_. Defaults to 7.5.
93
- batch_size (int, optional): _description_. Defaults to 8.
94
- ddim_eta (float, optional): _description_. Defaults to 0.0.
95
- dtype (_type_, optional): _description_. Defaults to torch.float32.
96
- device (str, optional): _description_. Defaults to "cuda".
97
- camera (_type_, optional): camera info in tensor, shape: torch.Size([5, 16]) mean: 0.11, std: 0.49, min: -1.00, max: 1.00
98
- num_frames (int, optional): _num of frames (views) to generate
99
- pixel_control: whether to use pixel conditioning. Defaults to False, True when using pixel mode
100
- transform: Compose(
101
- Resize(size=(256, 256), interpolation=bilinear, max_size=None, antialias=warn)
102
- ToTensor()
103
- Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
104
- )
105
- """
106
- ip_raw = ip
107
- if type(prompt) != list:
108
- prompt = [prompt]
109
- with torch.no_grad(), torch.autocast(device_type=torch.device(device).type, dtype=dtype):
110
- c = model.get_learned_conditioning(prompt).to(
111
- device
112
- ) # shape: torch.Size([1, 77, 1024]) mean: -0.17, std: 1.02, min: -7.50, max: 13.05
113
- c_ = {"context": c.repeat(batch_size, 1, 1)} # batch_size
114
- uc_ = {"context": uc.repeat(batch_size, 1, 1)}
115
-
116
- if camera is not None:
117
- c_["camera"] = uc_["camera"] = (
118
- camera # shape: torch.Size([5, 16]) mean: 0.11, std: 0.49, min: -1.00, max: 1.00
119
- )
120
- c_["num_frames"] = uc_["num_frames"] = num_frames
121
-
122
- if ip is not None:
123
- ip_embed = model.get_learned_image_conditioning(ip).to(
124
- device
125
- ) # shape: torch.Size([1, 257, 1280]) mean: 0.06, std: 0.53, min: -6.83, max: 11.12
126
- ip_ = ip_embed.repeat(batch_size, 1, 1)
127
- c_["ip"] = ip_
128
- uc_["ip"] = torch.zeros_like(ip_)
129
-
130
- if pixel_control:
131
- assert camera is not None
132
- ip = transform(ip).to(
133
- device
134
- ) # shape: torch.Size([3, 256, 256]) mean: 0.33, std: 0.37, min: -1.00, max: 1.00
135
- ip_img = model.get_first_stage_encoding(
136
- model.encode_first_stage(ip[None, :, :, :])
137
- ) # shape: torch.Size([1, 4, 32, 32]) mean: 0.23, std: 0.77, min: -4.42, max: 3.55
138
- c_["ip_img"] = ip_img
139
- uc_["ip_img"] = torch.zeros_like(ip_img)
140
-
141
- shape = [4, image_size // 8, image_size // 8] # [4, 32, 32]
142
- if offset_noise:
143
- ref = transform(ip_raw).to(device)
144
- ref_latent = model.get_first_stage_encoding(model.encode_first_stage(ref[None, :, :, :]))
145
- ref_mean = ref_latent.mean(dim=(-1, -2), keepdim=True)
146
- time_steps = torch.randint(model.num_timesteps - 1, model.num_timesteps, (batch_size,), device=device)
147
- x_T = model.q_sample(torch.ones([batch_size] + shape, device=device) * ref_mean, time_steps)
148
-
149
- samples_ddim, _ = (
150
- sampler.sample( # shape: torch.Size([5, 4, 32, 32]) mean: 0.29, std: 0.85, min: -3.38, max: 4.43
151
- S=step,
152
- conditioning=c_,
153
- batch_size=batch_size,
154
- shape=shape,
155
- verbose=False,
156
- unconditional_guidance_scale=scale,
157
- unconditional_conditioning=uc_,
158
- eta=ddim_eta,
159
- x_T=x_T if offset_noise else None,
160
- )
161
- )
162
-
163
- x_sample = model.decode_first_stage(samples_ddim)
164
- x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
165
- x_sample = 255.0 * x_sample.permute(0, 2, 3, 1).cpu().numpy()
166
-
167
- return list(x_sample.astype(np.uint8))
168
-
169
- def diffuse(self, t, ip, n_test=2):
170
- set_seed(self.seed)
171
- ip = do_resize_content(ip, self.resize_rate)
172
- if self.random_background:
173
- ip = add_random_background(ip)
174
-
175
- images = []
176
- for _ in range(n_test):
177
- img = self.i2i(
178
- self.model,
179
- self.size,
180
- t,
181
- self.uc,
182
- self.sampler,
183
- ip=ip,
184
- step=50,
185
- scale=5,
186
- batch_size=self.batch_size,
187
- ddim_eta=0.0,
188
- dtype=self.dtype,
189
- device=self.device,
190
- camera=self.camera,
191
- num_frames=self.num_frames,
192
- pixel_control=(self.mode == "pixel"),
193
- transform=self.image_transform,
194
- offset_noise=self.offset_noise,
195
- )
196
- img = np.concatenate(img, 1)
197
- img = np.concatenate((img, ip.resize((self.size, self.size))), axis=1)
198
- images.append(img)
199
- set_seed() # unset random and numpy seed
200
- return images
201
-
202
-
203
- class ImageDreamDiffusionStage2:
204
- def __init__(
205
- self,
206
- model,
207
- device,
208
- dtype,
209
- num_frames,
210
- camera_views,
211
- ref_position,
212
- random_background=False,
213
- offset_noise=False,
214
- resize_rate=1,
215
- mode="pixel",
216
- image_size=256,
217
- seed=1234,
218
- ) -> None:
219
- assert mode in ["pixel", "local"]
220
-
221
- size = image_size
222
- self.seed = seed
223
- batch_size = max(4, num_frames)
224
-
225
- neg_texts = "uniform low no texture ugly, boring, bad anatomy, blurry, pixelated, obscure, unnatural colors, poor lighting, dull, and unclear."
226
- uc = model.get_learned_conditioning([neg_texts]).to(device)
227
- sampler = DDIMSampler(model)
228
-
229
- # pre-compute camera matrices
230
- camera = [get_camera_for_index(i).squeeze() for i in camera_views]
231
- if ref_position is not None:
232
- camera[ref_position] = torch.zeros_like(camera[ref_position]) # set ref camera to zero
233
- camera = torch.stack(camera)
234
- camera = camera.repeat(batch_size // num_frames, 1).to(device)
235
-
236
- self.image_transform = T.Compose(
237
- [
238
- T.Resize((size, size)),
239
- T.ToTensor(),
240
- T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
241
- ]
242
- )
243
-
244
- self.dtype = dtype
245
- self.mode = mode
246
- self.ref_position = ref_position
247
- self.random_background = random_background
248
- self.resize_rate = resize_rate
249
- self.num_frames = num_frames
250
- self.size = size
251
- self.device = device
252
- self.batch_size = batch_size
253
- self.model = model
254
- self.sampler = sampler
255
- self.uc = uc
256
- self.camera = camera
257
- self.offset_noise = offset_noise
258
-
259
- @staticmethod
260
- def i2iStage2(
261
- model,
262
- image_size,
263
- prompt,
264
- uc,
265
- sampler,
266
- pixel_images,
267
- ip=None,
268
- step=20,
269
- scale=5.0,
270
- batch_size=8,
271
- ddim_eta=0.0,
272
- dtype=torch.float32,
273
- device="cuda",
274
- camera=None,
275
- num_frames=4,
276
- pixel_control=False,
277
- transform=None,
278
- offset_noise=False,
279
- ):
280
- ip_raw = ip
281
- if type(prompt) != list:
282
- prompt = [prompt]
283
- with torch.no_grad(), torch.autocast(device_type=torch.device(device).type, dtype=dtype):
284
- c = model.get_learned_conditioning(prompt).to(
285
- device
286
- ) # shape: torch.Size([1, 77, 1024]) mean: -0.17, std: 1.02, min: -7.50, max: 13.05
287
- c_ = {"context": c.repeat(batch_size, 1, 1)} # batch_size
288
- uc_ = {"context": uc.repeat(batch_size, 1, 1)}
289
-
290
- if camera is not None:
291
- c_["camera"] = uc_["camera"] = (
292
- camera # shape: torch.Size([5, 16]) mean: 0.11, std: 0.49, min: -1.00, max: 1.00
293
- )
294
- c_["num_frames"] = uc_["num_frames"] = num_frames
295
-
296
- if ip is not None:
297
- ip_embed = model.get_learned_image_conditioning(ip).to(
298
- device
299
- ) # shape: torch.Size([1, 257, 1280]) mean: 0.06, std: 0.53, min: -6.83, max: 11.12
300
- ip_ = ip_embed.repeat(batch_size, 1, 1)
301
- c_["ip"] = ip_
302
- uc_["ip"] = torch.zeros_like(ip_)
303
-
304
- if pixel_control:
305
- assert camera is not None
306
-
307
- transed_pixel_images = torch.stack([transform(i).to(device) for i in pixel_images])
308
- latent_pixel_images = model.get_first_stage_encoding(model.encode_first_stage(transed_pixel_images))
309
-
310
- c_["pixel_images"] = latent_pixel_images
311
- uc_["pixel_images"] = torch.zeros_like(latent_pixel_images)
312
-
313
- shape = [4, image_size // 8, image_size // 8] # [4, 32, 32]
314
- if offset_noise:
315
- ref = transform(ip_raw).to(device)
316
- ref_latent = model.get_first_stage_encoding(model.encode_first_stage(ref[None, :, :, :]))
317
- ref_mean = ref_latent.mean(dim=(-1, -2), keepdim=True)
318
- time_steps = torch.randint(model.num_timesteps - 1, model.num_timesteps, (batch_size,), device=device)
319
- x_T = model.q_sample(torch.ones([batch_size] + shape, device=device) * ref_mean, time_steps)
320
-
321
- samples_ddim, _ = (
322
- sampler.sample( # shape: torch.Size([5, 4, 32, 32]) mean: 0.29, std: 0.85, min: -3.38, max: 4.43
323
- S=step,
324
- conditioning=c_,
325
- batch_size=batch_size,
326
- shape=shape,
327
- verbose=False,
328
- unconditional_guidance_scale=scale,
329
- unconditional_conditioning=uc_,
330
- eta=ddim_eta,
331
- x_T=x_T if offset_noise else None,
332
- )
333
- )
334
- x_sample = model.decode_first_stage(samples_ddim)
335
- x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
336
- x_sample = 255.0 * x_sample.permute(0, 2, 3, 1).cpu().numpy()
337
-
338
- return list(x_sample.astype(np.uint8))
339
-
340
- @torch.no_grad()
341
- def diffuse(self, t, ip, pixel_images, n_test=2):
342
- set_seed(self.seed)
343
- ip = do_resize_content(ip, self.resize_rate)
344
- pixel_images = [do_resize_content(i, self.resize_rate) for i in pixel_images]
345
-
346
- if self.random_background:
347
- bg_color = np.random.rand() * 255
348
- ip = add_random_background(ip, bg_color)
349
- pixel_images = [add_random_background(i, bg_color) for i in pixel_images]
350
-
351
- images = []
352
- for _ in range(n_test):
353
- img = self.i2iStage2(
354
- self.model,
355
- self.size,
356
- t,
357
- self.uc,
358
- self.sampler,
359
- pixel_images=pixel_images,
360
- ip=ip,
361
- step=50,
362
- scale=5,
363
- batch_size=self.batch_size,
364
- ddim_eta=0.0,
365
- dtype=self.dtype,
366
- device=self.device,
367
- camera=self.camera,
368
- num_frames=self.num_frames,
369
- pixel_control=(self.mode == "pixel"),
370
- transform=self.image_transform,
371
- offset_noise=self.offset_noise,
372
- )
373
- img = np.concatenate(img, 1)
374
- img = np.concatenate(
375
- (img, ip.resize((self.size, self.size)), *[i.resize((self.size, self.size)) for i in pixel_images]),
376
- axis=1,
377
- )
378
- images.append(img)
379
- set_seed() # unset random and numpy seed
380
- return images
 
1
+ import numpy as np
2
+ import torch
3
+ from imagedream.camera_utils import get_camera_for_index
4
+ from imagedream.ldm.util import set_seed, add_random_background
5
+ from libs.base_utils import do_resize_content
6
+ from imagedream.ldm.models.diffusion.ddim import DDIMSampler
7
+ from torchvision import transforms as T
8
+
9
+
10
+ class ImageDreamDiffusion:
11
+ def __init__(
12
+ self,
13
+ model,
14
+ device,
15
+ dtype,
16
+ mode,
17
+ num_frames,
18
+ camera_views,
19
+ ref_position,
20
+ random_background=False,
21
+ offset_noise=False,
22
+ resize_rate=1,
23
+ image_size=256,
24
+ seed=1234,
25
+ ) -> None:
26
+ assert mode in ["pixel", "local"]
27
+ size = image_size
28
+ self.seed = seed
29
+ batch_size = max(4, num_frames)
30
+
31
+ neg_texts = "uniform low no texture ugly, boring, bad anatomy, blurry, pixelated, obscure, unnatural colors, poor lighting, dull, and unclear."
32
+ uc = model.get_learned_conditioning([neg_texts]).to(device)
33
+ sampler = DDIMSampler(model)
34
+
35
+ # pre-compute camera matrices
36
+ camera = [get_camera_for_index(i).squeeze() for i in camera_views]
37
+ camera[ref_position] = torch.zeros_like(camera[ref_position]) # set ref camera to zero
38
+ camera = torch.stack(camera)
39
+ camera = camera.repeat(batch_size // num_frames, 1).to(device)
40
+
41
+ self.image_transform = T.Compose(
42
+ [
43
+ T.Resize((size, size)),
44
+ T.ToTensor(),
45
+ T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
46
+ ]
47
+ )
48
+ self.dtype = dtype
49
+ self.ref_position = ref_position
50
+ self.mode = mode
51
+ self.random_background = random_background
52
+ self.resize_rate = resize_rate
53
+ self.num_frames = num_frames
54
+ self.size = size
55
+ self.device = device
56
+ self.batch_size = batch_size
57
+ self.model = model
58
+ self.sampler = sampler
59
+ self.uc = uc
60
+ self.camera = camera
61
+ self.offset_noise = offset_noise
62
+
63
+ @staticmethod
64
+ def i2i(
65
+ model,
66
+ image_size,
67
+ prompt,
68
+ uc,
69
+ sampler,
70
+ ip=None,
71
+ step=20,
72
+ scale=5.0,
73
+ batch_size=8,
74
+ ddim_eta=0.0,
75
+ dtype=torch.float32,
76
+ device="cuda",
77
+ camera=None,
78
+ num_frames=4,
79
+ pixel_control=False,
80
+ transform=None,
81
+ offset_noise=False,
82
+ ):
83
+ """ The function supports additional image prompt.
84
+ Args:
85
+ model (_type_): the image dream model
86
+ image_size (_type_): size of diffusion output (standard 256)
87
+ prompt (_type_): text prompt for the image (prompt in type str)
88
+ uc (_type_): unconditional vector (tensor in shape [1, 77, 1024])
89
+ sampler (_type_): imagedream.ldm.models.diffusion.ddim.DDIMSampler
90
+ ip (Image, optional): the image prompt. Defaults to None.
91
+ step (int, optional): _description_. Defaults to 20.
92
+ scale (float, optional): _description_. Defaults to 7.5.
93
+ batch_size (int, optional): _description_. Defaults to 8.
94
+ ddim_eta (float, optional): _description_. Defaults to 0.0.
95
+ dtype (_type_, optional): _description_. Defaults to torch.float32.
96
+ device (str, optional): _description_. Defaults to "cuda".
97
+ camera (_type_, optional): camera info in tensor, shape: torch.Size([5, 16]) mean: 0.11, std: 0.49, min: -1.00, max: 1.00
98
+ num_frames (int, optional): _num of frames (views) to generate
99
+ pixel_control: whether to use pixel conditioning. Defaults to False, True when using pixel mode
100
+ transform: Compose(
101
+ Resize(size=(256, 256), interpolation=bilinear, max_size=None, antialias=warn)
102
+ ToTensor()
103
+ Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
104
+ )
105
+ """
106
+ ip_raw = ip
107
+ if type(prompt) != list:
108
+ prompt = [prompt]
109
+ with torch.no_grad(), torch.autocast(device_type=torch.device(device).type, dtype=dtype):
110
+ c = model.get_learned_conditioning(prompt).to(
111
+ device
112
+ ) # shape: torch.Size([1, 77, 1024]) mean: -0.17, std: 1.02, min: -7.50, max: 13.05
113
+ c_ = {"context": c.repeat(batch_size, 1, 1)} # batch_size
114
+ uc_ = {"context": uc.repeat(batch_size, 1, 1)}
115
+
116
+ if camera is not None:
117
+ c_["camera"] = uc_["camera"] = (
118
+ camera # shape: torch.Size([5, 16]) mean: 0.11, std: 0.49, min: -1.00, max: 1.00
119
+ )
120
+ c_["num_frames"] = uc_["num_frames"] = num_frames
121
+
122
+ if ip is not None:
123
+ ip_embed = model.get_learned_image_conditioning(ip).to(
124
+ device
125
+ ) # shape: torch.Size([1, 257, 1280]) mean: 0.06, std: 0.53, min: -6.83, max: 11.12
126
+ ip_ = ip_embed.repeat(batch_size, 1, 1)
127
+ c_["ip"] = ip_
128
+ uc_["ip"] = torch.zeros_like(ip_)
129
+
130
+ if pixel_control:
131
+ assert camera is not None
132
+ ip = transform(ip).to(
133
+ device
134
+ ) # shape: torch.Size([3, 256, 256]) mean: 0.33, std: 0.37, min: -1.00, max: 1.00
135
+ ip_img = model.get_first_stage_encoding(
136
+ model.encode_first_stage(ip[None, :, :, :])
137
+ ) # shape: torch.Size([1, 4, 32, 32]) mean: 0.23, std: 0.77, min: -4.42, max: 3.55
138
+ c_["ip_img"] = ip_img
139
+ uc_["ip_img"] = torch.zeros_like(ip_img)
140
+
141
+ shape = [4, image_size // 8, image_size // 8] # [4, 32, 32]
142
+ if offset_noise:
143
+ ref = transform(ip_raw).to(device)
144
+ ref_latent = model.get_first_stage_encoding(model.encode_first_stage(ref[None, :, :, :]))
145
+ ref_mean = ref_latent.mean(dim=(-1, -2), keepdim=True)
146
+ time_steps = torch.randint(model.num_timesteps - 1, model.num_timesteps, (batch_size,), device=device)
147
+ x_T = model.q_sample(torch.ones([batch_size] + shape, device=device) * ref_mean, time_steps)
148
+
149
+ samples_ddim, _ = (
150
+ sampler.sample( # shape: torch.Size([5, 4, 32, 32]) mean: 0.29, std: 0.85, min: -3.38, max: 4.43
151
+ S=step,
152
+ conditioning=c_,
153
+ batch_size=batch_size,
154
+ shape=shape,
155
+ verbose=False,
156
+ unconditional_guidance_scale=scale,
157
+ unconditional_conditioning=uc_,
158
+ eta=ddim_eta,
159
+ x_T=x_T if offset_noise else None,
160
+ )
161
+ )
162
+
163
+ x_sample = model.decode_first_stage(samples_ddim)
164
+ x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
165
+ x_sample = 255.0 * x_sample.permute(0, 2, 3, 1).cpu().numpy()
166
+
167
+ return list(x_sample.astype(np.uint8))
168
+
169
+ def diffuse(self, t, ip, n_test=2):
170
+ set_seed(self.seed)
171
+ ip = do_resize_content(ip, self.resize_rate)
172
+ if self.random_background:
173
+ ip = add_random_background(ip)
174
+
175
+ images = []
176
+ for _ in range(n_test):
177
+ img = self.i2i(
178
+ self.model,
179
+ self.size,
180
+ t,
181
+ self.uc,
182
+ self.sampler,
183
+ ip=ip,
184
+ step=50,
185
+ scale=5,
186
+ batch_size=self.batch_size,
187
+ ddim_eta=0.0,
188
+ dtype=self.dtype,
189
+ device=self.device,
190
+ camera=self.camera,
191
+ num_frames=self.num_frames,
192
+ pixel_control=(self.mode == "pixel"),
193
+ transform=self.image_transform,
194
+ offset_noise=self.offset_noise,
195
+ )
196
+ img = np.concatenate(img, 1)
197
+ img = np.concatenate((img, ip.resize((self.size, self.size))), axis=1)
198
+ images.append(img)
199
+ set_seed() # unset random and numpy seed
200
+ return images
201
+
202
+
203
+ class ImageDreamDiffusionStage2:
204
+ def __init__(
205
+ self,
206
+ model,
207
+ device,
208
+ dtype,
209
+ num_frames,
210
+ camera_views,
211
+ ref_position,
212
+ random_background=False,
213
+ offset_noise=False,
214
+ resize_rate=1,
215
+ mode="pixel",
216
+ image_size=256,
217
+ seed=1234,
218
+ ) -> None:
219
+ assert mode in ["pixel", "local"]
220
+
221
+ size = image_size
222
+ self.seed = seed
223
+ batch_size = max(4, num_frames)
224
+
225
+ neg_texts = "uniform low no texture ugly, boring, bad anatomy, blurry, pixelated, obscure, unnatural colors, poor lighting, dull, and unclear."
226
+ uc = model.get_learned_conditioning([neg_texts]).to(device)
227
+ sampler = DDIMSampler(model)
228
+
229
+ # pre-compute camera matrices
230
+ camera = [get_camera_for_index(i).squeeze() for i in camera_views]
231
+ if ref_position is not None:
232
+ camera[ref_position] = torch.zeros_like(camera[ref_position]) # set ref camera to zero
233
+ camera = torch.stack(camera)
234
+ camera = camera.repeat(batch_size // num_frames, 1).to(device)
235
+
236
+ self.image_transform = T.Compose(
237
+ [
238
+ T.Resize((size, size)),
239
+ T.ToTensor(),
240
+ T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
241
+ ]
242
+ )
243
+
244
+ self.dtype = dtype
245
+ self.mode = mode
246
+ self.ref_position = ref_position
247
+ self.random_background = random_background
248
+ self.resize_rate = resize_rate
249
+ self.num_frames = num_frames
250
+ self.size = size
251
+ self.device = device
252
+ self.batch_size = batch_size
253
+ self.model = model
254
+ self.sampler = sampler
255
+ self.uc = uc
256
+ self.camera = camera
257
+ self.offset_noise = offset_noise
258
+
259
+ @staticmethod
260
+ def i2iStage2(
261
+ model,
262
+ image_size,
263
+ prompt,
264
+ uc,
265
+ sampler,
266
+ pixel_images,
267
+ ip=None,
268
+ step=20,
269
+ scale=5.0,
270
+ batch_size=8,
271
+ ddim_eta=0.0,
272
+ dtype=torch.float32,
273
+ device="cuda",
274
+ camera=None,
275
+ num_frames=4,
276
+ pixel_control=False,
277
+ transform=None,
278
+ offset_noise=False,
279
+ ):
280
+ ip_raw = ip
281
+ if type(prompt) != list:
282
+ prompt = [prompt]
283
+ with torch.no_grad(), torch.autocast(device_type=torch.device(device).type, dtype=dtype):
284
+ c = model.get_learned_conditioning(prompt).to(
285
+ device
286
+ ) # shape: torch.Size([1, 77, 1024]) mean: -0.17, std: 1.02, min: -7.50, max: 13.05
287
+ c_ = {"context": c.repeat(batch_size, 1, 1)} # batch_size
288
+ uc_ = {"context": uc.repeat(batch_size, 1, 1)}
289
+
290
+ if camera is not None:
291
+ c_["camera"] = uc_["camera"] = (
292
+ camera # shape: torch.Size([5, 16]) mean: 0.11, std: 0.49, min: -1.00, max: 1.00
293
+ )
294
+ c_["num_frames"] = uc_["num_frames"] = num_frames
295
+
296
+ if ip is not None:
297
+ ip_embed = model.get_learned_image_conditioning(ip).to(
298
+ device
299
+ ) # shape: torch.Size([1, 257, 1280]) mean: 0.06, std: 0.53, min: -6.83, max: 11.12
300
+ ip_ = ip_embed.repeat(batch_size, 1, 1)
301
+ c_["ip"] = ip_
302
+ uc_["ip"] = torch.zeros_like(ip_)
303
+
304
+ if pixel_control:
305
+ assert camera is not None
306
+
307
+ transed_pixel_images = torch.stack([transform(i).to(device) for i in pixel_images])
308
+ latent_pixel_images = model.get_first_stage_encoding(model.encode_first_stage(transed_pixel_images))
309
+
310
+ c_["pixel_images"] = latent_pixel_images
311
+ uc_["pixel_images"] = torch.zeros_like(latent_pixel_images)
312
+
313
+ shape = [4, image_size // 8, image_size // 8] # [4, 32, 32]
314
+ if offset_noise:
315
+ ref = transform(ip_raw).to(device)
316
+ ref_latent = model.get_first_stage_encoding(model.encode_first_stage(ref[None, :, :, :]))
317
+ ref_mean = ref_latent.mean(dim=(-1, -2), keepdim=True)
318
+ time_steps = torch.randint(model.num_timesteps - 1, model.num_timesteps, (batch_size,), device=device)
319
+ x_T = model.q_sample(torch.ones([batch_size] + shape, device=device) * ref_mean, time_steps)
320
+
321
+ samples_ddim, _ = (
322
+ sampler.sample( # shape: torch.Size([5, 4, 32, 32]) mean: 0.29, std: 0.85, min: -3.38, max: 4.43
323
+ S=step,
324
+ conditioning=c_,
325
+ batch_size=batch_size,
326
+ shape=shape,
327
+ verbose=False,
328
+ unconditional_guidance_scale=scale,
329
+ unconditional_conditioning=uc_,
330
+ eta=ddim_eta,
331
+ x_T=x_T if offset_noise else None,
332
+ )
333
+ )
334
+ x_sample = model.decode_first_stage(samples_ddim)
335
+ x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
336
+ x_sample = 255.0 * x_sample.permute(0, 2, 3, 1).cpu().numpy()
337
+
338
+ return list(x_sample.astype(np.uint8))
339
+
340
+ @torch.no_grad()
341
+ def diffuse(self, t, ip, pixel_images, n_test=2):
342
+ set_seed(self.seed)
343
+ ip = do_resize_content(ip, self.resize_rate)
344
+ pixel_images = [do_resize_content(i, self.resize_rate) for i in pixel_images]
345
+
346
+ if self.random_background:
347
+ bg_color = np.random.rand() * 255
348
+ ip = add_random_background(ip, bg_color)
349
+ pixel_images = [add_random_background(i, bg_color) for i in pixel_images]
350
+
351
+ images = []
352
+ for _ in range(n_test):
353
+ img = self.i2iStage2(
354
+ self.model,
355
+ self.size,
356
+ t,
357
+ self.uc,
358
+ self.sampler,
359
+ pixel_images=pixel_images,
360
+ ip=ip,
361
+ step=50,
362
+ scale=5,
363
+ batch_size=self.batch_size,
364
+ ddim_eta=0.0,
365
+ dtype=self.dtype,
366
+ device=self.device,
367
+ camera=self.camera,
368
+ num_frames=self.num_frames,
369
+ pixel_control=(self.mode == "pixel"),
370
+ transform=self.image_transform,
371
+ offset_noise=self.offset_noise,
372
+ )
373
+ img = np.concatenate(img, 1)
374
+ img = np.concatenate(
375
+ (img, ip.resize((self.size, self.size)), *[i.resize((self.size, self.size)) for i in pixel_images]),
376
+ axis=1,
377
+ )
378
+ images.append(img)
379
+ set_seed() # unset random and numpy seed
380
+ return images