import torch from torch.utils.checkpoint import checkpoint from diffusers.models.autoencoders.autoencoder_dc import Encoder, Decoder, AutoencoderDC class MyEncoder(Encoder): def __init__( self, in_channels, latent_channels, attention_head_dim = 32, block_type = "ResBlock", block_out_channels = ..., layers_per_block = ..., qkv_multiscales = ..., downsample_block_type = "pixel_unshuffle", out_shortcut = True ): super().__init__( in_channels, latent_channels, attention_head_dim, block_type, block_out_channels, layers_per_block, qkv_multiscales, downsample_block_type, out_shortcut ) def forward(self, hidden_states: torch.Tensor, use_checkpoint=False) -> torch.Tensor: hidden_states = self.conv_in(hidden_states) for down_block in self.down_blocks: if use_checkpoint: hidden_states = checkpoint(self.ckpt_wrapper(down_block), hidden_states) else: hidden_states = down_block(hidden_states) if self.out_shortcut: x = hidden_states.unflatten(1, (-1, self.out_shortcut_average_group_size)) x = x.mean(dim=2) hidden_states = self.conv_out(hidden_states) + x else: hidden_states = self.conv_out(hidden_states) return hidden_states def ckpt_wrapper(self, module): def ckpt_forward(*inputs): outputs = module(*inputs) return outputs return ckpt_forward class MyDecoder(Decoder): def __init__( self, in_channels, latent_channels, attention_head_dim = 32, block_type = "ResBlock", block_out_channels = ..., layers_per_block = ..., qkv_multiscales = ..., norm_type = "rms_norm", act_fn = "silu", upsample_block_type = "pixel_shuffle", in_shortcut = True ): super().__init__( in_channels, latent_channels, attention_head_dim, block_type, block_out_channels, layers_per_block, qkv_multiscales, norm_type, act_fn, upsample_block_type, in_shortcut ) def forward(self, hidden_states: torch.Tensor, use_checkpoint=False) -> torch.Tensor: if self.in_shortcut: x = hidden_states.repeat_interleave( self.in_shortcut_repeats, dim=1, output_size=hidden_states.shape[1] * self.in_shortcut_repeats ) hidden_states = self.conv_in(hidden_states) + x else: hidden_states = self.conv_in(hidden_states) for up_block in reversed(self.up_blocks): if use_checkpoint: hidden_states = checkpoint(self.ckpt_wrapper(up_block), hidden_states) else: hidden_states = up_block(hidden_states) hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1) hidden_states = self.conv_act(hidden_states) hidden_states = self.conv_out(hidden_states) return hidden_states def ckpt_wrapper(self, module): def ckpt_forward(*inputs): outputs = module(*inputs) return outputs return ckpt_forward class MyAutoencoderDC(AutoencoderDC): def __init__( self, in_channels = 3, latent_channels = 32, attention_head_dim = 32, encoder_block_types = "ResBlock", decoder_block_types = "ResBlock", encoder_block_out_channels = ..., decoder_block_out_channels = ..., encoder_layers_per_block = ..., decoder_layers_per_block = ..., encoder_qkv_multiscales = ..., decoder_qkv_multiscales = ..., upsample_block_type = "pixel_shuffle", downsample_block_type = "pixel_unshuffle", decoder_norm_types = "rms_norm", decoder_act_fns = "silu", scaling_factor = 1, bn_momentum = 0.1, ): super().__init__( in_channels, latent_channels, attention_head_dim, encoder_block_types, decoder_block_types, encoder_block_out_channels, decoder_block_out_channels, encoder_layers_per_block, decoder_layers_per_block, encoder_qkv_multiscales, decoder_qkv_multiscales, upsample_block_type, downsample_block_type, decoder_norm_types, decoder_act_fns, scaling_factor ) self.encoder = MyEncoder( in_channels=in_channels, latent_channels=latent_channels, attention_head_dim=attention_head_dim, block_type=encoder_block_types, block_out_channels=encoder_block_out_channels, layers_per_block=encoder_layers_per_block, qkv_multiscales=encoder_qkv_multiscales, downsample_block_type=downsample_block_type, ) self.decoder = MyDecoder( in_channels=in_channels, latent_channels=latent_channels, attention_head_dim=attention_head_dim, block_type=decoder_block_types, block_out_channels=decoder_block_out_channels, layers_per_block=decoder_layers_per_block, qkv_multiscales=decoder_qkv_multiscales, norm_type=decoder_norm_types, act_fn=decoder_act_fns, upsample_block_type=upsample_block_type, ) self.bn = torch.nn.BatchNorm2d( latent_channels, eps=1e-4, momentum=bn_momentum, affine=False, track_running_stats=True ) self.bn.reset_running_stats() self.init_bn() def init_bn(self): # self.bn.running_mean = torch.zeros_like(self.bn.running_mean).to(torch.float64) # self.bn.running_var = torch.ones_like(self.bn.running_var).to(torch.float64) / self.config.scaling_factor ** 2 self.bn.running_mean = torch.zeros_like(self.bn.running_mean) self.bn.running_var = torch.ones_like(self.bn.running_var) / self.config.scaling_factor ** 2 print(self.config.scaling_factor, self.bn.running_var.flatten()) @property def mean(self): mean = self.bn.running_mean.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) return mean @property def std(self): std = self.bn.running_var.sqrt().unsqueeze(0).unsqueeze(-1).unsqueeze(-1) return std def forward(self, x: torch.Tensor, use_checkpoint=False) -> torch.Tensor: z = self.encoder(x, use_checkpoint) latent = self.bn(z) recon = self.decoder(z, use_checkpoint) posterior = None return posterior, latent, recon