Heasterian commited on
Commit
2c1ea12
·
verified ·
1 Parent(s): 1a50fd3

Removed Tiling code

Browse files

After a bit longer testing, it's causing pretty big changes in colors. I need to check out why.

Files changed (1) hide show
  1. README.md +0 -458
README.md CHANGED
@@ -53,461 +53,3 @@ upscaled_image = vae(image).sample
53
  # Save the reconstructed image
54
  utils.save_image(upscaled_image, "test.png")
55
  ```
56
-
57
- In case you want to run it on GPU and VRAM usage is too high, below you can find modified AsymmetricAutoencoderKL class with tiling support (and maybe slicing - it does not reduce VRAM usage for me, but it can be issue with ROCm on my platform). It's copy paste from AutoencoderKL with separated tile size for encode and decode.
58
-
59
- ```
60
- class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
61
- r"""
62
- Designing a Better Asymmetric VQGAN for StableDiffusion https://arxiv.org/abs/2306.04632 . A VAE model with KL loss
63
- for encoding images into latents and decoding latent representations into images.
64
-
65
- This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
66
- for all models (such as downloading or saving).
67
-
68
- Parameters:
69
- in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
70
- out_channels (int, *optional*, defaults to 3): Number of channels in the output.
71
- down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
72
- Tuple of downsample block types.
73
- down_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
74
- Tuple of down block output channels.
75
- layers_per_down_block (`int`, *optional*, defaults to `1`):
76
- Number layers for down block.
77
- up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
78
- Tuple of upsample block types.
79
- up_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
80
- Tuple of up block output channels.
81
- layers_per_up_block (`int`, *optional*, defaults to `1`):
82
- Number layers for up block.
83
- act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
84
- latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
85
- sample_size (`int`, *optional*, defaults to `32`): Sample input size.
86
- norm_num_groups (`int`, *optional*, defaults to `32`):
87
- Number of groups to use for the first normalization layer in ResNet blocks.
88
- scaling_factor (`float`, *optional*, defaults to 0.18215):
89
- The component-wise standard deviation of the trained latent space computed using the first batch of the
90
- training set. This is used to scale the latent space to have unit variance when training the diffusion
91
- model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
92
- diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
93
- / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
94
- Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
95
- """
96
-
97
- @register_to_config
98
- def __init__(
99
- self,
100
- in_channels: int = 3,
101
- out_channels: int = 3,
102
- down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
103
- down_block_out_channels: Tuple[int, ...] = (64,),
104
- layers_per_down_block: int = 1,
105
- up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
106
- up_block_out_channels: Tuple[int, ...] = (64,),
107
- layers_per_up_block: int = 1,
108
- act_fn: str = "silu",
109
- latent_channels: int = 4,
110
- norm_num_groups: int = 32,
111
- sample_size: int = 32,
112
- scaling_factor: float = 0.18215,
113
- use_quant_conv: bool = True,
114
- use_post_quant_conv: bool = True,
115
- ) -> None:
116
- super().__init__()
117
-
118
- # pass init params to Encoder
119
- self.encoder = Encoder(
120
- in_channels=in_channels,
121
- out_channels=latent_channels,
122
- down_block_types=down_block_types,
123
- block_out_channels=down_block_out_channels,
124
- layers_per_block=layers_per_down_block,
125
- act_fn=act_fn,
126
- norm_num_groups=norm_num_groups,
127
- double_z=True,
128
- )
129
-
130
- # pass init params to Decoder
131
- self.decoder = MaskConditionDecoder(
132
- in_channels=latent_channels,
133
- out_channels=out_channels,
134
- up_block_types=up_block_types,
135
- block_out_channels=up_block_out_channels,
136
- layers_per_block=layers_per_up_block,
137
- act_fn=act_fn,
138
- norm_num_groups=norm_num_groups,
139
- )
140
-
141
- self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None
142
- self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None
143
-
144
- self.use_slicing = False
145
- self.use_tiling = False
146
-
147
- # only relevant if vae tiling is enabled
148
- self.tile_sample_min_size = self.config.sample_size
149
- sample_size = (
150
- self.config.sample_size[0]
151
- if isinstance(self.config.sample_size, (list, tuple))
152
- else self.config.sample_size
153
- )
154
- self.tile_latent_min_up_size = int(sample_size / (2 ** (len(self.config.up_block_out_channels) - 1)))
155
- self.tile_latent_min_down_size = int(sample_size / (2 ** (len(self.config.down_block_out_channels) - 1)))
156
-
157
- self.tile_overlap_factor = 0.25
158
-
159
- self.register_to_config(block_out_channels=up_block_out_channels)
160
- self.register_to_config(force_upcast=False)
161
-
162
- def enable_tiling(self, use_tiling: bool = True):
163
- r"""
164
- Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
165
- compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
166
- processing larger images.
167
- """
168
- self.use_tiling = use_tiling
169
-
170
- def disable_tiling(self):
171
- r"""
172
- Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
173
- decoding in one step.
174
- """
175
- self.enable_tiling(False)
176
-
177
- def enable_slicing(self):
178
- r"""
179
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
180
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
181
- """
182
- self.use_slicing = True
183
-
184
- def disable_slicing(self):
185
- r"""
186
- Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
187
- decoding in one step.
188
- """
189
- self.use_slicing = False
190
-
191
- def _encode(self, x: torch.Tensor) -> torch.Tensor:
192
- batch_size, num_channels, height, width = x.shape
193
-
194
- if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
195
- return self._tiled_encode(x)
196
-
197
- enc = self.encoder(x)
198
- if self.quant_conv is not None:
199
- enc = self.quant_conv(enc)
200
-
201
- return enc
202
-
203
- @apply_forward_hook
204
- def encode(
205
- self, x: torch.Tensor, return_dict: bool = True
206
- ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
207
- """
208
- Encode a batch of images into latents.
209
-
210
- Args:
211
- x (`torch.Tensor`): Input batch of images.
212
- return_dict (`bool`, *optional*, defaults to `True`):
213
- Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
214
-
215
- Returns:
216
- The latent representations of the encoded images. If `return_dict` is True, a
217
- [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
218
- """
219
- if self.use_slicing and x.shape[0] > 1:
220
- encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
221
- h = torch.cat(encoded_slices)
222
- else:
223
- h = self._encode(x)
224
-
225
- posterior = DiagonalGaussianDistribution(h)
226
-
227
- if not return_dict:
228
- return (posterior,)
229
-
230
- return AutoencoderKLOutput(latent_dist=posterior)
231
-
232
- def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
233
- if self.use_tiling and (z.shape[-1] > self.tile_latent_min_up_size or z.shape[-2] > self.tile_latent_min_up_size):
234
- return self.tiled_decode(z, return_dict=return_dict)
235
-
236
- if self.post_quant_conv is not None:
237
- z = self.post_quant_conv(z)
238
-
239
- dec = self.decoder(z)
240
-
241
- if not return_dict:
242
- return (dec,)
243
-
244
- return DecoderOutput(sample=dec)
245
-
246
- @apply_forward_hook
247
- def decode(
248
- self, z: torch.FloatTensor, return_dict: bool = True, generator=None
249
- ) -> Union[DecoderOutput, torch.FloatTensor]:
250
- """
251
- Decode a batch of images.
252
-
253
- Args:
254
- z (`torch.Tensor`): Input batch of latent vectors.
255
- return_dict (`bool`, *optional*, defaults to `True`):
256
- Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
257
-
258
- Returns:
259
- [`~models.vae.DecoderOutput`] or `tuple`:
260
- If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
261
- returned.
262
-
263
- """
264
- if self.use_slicing and z.shape[0] > 1:
265
- decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
266
- decoded = torch.cat(decoded_slices)
267
- else:
268
- decoded = self._decode(z).sample
269
-
270
- if not return_dict:
271
- return (decoded,)
272
-
273
- return DecoderOutput(sample=decoded)
274
-
275
- def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
276
- blend_extent = min(a.shape[2], b.shape[2], blend_extent)
277
- for y in range(blend_extent):
278
- b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
279
- return b
280
-
281
- def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
282
- blend_extent = min(a.shape[3], b.shape[3], blend_extent)
283
- for x in range(blend_extent):
284
- b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
285
- return b
286
-
287
- def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
288
- r"""Encode a batch of images using a tiled encoder.
289
-
290
- When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
291
- steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
292
- different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
293
- tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
294
- output, but they should be much less noticeable.
295
-
296
- Args:
297
- x (`torch.Tensor`): Input batch of images.
298
-
299
- Returns:
300
- `torch.Tensor`:
301
- The latent representation of the encoded videos.
302
- """
303
-
304
- overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
305
- blend_extent = int(self.tile_latent_min_down_size * self.tile_overlap_factor)
306
- row_limit = self.tile_latent_min_down_size - blend_extent
307
-
308
- # Split the image into 512x512 tiles and encode them separately.
309
- rows = []
310
- for i in range(0, x.shape[2], overlap_size):
311
- row = []
312
- for j in range(0, x.shape[3], overlap_size):
313
- tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
314
- tile = self.encoder(tile)
315
- if self.config.use_quant_conv:
316
- tile = self.quant_conv(tile)
317
- row.append(tile)
318
- rows.append(row)
319
- result_rows = []
320
- for i, row in enumerate(rows):
321
- result_row = []
322
- for j, tile in enumerate(row):
323
- # blend the above tile and the left tile
324
- # to the current tile and add the current tile to the result row
325
- if i > 0:
326
- tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
327
- if j > 0:
328
- tile = self.blend_h(row[j - 1], tile, blend_extent)
329
- result_row.append(tile[:, :, :row_limit, :row_limit])
330
- result_rows.append(torch.cat(result_row, dim=3))
331
-
332
- enc = torch.cat(result_rows, dim=2)
333
- return enc
334
-
335
- def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput:
336
- r"""Encode a batch of images using a tiled encoder.
337
-
338
- When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
339
- steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
340
- different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
341
- tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
342
- output, but they should be much less noticeable.
343
-
344
- Args:
345
- x (`torch.Tensor`): Input batch of images.
346
- return_dict (`bool`, *optional*, defaults to `True`):
347
- Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
348
-
349
- Returns:
350
- [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
351
- If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
352
- `tuple` is returned.
353
- """
354
- deprecation_message = (
355
- "The tiled_encode implementation supporting the `return_dict` parameter is deprecated. In the future, the "
356
- "implementation of this method will be replaced with that of `_tiled_encode` and you will no longer be able "
357
- "to pass `return_dict`. You will also have to create a `DiagonalGaussianDistribution()` from the returned value."
358
- )
359
- deprecate("tiled_encode", "1.0.0", deprecation_message, standard_warn=False)
360
-
361
- overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
362
- blend_extent = int(self.tile_latent_min_up_size * self.tile_overlap_factor)
363
- row_limit = self.tile_latent_min_up_size - blend_extent
364
-
365
- # Split the image into 512x512 tiles and encode them separately.
366
- rows = []
367
- for i in range(0, x.shape[2], overlap_size):
368
- row = []
369
- for j in range(0, x.shape[3], overlap_size):
370
- tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
371
- tile = self.encoder(tile)
372
- if self.config.use_quant_conv:
373
- tile = self.quant_conv(tile)
374
- row.append(tile)
375
- rows.append(row)
376
- result_rows = []
377
- for i, row in enumerate(rows):
378
- result_row = []
379
- for j, tile in enumerate(row):
380
- # blend the above tile and the left tile
381
- # to the current tile and add the current tile to the result row
382
- if i > 0:
383
- tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
384
- if j > 0:
385
- tile = self.blend_h(row[j - 1], tile, blend_extent)
386
- result_row.append(tile[:, :, :row_limit, :row_limit])
387
- result_rows.append(torch.cat(result_row, dim=3))
388
-
389
- moments = torch.cat(result_rows, dim=2)
390
- posterior = DiagonalGaussianDistribution(moments)
391
-
392
- if not return_dict:
393
- return (posterior,)
394
-
395
- return AutoencoderKLOutput(latent_dist=posterior)
396
-
397
- def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
398
- r"""
399
- Decode a batch of images using a tiled decoder.
400
-
401
- Args:
402
- z (`torch.Tensor`): Input batch of latent vectors.
403
- return_dict (`bool`, *optional*, defaults to `True`):
404
- Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
405
-
406
- Returns:
407
- [`~models.vae.DecoderOutput`] or `tuple`:
408
- If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
409
- returned.
410
- """
411
- overlap_size = int(self.tile_latent_min_up_size * (1 - self.tile_overlap_factor))
412
- blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
413
- row_limit = self.tile_sample_min_size - blend_extent
414
-
415
- # Split z into overlapping 64x64 tiles and decode them separately.
416
- # The tiles have an overlap to avoid seams between tiles.
417
- rows = []
418
- for i in range(0, z.shape[2], overlap_size):
419
- row = []
420
- for j in range(0, z.shape[3], overlap_size):
421
- tile = z[:, :, i : i + self.tile_latent_min_up_size, j : j + self.tile_latent_min_up_size]
422
- if self.config.use_post_quant_conv:
423
- tile = self.post_quant_conv(tile)
424
- decoded = self.decoder(tile)
425
- row.append(decoded)
426
- rows.append(row)
427
- result_rows = []
428
- for i, row in enumerate(rows):
429
- result_row = []
430
- for j, tile in enumerate(row):
431
- # blend the above tile and the left tile
432
- # to the current tile and add the current tile to the result row
433
- if i > 0:
434
- tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
435
- if j > 0:
436
- tile = self.blend_h(row[j - 1], tile, blend_extent)
437
- result_row.append(tile[:, :, :row_limit, :row_limit])
438
- result_rows.append(torch.cat(result_row, dim=3))
439
-
440
- dec = torch.cat(result_rows, dim=2)
441
- if not return_dict:
442
- return (dec,)
443
-
444
- return DecoderOutput(sample=dec)
445
-
446
- def forward(
447
- self,
448
- sample: torch.Tensor,
449
- sample_posterior: bool = False,
450
- return_dict: bool = True,
451
- generator: Optional[torch.Generator] = None,
452
- ) -> Union[DecoderOutput, torch.Tensor]:
453
- r"""
454
- Args:
455
- sample (`torch.Tensor`): Input sample.
456
- sample_posterior (`bool`, *optional*, defaults to `False`):
457
- Whether to sample from the posterior.
458
- return_dict (`bool`, *optional*, defaults to `True`):
459
- Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
460
- """
461
- x = sample
462
- posterior = self.encode(x).latent_dist
463
- if sample_posterior:
464
- z = posterior.sample(generator=generator)
465
- else:
466
- z = posterior.mode()
467
- dec = self.decode(z).sample
468
-
469
- if not return_dict:
470
- return (dec,)
471
-
472
- return DecoderOutput(sample=dec)
473
-
474
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
475
- def fuse_qkv_projections(self):
476
- """
477
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
478
- are fused. For cross-attention modules, key and value projection matrices are fused.
479
-
480
- <Tip warning={true}>
481
-
482
- This API is 🧪 experimental.
483
-
484
- </Tip>
485
- """
486
- self.original_attn_processors = None
487
-
488
- for _, attn_processor in self.attn_processors.items():
489
- if "Added" in str(attn_processor.__class__.__name__):
490
- raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
491
-
492
- self.original_attn_processors = self.attn_processors
493
-
494
- for module in self.modules():
495
- if isinstance(module, Attention):
496
- module.fuse_projections(fuse=True)
497
-
498
- self.set_attn_processor(FusedAttnProcessor2_0())
499
-
500
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
501
- def unfuse_qkv_projections(self):
502
- """Disables the fused QKV projection if enabled.
503
-
504
- <Tip warning={true}>
505
-
506
- This API is 🧪 experimental.
507
-
508
- </Tip>
509
-
510
- """
511
- if self.original_attn_processors is not None:
512
- self.set_attn_processor(self.original_attn_processors)
513
- ```
 
53
  # Save the reconstructed image
54
  utils.save_image(upscaled_image, "test.png")
55
  ```