Mariam-Elz commited on
Commit
6cb056f
·
verified ·
1 Parent(s): a5de84f

Upload imagedream/ldm/models/diffusion/ddim.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. imagedream/ldm/models/diffusion/ddim.py +430 -430
imagedream/ldm/models/diffusion/ddim.py CHANGED
@@ -1,430 +1,430 @@
1
- """SAMPLING ONLY."""
2
-
3
- import torch
4
- import numpy as np
5
- from tqdm import tqdm
6
- from functools import partial
7
-
8
- from ...modules.diffusionmodules.util import (
9
- make_ddim_sampling_parameters,
10
- make_ddim_timesteps,
11
- noise_like,
12
- extract_into_tensor,
13
- )
14
-
15
-
16
- class DDIMSampler(object):
17
- def __init__(self, model, schedule="linear", **kwargs):
18
- super().__init__()
19
- self.model = model
20
- self.ddpm_num_timesteps = model.num_timesteps
21
- self.schedule = schedule
22
-
23
- def register_buffer(self, name, attr):
24
- if type(attr) == torch.Tensor:
25
- if attr.device != torch.device("cuda"):
26
- attr = attr.to(torch.device("cuda"))
27
- setattr(self, name, attr)
28
-
29
- def make_schedule(
30
- self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
31
- ):
32
- self.ddim_timesteps = make_ddim_timesteps(
33
- ddim_discr_method=ddim_discretize,
34
- num_ddim_timesteps=ddim_num_steps,
35
- num_ddpm_timesteps=self.ddpm_num_timesteps,
36
- verbose=verbose,
37
- )
38
- alphas_cumprod = self.model.alphas_cumprod
39
- assert (
40
- alphas_cumprod.shape[0] == self.ddpm_num_timesteps
41
- ), "alphas have to be defined for each timestep"
42
- to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
43
-
44
- self.register_buffer("betas", to_torch(self.model.betas))
45
- self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
46
- self.register_buffer(
47
- "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
48
- )
49
-
50
- # calculations for diffusion q(x_t | x_{t-1}) and others
51
- self.register_buffer(
52
- "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
53
- )
54
- self.register_buffer(
55
- "sqrt_one_minus_alphas_cumprod",
56
- to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
57
- )
58
- self.register_buffer(
59
- "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
60
- )
61
- self.register_buffer(
62
- "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
63
- )
64
- self.register_buffer(
65
- "sqrt_recipm1_alphas_cumprod",
66
- to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
67
- )
68
-
69
- # ddim sampling parameters
70
- ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
71
- alphacums=alphas_cumprod.cpu(),
72
- ddim_timesteps=self.ddim_timesteps,
73
- eta=ddim_eta,
74
- verbose=verbose,
75
- )
76
- self.register_buffer("ddim_sigmas", ddim_sigmas)
77
- self.register_buffer("ddim_alphas", ddim_alphas)
78
- self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
79
- self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
80
- sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
81
- (1 - self.alphas_cumprod_prev)
82
- / (1 - self.alphas_cumprod)
83
- * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
84
- )
85
- self.register_buffer(
86
- "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
87
- )
88
-
89
- @torch.no_grad()
90
- def sample(
91
- self,
92
- S,
93
- batch_size,
94
- shape,
95
- conditioning=None,
96
- callback=None,
97
- normals_sequence=None,
98
- img_callback=None,
99
- quantize_x0=False,
100
- eta=0.0,
101
- mask=None,
102
- x0=None,
103
- temperature=1.0,
104
- noise_dropout=0.0,
105
- score_corrector=None,
106
- corrector_kwargs=None,
107
- verbose=True,
108
- x_T=None,
109
- log_every_t=100,
110
- unconditional_guidance_scale=1.0,
111
- unconditional_conditioning=None,
112
- # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
113
- **kwargs,
114
- ):
115
- if conditioning is not None:
116
- if isinstance(conditioning, dict):
117
- cbs = conditioning[list(conditioning.keys())[0]].shape[0]
118
- if cbs != batch_size:
119
- print(
120
- f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
121
- )
122
- else:
123
- if conditioning.shape[0] != batch_size:
124
- print(
125
- f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
126
- )
127
-
128
- self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
129
- # sampling
130
- C, H, W = shape
131
- size = (batch_size, C, H, W)
132
-
133
- samples, intermediates = self.ddim_sampling(
134
- conditioning,
135
- size,
136
- callback=callback,
137
- img_callback=img_callback,
138
- quantize_denoised=quantize_x0,
139
- mask=mask,
140
- x0=x0,
141
- ddim_use_original_steps=False,
142
- noise_dropout=noise_dropout,
143
- temperature=temperature,
144
- score_corrector=score_corrector,
145
- corrector_kwargs=corrector_kwargs,
146
- x_T=x_T,
147
- log_every_t=log_every_t,
148
- unconditional_guidance_scale=unconditional_guidance_scale,
149
- unconditional_conditioning=unconditional_conditioning,
150
- **kwargs,
151
- )
152
- return samples, intermediates
153
-
154
- @torch.no_grad()
155
- def ddim_sampling(
156
- self,
157
- cond,
158
- shape,
159
- x_T=None,
160
- ddim_use_original_steps=False,
161
- callback=None,
162
- timesteps=None,
163
- quantize_denoised=False,
164
- mask=None,
165
- x0=None,
166
- img_callback=None,
167
- log_every_t=100,
168
- temperature=1.0,
169
- noise_dropout=0.0,
170
- score_corrector=None,
171
- corrector_kwargs=None,
172
- unconditional_guidance_scale=1.0,
173
- unconditional_conditioning=None,
174
- **kwargs,
175
- ):
176
- """
177
- when inference time: all values of parameter
178
- cond.keys(): dict_keys(['context', 'camera', 'num_frames', 'ip', 'ip_img'])
179
- shape: (5, 4, 32, 32)
180
- x_T: None
181
- ddim_use_original_steps: False
182
- timesteps: None
183
- callback: None
184
- quantize_denoised: False
185
- mask: None
186
- image_callback: None
187
- log_every_t: 100
188
- temperature: 1.0
189
- noise_dropout: 0.0
190
- score_corrector: None
191
- corrector_kwargs: None
192
- unconditional_guidance_scale: 5
193
- unconditional_conditioning.keys(): dict_keys(['context', 'camera', 'num_frames', 'ip', 'ip_img'])
194
- kwargs: {}
195
- """
196
- device = self.model.betas.device
197
- b = shape[0]
198
- if x_T is None:
199
- img = torch.randn(shape, device=device) # shape: torch.Size([5, 4, 32, 32]) mean: -0.00, std: 1.00, min: -3.64, max: 3.94
200
- else:
201
- img = x_T
202
-
203
- if timesteps is None: # equal with set time step in hf
204
- timesteps = (
205
- self.ddpm_num_timesteps
206
- if ddim_use_original_steps
207
- else self.ddim_timesteps
208
- )
209
- elif timesteps is not None and not ddim_use_original_steps:
210
- subset_end = (
211
- int(
212
- min(timesteps / self.ddim_timesteps.shape[0], 1)
213
- * self.ddim_timesteps.shape[0]
214
- )
215
- - 1
216
- )
217
- timesteps = self.ddim_timesteps[:subset_end]
218
-
219
- intermediates = {"x_inter": [img], "pred_x0": [img]}
220
- time_range = ( # reversed timesteps
221
- reversed(range(0, timesteps))
222
- if ddim_use_original_steps
223
- else np.flip(timesteps)
224
- )
225
- total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
226
- iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
227
- for i, step in enumerate(iterator):
228
- index = total_steps - i - 1
229
- ts = torch.full((b,), step, device=device, dtype=torch.long)
230
-
231
- if mask is not None:
232
- assert x0 is not None
233
- img_orig = self.model.q_sample(
234
- x0, ts
235
- ) # TODO: deterministic forward pass?
236
- img = img_orig * mask + (1.0 - mask) * img
237
-
238
- outs = self.p_sample_ddim(
239
- img,
240
- cond,
241
- ts,
242
- index=index,
243
- use_original_steps=ddim_use_original_steps,
244
- quantize_denoised=quantize_denoised,
245
- temperature=temperature,
246
- noise_dropout=noise_dropout,
247
- score_corrector=score_corrector,
248
- corrector_kwargs=corrector_kwargs,
249
- unconditional_guidance_scale=unconditional_guidance_scale,
250
- unconditional_conditioning=unconditional_conditioning,
251
- **kwargs,
252
- )
253
- img, pred_x0 = outs
254
- if callback:
255
- callback(i)
256
- if img_callback:
257
- img_callback(pred_x0, i)
258
-
259
- if index % log_every_t == 0 or index == total_steps - 1:
260
- intermediates["x_inter"].append(img)
261
- intermediates["pred_x0"].append(pred_x0)
262
-
263
- return img, intermediates
264
-
265
- @torch.no_grad()
266
- def p_sample_ddim(
267
- self,
268
- x,
269
- c,
270
- t,
271
- index,
272
- repeat_noise=False,
273
- use_original_steps=False,
274
- quantize_denoised=False,
275
- temperature=1.0,
276
- noise_dropout=0.0,
277
- score_corrector=None,
278
- corrector_kwargs=None,
279
- unconditional_guidance_scale=1.0,
280
- unconditional_conditioning=None,
281
- dynamic_threshold=None,
282
- **kwargs,
283
- ):
284
- b, *_, device = *x.shape, x.device
285
-
286
- if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
287
- model_output = self.model.apply_model(x, t, c)
288
- else:
289
- x_in = torch.cat([x] * 2)
290
- t_in = torch.cat([t] * 2)
291
- if isinstance(c, dict):
292
- assert isinstance(unconditional_conditioning, dict)
293
- c_in = dict()
294
- for k in c:
295
- if isinstance(c[k], list):
296
- c_in[k] = [
297
- torch.cat([unconditional_conditioning[k][i], c[k][i]])
298
- for i in range(len(c[k]))
299
- ]
300
- elif isinstance(c[k], torch.Tensor):
301
- c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
302
- else:
303
- assert c[k] == unconditional_conditioning[k]
304
- c_in[k] = c[k]
305
- elif isinstance(c, list):
306
- c_in = list()
307
- assert isinstance(unconditional_conditioning, list)
308
- for i in range(len(c)):
309
- c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
310
- else:
311
- c_in = torch.cat([unconditional_conditioning, c])
312
- model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
313
- model_output = model_uncond + unconditional_guidance_scale * (
314
- model_t - model_uncond
315
- )
316
-
317
-
318
- if self.model.parameterization == "v":
319
- print("using v!")
320
- e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
321
- else:
322
- e_t = model_output
323
-
324
- if score_corrector is not None:
325
- assert self.model.parameterization == "eps", "not implemented"
326
- e_t = score_corrector.modify_score(
327
- self.model, e_t, x, t, c, **corrector_kwargs
328
- )
329
-
330
- alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
331
- alphas_prev = (
332
- self.model.alphas_cumprod_prev
333
- if use_original_steps
334
- else self.ddim_alphas_prev
335
- )
336
- sqrt_one_minus_alphas = (
337
- self.model.sqrt_one_minus_alphas_cumprod
338
- if use_original_steps
339
- else self.ddim_sqrt_one_minus_alphas
340
- )
341
- sigmas = (
342
- self.model.ddim_sigmas_for_original_num_steps
343
- if use_original_steps
344
- else self.ddim_sigmas
345
- )
346
- # select parameters corresponding to the currently considered timestep
347
- a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
348
- a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
349
- sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
350
- sqrt_one_minus_at = torch.full(
351
- (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
352
- )
353
-
354
- # current prediction for x_0
355
- if self.model.parameterization != "v":
356
- pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
357
- else:
358
- pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
359
-
360
- if quantize_denoised:
361
- pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
362
-
363
- if dynamic_threshold is not None:
364
- raise NotImplementedError()
365
-
366
- # direction pointing to x_t
367
- dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
368
- noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
369
- if noise_dropout > 0.0:
370
- noise = torch.nn.functional.dropout(noise, p=noise_dropout)
371
- x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
372
- return x_prev, pred_x0
373
-
374
- @torch.no_grad()
375
- def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
376
- # fast, but does not allow for exact reconstruction
377
- # t serves as an index to gather the correct alphas
378
- if use_original_steps:
379
- sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
380
- sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
381
- else:
382
- sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
383
- sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
384
-
385
- if noise is None:
386
- noise = torch.randn_like(x0)
387
- return (
388
- extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
389
- + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
390
- )
391
-
392
- @torch.no_grad()
393
- def decode(
394
- self,
395
- x_latent,
396
- cond,
397
- t_start,
398
- unconditional_guidance_scale=1.0,
399
- unconditional_conditioning=None,
400
- use_original_steps=False,
401
- **kwargs,
402
- ):
403
- timesteps = (
404
- np.arange(self.ddpm_num_timesteps)
405
- if use_original_steps
406
- else self.ddim_timesteps
407
- )
408
- timesteps = timesteps[:t_start]
409
-
410
- time_range = np.flip(timesteps)
411
- total_steps = timesteps.shape[0]
412
-
413
- iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
414
- x_dec = x_latent
415
- for i, step in enumerate(iterator):
416
- index = total_steps - i - 1
417
- ts = torch.full(
418
- (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
419
- )
420
- x_dec, _ = self.p_sample_ddim(
421
- x_dec,
422
- cond,
423
- ts,
424
- index=index,
425
- use_original_steps=use_original_steps,
426
- unconditional_guidance_scale=unconditional_guidance_scale,
427
- unconditional_conditioning=unconditional_conditioning,
428
- **kwargs,
429
- )
430
- return x_dec
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from functools import partial
7
+
8
+ from ...modules.diffusionmodules.util import (
9
+ make_ddim_sampling_parameters,
10
+ make_ddim_timesteps,
11
+ noise_like,
12
+ extract_into_tensor,
13
+ )
14
+
15
+
16
+ class DDIMSampler(object):
17
+ def __init__(self, model, schedule="linear", **kwargs):
18
+ super().__init__()
19
+ self.model = model
20
+ self.ddpm_num_timesteps = model.num_timesteps
21
+ self.schedule = schedule
22
+
23
+ def register_buffer(self, name, attr):
24
+ if type(attr) == torch.Tensor:
25
+ if attr.device != torch.device("cuda"):
26
+ attr = attr.to(torch.device("cuda"))
27
+ setattr(self, name, attr)
28
+
29
+ def make_schedule(
30
+ self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
31
+ ):
32
+ self.ddim_timesteps = make_ddim_timesteps(
33
+ ddim_discr_method=ddim_discretize,
34
+ num_ddim_timesteps=ddim_num_steps,
35
+ num_ddpm_timesteps=self.ddpm_num_timesteps,
36
+ verbose=verbose,
37
+ )
38
+ alphas_cumprod = self.model.alphas_cumprod
39
+ assert (
40
+ alphas_cumprod.shape[0] == self.ddpm_num_timesteps
41
+ ), "alphas have to be defined for each timestep"
42
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
43
+
44
+ self.register_buffer("betas", to_torch(self.model.betas))
45
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
46
+ self.register_buffer(
47
+ "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
48
+ )
49
+
50
+ # calculations for diffusion q(x_t | x_{t-1}) and others
51
+ self.register_buffer(
52
+ "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
53
+ )
54
+ self.register_buffer(
55
+ "sqrt_one_minus_alphas_cumprod",
56
+ to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
57
+ )
58
+ self.register_buffer(
59
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
60
+ )
61
+ self.register_buffer(
62
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
63
+ )
64
+ self.register_buffer(
65
+ "sqrt_recipm1_alphas_cumprod",
66
+ to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
67
+ )
68
+
69
+ # ddim sampling parameters
70
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
71
+ alphacums=alphas_cumprod.cpu(),
72
+ ddim_timesteps=self.ddim_timesteps,
73
+ eta=ddim_eta,
74
+ verbose=verbose,
75
+ )
76
+ self.register_buffer("ddim_sigmas", ddim_sigmas)
77
+ self.register_buffer("ddim_alphas", ddim_alphas)
78
+ self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
79
+ self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
80
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
81
+ (1 - self.alphas_cumprod_prev)
82
+ / (1 - self.alphas_cumprod)
83
+ * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
84
+ )
85
+ self.register_buffer(
86
+ "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
87
+ )
88
+
89
+ @torch.no_grad()
90
+ def sample(
91
+ self,
92
+ S,
93
+ batch_size,
94
+ shape,
95
+ conditioning=None,
96
+ callback=None,
97
+ normals_sequence=None,
98
+ img_callback=None,
99
+ quantize_x0=False,
100
+ eta=0.0,
101
+ mask=None,
102
+ x0=None,
103
+ temperature=1.0,
104
+ noise_dropout=0.0,
105
+ score_corrector=None,
106
+ corrector_kwargs=None,
107
+ verbose=True,
108
+ x_T=None,
109
+ log_every_t=100,
110
+ unconditional_guidance_scale=1.0,
111
+ unconditional_conditioning=None,
112
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
113
+ **kwargs,
114
+ ):
115
+ if conditioning is not None:
116
+ if isinstance(conditioning, dict):
117
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
118
+ if cbs != batch_size:
119
+ print(
120
+ f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
121
+ )
122
+ else:
123
+ if conditioning.shape[0] != batch_size:
124
+ print(
125
+ f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
126
+ )
127
+
128
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
129
+ # sampling
130
+ C, H, W = shape
131
+ size = (batch_size, C, H, W)
132
+
133
+ samples, intermediates = self.ddim_sampling(
134
+ conditioning,
135
+ size,
136
+ callback=callback,
137
+ img_callback=img_callback,
138
+ quantize_denoised=quantize_x0,
139
+ mask=mask,
140
+ x0=x0,
141
+ ddim_use_original_steps=False,
142
+ noise_dropout=noise_dropout,
143
+ temperature=temperature,
144
+ score_corrector=score_corrector,
145
+ corrector_kwargs=corrector_kwargs,
146
+ x_T=x_T,
147
+ log_every_t=log_every_t,
148
+ unconditional_guidance_scale=unconditional_guidance_scale,
149
+ unconditional_conditioning=unconditional_conditioning,
150
+ **kwargs,
151
+ )
152
+ return samples, intermediates
153
+
154
+ @torch.no_grad()
155
+ def ddim_sampling(
156
+ self,
157
+ cond,
158
+ shape,
159
+ x_T=None,
160
+ ddim_use_original_steps=False,
161
+ callback=None,
162
+ timesteps=None,
163
+ quantize_denoised=False,
164
+ mask=None,
165
+ x0=None,
166
+ img_callback=None,
167
+ log_every_t=100,
168
+ temperature=1.0,
169
+ noise_dropout=0.0,
170
+ score_corrector=None,
171
+ corrector_kwargs=None,
172
+ unconditional_guidance_scale=1.0,
173
+ unconditional_conditioning=None,
174
+ **kwargs,
175
+ ):
176
+ """
177
+ when inference time: all values of parameter
178
+ cond.keys(): dict_keys(['context', 'camera', 'num_frames', 'ip', 'ip_img'])
179
+ shape: (5, 4, 32, 32)
180
+ x_T: None
181
+ ddim_use_original_steps: False
182
+ timesteps: None
183
+ callback: None
184
+ quantize_denoised: False
185
+ mask: None
186
+ image_callback: None
187
+ log_every_t: 100
188
+ temperature: 1.0
189
+ noise_dropout: 0.0
190
+ score_corrector: None
191
+ corrector_kwargs: None
192
+ unconditional_guidance_scale: 5
193
+ unconditional_conditioning.keys(): dict_keys(['context', 'camera', 'num_frames', 'ip', 'ip_img'])
194
+ kwargs: {}
195
+ """
196
+ device = self.model.betas.device
197
+ b = shape[0]
198
+ if x_T is None:
199
+ img = torch.randn(shape, device=device) # shape: torch.Size([5, 4, 32, 32]) mean: -0.00, std: 1.00, min: -3.64, max: 3.94
200
+ else:
201
+ img = x_T
202
+
203
+ if timesteps is None: # equal with set time step in hf
204
+ timesteps = (
205
+ self.ddpm_num_timesteps
206
+ if ddim_use_original_steps
207
+ else self.ddim_timesteps
208
+ )
209
+ elif timesteps is not None and not ddim_use_original_steps:
210
+ subset_end = (
211
+ int(
212
+ min(timesteps / self.ddim_timesteps.shape[0], 1)
213
+ * self.ddim_timesteps.shape[0]
214
+ )
215
+ - 1
216
+ )
217
+ timesteps = self.ddim_timesteps[:subset_end]
218
+
219
+ intermediates = {"x_inter": [img], "pred_x0": [img]}
220
+ time_range = ( # reversed timesteps
221
+ reversed(range(0, timesteps))
222
+ if ddim_use_original_steps
223
+ else np.flip(timesteps)
224
+ )
225
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
226
+ iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
227
+ for i, step in enumerate(iterator):
228
+ index = total_steps - i - 1
229
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
230
+
231
+ if mask is not None:
232
+ assert x0 is not None
233
+ img_orig = self.model.q_sample(
234
+ x0, ts
235
+ ) # TODO: deterministic forward pass?
236
+ img = img_orig * mask + (1.0 - mask) * img
237
+
238
+ outs = self.p_sample_ddim(
239
+ img,
240
+ cond,
241
+ ts,
242
+ index=index,
243
+ use_original_steps=ddim_use_original_steps,
244
+ quantize_denoised=quantize_denoised,
245
+ temperature=temperature,
246
+ noise_dropout=noise_dropout,
247
+ score_corrector=score_corrector,
248
+ corrector_kwargs=corrector_kwargs,
249
+ unconditional_guidance_scale=unconditional_guidance_scale,
250
+ unconditional_conditioning=unconditional_conditioning,
251
+ **kwargs,
252
+ )
253
+ img, pred_x0 = outs
254
+ if callback:
255
+ callback(i)
256
+ if img_callback:
257
+ img_callback(pred_x0, i)
258
+
259
+ if index % log_every_t == 0 or index == total_steps - 1:
260
+ intermediates["x_inter"].append(img)
261
+ intermediates["pred_x0"].append(pred_x0)
262
+
263
+ return img, intermediates
264
+
265
+ @torch.no_grad()
266
+ def p_sample_ddim(
267
+ self,
268
+ x,
269
+ c,
270
+ t,
271
+ index,
272
+ repeat_noise=False,
273
+ use_original_steps=False,
274
+ quantize_denoised=False,
275
+ temperature=1.0,
276
+ noise_dropout=0.0,
277
+ score_corrector=None,
278
+ corrector_kwargs=None,
279
+ unconditional_guidance_scale=1.0,
280
+ unconditional_conditioning=None,
281
+ dynamic_threshold=None,
282
+ **kwargs,
283
+ ):
284
+ b, *_, device = *x.shape, x.device
285
+
286
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
287
+ model_output = self.model.apply_model(x, t, c)
288
+ else:
289
+ x_in = torch.cat([x] * 2)
290
+ t_in = torch.cat([t] * 2)
291
+ if isinstance(c, dict):
292
+ assert isinstance(unconditional_conditioning, dict)
293
+ c_in = dict()
294
+ for k in c:
295
+ if isinstance(c[k], list):
296
+ c_in[k] = [
297
+ torch.cat([unconditional_conditioning[k][i], c[k][i]])
298
+ for i in range(len(c[k]))
299
+ ]
300
+ elif isinstance(c[k], torch.Tensor):
301
+ c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
302
+ else:
303
+ assert c[k] == unconditional_conditioning[k]
304
+ c_in[k] = c[k]
305
+ elif isinstance(c, list):
306
+ c_in = list()
307
+ assert isinstance(unconditional_conditioning, list)
308
+ for i in range(len(c)):
309
+ c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
310
+ else:
311
+ c_in = torch.cat([unconditional_conditioning, c])
312
+ model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
313
+ model_output = model_uncond + unconditional_guidance_scale * (
314
+ model_t - model_uncond
315
+ )
316
+
317
+
318
+ if self.model.parameterization == "v":
319
+ print("using v!")
320
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
321
+ else:
322
+ e_t = model_output
323
+
324
+ if score_corrector is not None:
325
+ assert self.model.parameterization == "eps", "not implemented"
326
+ e_t = score_corrector.modify_score(
327
+ self.model, e_t, x, t, c, **corrector_kwargs
328
+ )
329
+
330
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
331
+ alphas_prev = (
332
+ self.model.alphas_cumprod_prev
333
+ if use_original_steps
334
+ else self.ddim_alphas_prev
335
+ )
336
+ sqrt_one_minus_alphas = (
337
+ self.model.sqrt_one_minus_alphas_cumprod
338
+ if use_original_steps
339
+ else self.ddim_sqrt_one_minus_alphas
340
+ )
341
+ sigmas = (
342
+ self.model.ddim_sigmas_for_original_num_steps
343
+ if use_original_steps
344
+ else self.ddim_sigmas
345
+ )
346
+ # select parameters corresponding to the currently considered timestep
347
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
348
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
349
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
350
+ sqrt_one_minus_at = torch.full(
351
+ (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
352
+ )
353
+
354
+ # current prediction for x_0
355
+ if self.model.parameterization != "v":
356
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
357
+ else:
358
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
359
+
360
+ if quantize_denoised:
361
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
362
+
363
+ if dynamic_threshold is not None:
364
+ raise NotImplementedError()
365
+
366
+ # direction pointing to x_t
367
+ dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
368
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
369
+ if noise_dropout > 0.0:
370
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
371
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
372
+ return x_prev, pred_x0
373
+
374
+ @torch.no_grad()
375
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
376
+ # fast, but does not allow for exact reconstruction
377
+ # t serves as an index to gather the correct alphas
378
+ if use_original_steps:
379
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
380
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
381
+ else:
382
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
383
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
384
+
385
+ if noise is None:
386
+ noise = torch.randn_like(x0)
387
+ return (
388
+ extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
389
+ + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
390
+ )
391
+
392
+ @torch.no_grad()
393
+ def decode(
394
+ self,
395
+ x_latent,
396
+ cond,
397
+ t_start,
398
+ unconditional_guidance_scale=1.0,
399
+ unconditional_conditioning=None,
400
+ use_original_steps=False,
401
+ **kwargs,
402
+ ):
403
+ timesteps = (
404
+ np.arange(self.ddpm_num_timesteps)
405
+ if use_original_steps
406
+ else self.ddim_timesteps
407
+ )
408
+ timesteps = timesteps[:t_start]
409
+
410
+ time_range = np.flip(timesteps)
411
+ total_steps = timesteps.shape[0]
412
+
413
+ iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
414
+ x_dec = x_latent
415
+ for i, step in enumerate(iterator):
416
+ index = total_steps - i - 1
417
+ ts = torch.full(
418
+ (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
419
+ )
420
+ x_dec, _ = self.p_sample_ddim(
421
+ x_dec,
422
+ cond,
423
+ ts,
424
+ index=index,
425
+ use_original_steps=use_original_steps,
426
+ unconditional_guidance_scale=unconditional_guidance_scale,
427
+ unconditional_conditioning=unconditional_conditioning,
428
+ **kwargs,
429
+ )
430
+ return x_dec