Vaani-Audio2Img-LDM / Vaani /VQVAE_summary.txt
alpha31476's picture
LDM-train-pass, checking results
87ef7b5 verified
TIME: 2025-05-09 21:58:45.534412
DEVICE: cuda
{'autoencoder_params': {'attn_down': [False, False],
'codebook_size': 20,
'down_channels': [32, 64, 128],
'down_sample': [True, True],
'mid_channels': [128, 128],
'norm_channels': 32,
'num_down_layers': 4,
'num_heads': 16,
'num_mid_layers': 4,
'num_up_layers': 4,
'z_channels': 3},
'dataset_params': {'im_channels': 3, 'im_size': 128},
'diffusion_params': {'beta_end': 0.0195, 'beta_start': 0.0015, 'num_timesteps': 1000},
'ldm_params': {'attn_down': [True, True, True],
'conv_out_channels': 128,
'down_channels': [128, 256, 256, 256],
'down_sample': [False, False, False],
'mid_channels': [256, 256],
'norm_channels': 32,
'num_down_layers': 2,
'num_heads': 16,
'num_mid_layers': 2,
'num_up_layers': 2,
'time_emb_dim': 256},
'paths': {'images_dir': '/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Images'},
'train_params': {'autoencoder_acc_steps': 1,
'autoencoder_batch_size': 8,
'autoencoder_epochs': 30,
'autoencoder_img_save_steps': 8,
'autoencoder_lr': 0.0001,
'codebook_weight': 1,
'commitment_beta': 0.2,
'disc_start': 1000,
'disc_weight': 0.5,
'kl_weight': 5e-06,
'ldm_batch_size': 1,
'ldm_ckpt_name': 'ddpm_ckpt.pth',
'ldm_epochs': 10,
'ldm_lr': 1e-05,
'num_grid_rows': 3,
'num_samples': 9,
'perceptual_weight': 1,
'save_latents': True,
'seed': 4422,
'task_name': 'VaaniLDM',
'vqvae_ckpt_name': 'vqvaq_ckpt.pth',
'vqvae_latent_dir_name': 'vqvae_latents'},
'training': {'_continue_': True}}
Files found: 128807
IMAGE SHAPE: torch.Size([3, 128, 128])
BATCH SHAPE: torch.Size([8, 3, 128, 128])
======================================================================================================================================================
Layer (type (var_name)) Input Shape Output Shape Param # Trainable Param %
======================================================================================================================================================
VQVAE (VQVAE) [8, 3, 128, 128] [8, 3, 128, 128] 60 True 0.00%
├─Conv2d (encoder_conv_in) [8, 3, 128, 128] [8, 32, 128, 128] 896 True 0.01%
├─ModuleList (encoder_layers) -- -- -- True --
│ └─DownBlock (0) [8, 32, 128, 128] [8, 64, 64, 64] -- True --
│ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (0) [8, 32, 128, 128] [8, 64, 128, 128] -- True --
│ │ │ │ └─GroupNorm (0) [8, 32, 128, 128] [8, 32, 128, 128] 64 True 0.00%
│ │ │ │ └─SiLU (1) [8, 32, 128, 128] [8, 32, 128, 128] -- -- --
│ │ │ │ └─Conv2d (2) [8, 32, 128, 128] [8, 64, 128, 128] 18,496 True 0.30%
│ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (0) [8, 64, 128, 128] [8, 64, 128, 128] -- True --
│ │ │ │ └─GroupNorm (0) [8, 64, 128, 128] [8, 64, 128, 128] 128 True 0.00%
│ │ │ │ └─SiLU (1) [8, 64, 128, 128] [8, 64, 128, 128] -- -- --
│ │ │ │ └─Conv2d (2) [8, 64, 128, 128] [8, 64, 128, 128] 36,928 True 0.59%
│ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
│ │ │ └─Conv2d (0) [8, 32, 128, 128] [8, 64, 128, 128] 2,112 True 0.03%
│ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (1) [8, 64, 128, 128] [8, 64, 128, 128] -- True --
│ │ │ │ └─GroupNorm (0) [8, 64, 128, 128] [8, 64, 128, 128] 128 True 0.00%
│ │ │ │ └─SiLU (1) [8, 64, 128, 128] [8, 64, 128, 128] -- -- --
│ │ │ │ └─Conv2d (2) [8, 64, 128, 128] [8, 64, 128, 128] 36,928 True 0.59%
│ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (1) [8, 64, 128, 128] [8, 64, 128, 128] -- True --
│ │ │ │ └─GroupNorm (0) [8, 64, 128, 128] [8, 64, 128, 128] 128 True 0.00%
│ │ │ │ └─SiLU (1) [8, 64, 128, 128] [8, 64, 128, 128] -- -- --
│ │ │ │ └─Conv2d (2) [8, 64, 128, 128] [8, 64, 128, 128] 36,928 True 0.59%
│ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
│ │ │ └─Conv2d (1) [8, 64, 128, 128] [8, 64, 128, 128] 4,160 True 0.07%
│ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (2) [8, 64, 128, 128] [8, 64, 128, 128] -- True --
│ │ │ │ └─GroupNorm (0) [8, 64, 128, 128] [8, 64, 128, 128] 128 True 0.00%
│ │ │ │ └─SiLU (1) [8, 64, 128, 128] [8, 64, 128, 128] -- -- --
│ │ │ │ └─Conv2d (2) [8, 64, 128, 128] [8, 64, 128, 128] 36,928 True 0.59%
│ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (2) [8, 64, 128, 128] [8, 64, 128, 128] -- True --
│ │ │ │ └─GroupNorm (0) [8, 64, 128, 128] [8, 64, 128, 128] 128 True 0.00%
│ │ │ │ └─SiLU (1) [8, 64, 128, 128] [8, 64, 128, 128] -- -- --
│ │ │ │ └─Conv2d (2) [8, 64, 128, 128] [8, 64, 128, 128] 36,928 True 0.59%
│ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
│ │ │ └─Conv2d (2) [8, 64, 128, 128] [8, 64, 128, 128] 4,160 True 0.07%
│ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (3) [8, 64, 128, 128] [8, 64, 128, 128] -- True --
│ │ │ │ └─GroupNorm (0) [8, 64, 128, 128] [8, 64, 128, 128] 128 True 0.00%
│ │ │ │ └─SiLU (1) [8, 64, 128, 128] [8, 64, 128, 128] -- -- --
│ │ │ │ └─Conv2d (2) [8, 64, 128, 128] [8, 64, 128, 128] 36,928 True 0.59%
│ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (3) [8, 64, 128, 128] [8, 64, 128, 128] -- True --
│ │ │ │ └─GroupNorm (0) [8, 64, 128, 128] [8, 64, 128, 128] 128 True 0.00%
│ │ │ │ └─SiLU (1) [8, 64, 128, 128] [8, 64, 128, 128] -- -- --
│ │ │ │ └─Conv2d (2) [8, 64, 128, 128] [8, 64, 128, 128] 36,928 True 0.59%
│ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
│ │ │ └─Conv2d (3) [8, 64, 128, 128] [8, 64, 128, 128] 4,160 True 0.07%
│ │ └─Conv2d (down_sample_conv) [8, 64, 128, 128] [8, 64, 64, 64] 65,600 True 1.05%
│ └─DownBlock (1) [8, 64, 64, 64] [8, 128, 32, 32] -- True --
│ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (0) [8, 64, 64, 64] [8, 128, 64, 64] -- True --
│ │ │ │ └─GroupNorm (0) [8, 64, 64, 64] [8, 64, 64, 64] 128 True 0.00%
│ │ │ │ └─SiLU (1) [8, 64, 64, 64] [8, 64, 64, 64] -- -- --
│ │ │ │ └─Conv2d (2) [8, 64, 64, 64] [8, 128, 64, 64] 73,856 True 1.19%
│ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (0) [8, 128, 64, 64] [8, 128, 64, 64] -- True --
│ │ │ │ └─GroupNorm (0) [8, 128, 64, 64] [8, 128, 64, 64] 256 True 0.00%
│ │ │ │ └─SiLU (1) [8, 128, 64, 64] [8, 128, 64, 64] -- -- --
│ │ │ │ └─Conv2d (2) [8, 128, 64, 64] [8, 128, 64, 64] 147,584 True 2.37%
│ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
│ │ │ └─Conv2d (0) [8, 64, 64, 64] [8, 128, 64, 64] 8,320 True 0.13%
│ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (1) [8, 128, 64, 64] [8, 128, 64, 64] -- True --
│ │ │ │ └─GroupNorm (0) [8, 128, 64, 64] [8, 128, 64, 64] 256 True 0.00%
│ │ │ │ └─SiLU (1) [8, 128, 64, 64] [8, 128, 64, 64] -- -- --
│ │ │ │ └─Conv2d (2) [8, 128, 64, 64] [8, 128, 64, 64] 147,584 True 2.37%
│ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (1) [8, 128, 64, 64] [8, 128, 64, 64] -- True --
│ │ │ │ └─GroupNorm (0) [8, 128, 64, 64] [8, 128, 64, 64] 256 True 0.00%
│ │ │ │ └─SiLU (1) [8, 128, 64, 64] [8, 128, 64, 64] -- -- --
│ │ │ │ └─Conv2d (2) [8, 128, 64, 64] [8, 128, 64, 64] 147,584 True 2.37%
│ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
│ │ │ └─Conv2d (1) [8, 128, 64, 64] [8, 128, 64, 64] 16,512 True 0.27%
│ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (2) [8, 128, 64, 64] [8, 128, 64, 64] -- True --
│ │ │ │ └─GroupNorm (0) [8, 128, 64, 64] [8, 128, 64, 64] 256 True 0.00%
│ │ │ │ └─SiLU (1) [8, 128, 64, 64] [8, 128, 64, 64] -- -- --
│ │ │ │ └─Conv2d (2) [8, 128, 64, 64] [8, 128, 64, 64] 147,584 True 2.37%
│ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (2) [8, 128, 64, 64] [8, 128, 64, 64] -- True --
│ │ │ │ └─GroupNorm (0) [8, 128, 64, 64] [8, 128, 64, 64] 256 True 0.00%
│ │ │ │ └─SiLU (1) [8, 128, 64, 64] [8, 128, 64, 64] -- -- --
│ │ │ │ └─Conv2d (2) [8, 128, 64, 64] [8, 128, 64, 64] 147,584 True 2.37%
│ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
│ │ │ └─Conv2d (2) [8, 128, 64, 64] [8, 128, 64, 64] 16,512 True 0.27%
│ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (3) [8, 128, 64, 64] [8, 128, 64, 64] -- True --
│ │ │ │ └─GroupNorm (0) [8, 128, 64, 64] [8, 128, 64, 64] 256 True 0.00%
│ │ │ │ └─SiLU (1) [8, 128, 64, 64] [8, 128, 64, 64] -- -- --
│ │ │ │ └─Conv2d (2) [8, 128, 64, 64] [8, 128, 64, 64] 147,584 True 2.37%
│ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (3) [8, 128, 64, 64] [8, 128, 64, 64] -- True --
│ │ │ │ └─GroupNorm (0) [8, 128, 64, 64] [8, 128, 64, 64] 256 True 0.00%
│ │ │ │ └─SiLU (1) [8, 128, 64, 64] [8, 128, 64, 64] -- -- --
│ │ │ │ └─Conv2d (2) [8, 128, 64, 64] [8, 128, 64, 64] 147,584 True 2.37%
│ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
│ │ │ └─Conv2d (3) [8, 128, 64, 64] [8, 128, 64, 64] 16,512 True 0.27%
│ │ └─Conv2d (down_sample_conv) [8, 128, 64, 64] [8, 128, 32, 32] 262,272 True 4.22%
├─ModuleList (encoder_mids) -- -- -- True --
│ └─MidBlock (0) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
│ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (0) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
│ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
│ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
│ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
│ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (0) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
│ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
│ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
│ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
│ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
│ │ │ └─Conv2d (0) [8, 128, 32, 32] [8, 128, 32, 32] 16,512 True 0.27%
│ │ └─ModuleList (attention_norms) -- -- (recursive) True (recursive)
│ │ │ └─GroupNorm (0) [8, 128, 1024] [8, 128, 1024] 256 True 0.00%
│ │ └─ModuleList (attentions) -- -- (recursive) True (recursive)
│ │ │ └─MultiheadAttention (0) [8, 1024, 128] [8, 1024, 128] 66,048 True 1.06%
│ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (1) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
│ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
│ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
│ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
│ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (1) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
│ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
│ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
│ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
│ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
│ │ │ └─Conv2d (1) [8, 128, 32, 32] [8, 128, 32, 32] 16,512 True 0.27%
│ │ └─ModuleList (attention_norms) -- -- (recursive) True (recursive)
│ │ │ └─GroupNorm (1) [8, 128, 1024] [8, 128, 1024] 256 True 0.00%
│ │ └─ModuleList (attentions) -- -- (recursive) True (recursive)
│ │ │ └─MultiheadAttention (1) [8, 1024, 128] [8, 1024, 128] 66,048 True 1.06%
│ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (2) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
│ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
│ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
│ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
│ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (2) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
│ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
│ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
│ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
│ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
│ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 16,512 True 0.27%
│ │ └─ModuleList (attention_norms) -- -- (recursive) True (recursive)
│ │ │ └─GroupNorm (2) [8, 128, 1024] [8, 128, 1024] 256 True 0.00%
│ │ └─ModuleList (attentions) -- -- (recursive) True (recursive)
│ │ │ └─MultiheadAttention (2) [8, 1024, 128] [8, 1024, 128] 66,048 True 1.06%
│ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (3) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
│ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
│ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
│ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
│ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (3) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
│ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
│ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
│ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
│ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
│ │ │ └─Conv2d (3) [8, 128, 32, 32] [8, 128, 32, 32] 16,512 True 0.27%
│ │ └─ModuleList (attention_norms) -- -- (recursive) True (recursive)
│ │ │ └─GroupNorm (3) [8, 128, 1024] [8, 128, 1024] 256 True 0.00%
│ │ └─ModuleList (attentions) -- -- (recursive) True (recursive)
│ │ │ └─MultiheadAttention (3) [8, 1024, 128] [8, 1024, 128] 66,048 True 1.06%
│ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (4) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
│ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
│ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
│ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
│ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (4) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
│ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
│ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
│ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
│ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
│ │ │ └─Conv2d (4) [8, 128, 32, 32] [8, 128, 32, 32] 16,512 True 0.27%
├─GroupNorm (encoder_norm_out) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
├─Conv2d (encoder_conv_out) [8, 128, 32, 32] [8, 3, 32, 32] 3,459 True 0.06%
├─Conv2d (pre_quant_conv) [8, 3, 32, 32] [8, 3, 32, 32] 12 True 0.00%
├─Conv2d (post_quant_conv) [8, 3, 32, 32] [8, 3, 32, 32] 12 True 0.00%
├─Conv2d (decoder_conv_in) [8, 3, 32, 32] [8, 128, 32, 32] 3,584 True 0.06%
├─ModuleList (decoder_mids) -- -- -- True --
│ └─MidBlock (0) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
│ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (0) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
│ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
│ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
│ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
│ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (0) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
│ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
│ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
│ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
│ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
│ │ │ └─Conv2d (0) [8, 128, 32, 32] [8, 128, 32, 32] 16,512 True 0.27%
│ │ └─ModuleList (attention_norms) -- -- (recursive) True (recursive)
│ │ │ └─GroupNorm (0) [8, 128, 1024] [8, 128, 1024] 256 True 0.00%
│ │ └─ModuleList (attentions) -- -- (recursive) True (recursive)
│ │ │ └─MultiheadAttention (0) [8, 1024, 128] [8, 1024, 128] 66,048 True 1.06%
│ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (1) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
│ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
│ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
│ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
│ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (1) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
│ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
│ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
│ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
│ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
│ │ │ └─Conv2d (1) [8, 128, 32, 32] [8, 128, 32, 32] 16,512 True 0.27%
│ │ └─ModuleList (attention_norms) -- -- (recursive) True (recursive)
│ │ │ └─GroupNorm (1) [8, 128, 1024] [8, 128, 1024] 256 True 0.00%
│ │ └─ModuleList (attentions) -- -- (recursive) True (recursive)
│ │ │ └─MultiheadAttention (1) [8, 1024, 128] [8, 1024, 128] 66,048 True 1.06%
│ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (2) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
│ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
│ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
│ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
│ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (2) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
│ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
│ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
│ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
│ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
│ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 16,512 True 0.27%
│ │ └─ModuleList (attention_norms) -- -- (recursive) True (recursive)
│ │ │ └─GroupNorm (2) [8, 128, 1024] [8, 128, 1024] 256 True 0.00%
│ │ └─ModuleList (attentions) -- -- (recursive) True (recursive)
│ │ │ └─MultiheadAttention (2) [8, 1024, 128] [8, 1024, 128] 66,048 True 1.06%
│ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (3) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
│ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
│ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
│ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
│ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (3) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
│ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
│ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
│ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
│ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
│ │ │ └─Conv2d (3) [8, 128, 32, 32] [8, 128, 32, 32] 16,512 True 0.27%
│ │ └─ModuleList (attention_norms) -- -- (recursive) True (recursive)
│ │ │ └─GroupNorm (3) [8, 128, 1024] [8, 128, 1024] 256 True 0.00%
│ │ └─ModuleList (attentions) -- -- (recursive) True (recursive)
│ │ │ └─MultiheadAttention (3) [8, 1024, 128] [8, 1024, 128] 66,048 True 1.06%
│ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (4) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
│ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
│ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
│ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
│ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (4) [8, 128, 32, 32] [8, 128, 32, 32] -- True --
│ │ │ │ └─GroupNorm (0) [8, 128, 32, 32] [8, 128, 32, 32] 256 True 0.00%
│ │ │ │ └─SiLU (1) [8, 128, 32, 32] [8, 128, 32, 32] -- -- --
│ │ │ │ └─Conv2d (2) [8, 128, 32, 32] [8, 128, 32, 32] 147,584 True 2.37%
│ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
│ │ │ └─Conv2d (4) [8, 128, 32, 32] [8, 128, 32, 32] 16,512 True 0.27%
├─ModuleList (decoder_layers) -- -- -- True --
│ └─UpBlock (0) [8, 128, 32, 32] [8, 64, 64, 64] -- True --
│ │ └─ConvTranspose2d (up_sample_conv) [8, 128, 32, 32] [8, 128, 64, 64] 262,272 True 4.22%
│ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (0) [8, 128, 64, 64] [8, 64, 64, 64] -- True --
│ │ │ │ └─GroupNorm (0) [8, 128, 64, 64] [8, 128, 64, 64] 256 True 0.00%
│ │ │ │ └─SiLU (1) [8, 128, 64, 64] [8, 128, 64, 64] -- -- --
│ │ │ │ └─Conv2d (2) [8, 128, 64, 64] [8, 64, 64, 64] 73,792 True 1.19%
│ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (0) [8, 64, 64, 64] [8, 64, 64, 64] -- True --
│ │ │ │ └─GroupNorm (0) [8, 64, 64, 64] [8, 64, 64, 64] 128 True 0.00%
│ │ │ │ └─SiLU (1) [8, 64, 64, 64] [8, 64, 64, 64] -- -- --
│ │ │ │ └─Conv2d (2) [8, 64, 64, 64] [8, 64, 64, 64] 36,928 True 0.59%
│ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
│ │ │ └─Conv2d (0) [8, 128, 64, 64] [8, 64, 64, 64] 8,256 True 0.13%
│ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (1) [8, 64, 64, 64] [8, 64, 64, 64] -- True --
│ │ │ │ └─GroupNorm (0) [8, 64, 64, 64] [8, 64, 64, 64] 128 True 0.00%
│ │ │ │ └─SiLU (1) [8, 64, 64, 64] [8, 64, 64, 64] -- -- --
│ │ │ │ └─Conv2d (2) [8, 64, 64, 64] [8, 64, 64, 64] 36,928 True 0.59%
│ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (1) [8, 64, 64, 64] [8, 64, 64, 64] -- True --
│ │ │ │ └─GroupNorm (0) [8, 64, 64, 64] [8, 64, 64, 64] 128 True 0.00%
│ │ │ │ └─SiLU (1) [8, 64, 64, 64] [8, 64, 64, 64] -- -- --
│ │ │ │ └─Conv2d (2) [8, 64, 64, 64] [8, 64, 64, 64] 36,928 True 0.59%
│ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
│ │ │ └─Conv2d (1) [8, 64, 64, 64] [8, 64, 64, 64] 4,160 True 0.07%
│ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (2) [8, 64, 64, 64] [8, 64, 64, 64] -- True --
│ │ │ │ └─GroupNorm (0) [8, 64, 64, 64] [8, 64, 64, 64] 128 True 0.00%
│ │ │ │ └─SiLU (1) [8, 64, 64, 64] [8, 64, 64, 64] -- -- --
│ │ │ │ └─Conv2d (2) [8, 64, 64, 64] [8, 64, 64, 64] 36,928 True 0.59%
│ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (2) [8, 64, 64, 64] [8, 64, 64, 64] -- True --
│ │ │ │ └─GroupNorm (0) [8, 64, 64, 64] [8, 64, 64, 64] 128 True 0.00%
│ │ │ │ └─SiLU (1) [8, 64, 64, 64] [8, 64, 64, 64] -- -- --
│ │ │ │ └─Conv2d (2) [8, 64, 64, 64] [8, 64, 64, 64] 36,928 True 0.59%
│ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
│ │ │ └─Conv2d (2) [8, 64, 64, 64] [8, 64, 64, 64] 4,160 True 0.07%
│ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (3) [8, 64, 64, 64] [8, 64, 64, 64] -- True --
│ │ │ │ └─GroupNorm (0) [8, 64, 64, 64] [8, 64, 64, 64] 128 True 0.00%
│ │ │ │ └─SiLU (1) [8, 64, 64, 64] [8, 64, 64, 64] -- -- --
│ │ │ │ └─Conv2d (2) [8, 64, 64, 64] [8, 64, 64, 64] 36,928 True 0.59%
│ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (3) [8, 64, 64, 64] [8, 64, 64, 64] -- True --
│ │ │ │ └─GroupNorm (0) [8, 64, 64, 64] [8, 64, 64, 64] 128 True 0.00%
│ │ │ │ └─SiLU (1) [8, 64, 64, 64] [8, 64, 64, 64] -- -- --
│ │ │ │ └─Conv2d (2) [8, 64, 64, 64] [8, 64, 64, 64] 36,928 True 0.59%
│ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
│ │ │ └─Conv2d (3) [8, 64, 64, 64] [8, 64, 64, 64] 4,160 True 0.07%
│ └─UpBlock (1) [8, 64, 64, 64] [8, 32, 128, 128] -- True --
│ │ └─ConvTranspose2d (up_sample_conv) [8, 64, 64, 64] [8, 64, 128, 128] 65,600 True 1.05%
│ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (0) [8, 64, 128, 128] [8, 32, 128, 128] -- True --
│ │ │ │ └─GroupNorm (0) [8, 64, 128, 128] [8, 64, 128, 128] 128 True 0.00%
│ │ │ │ └─SiLU (1) [8, 64, 128, 128] [8, 64, 128, 128] -- -- --
│ │ │ │ └─Conv2d (2) [8, 64, 128, 128] [8, 32, 128, 128] 18,464 True 0.30%
│ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (0) [8, 32, 128, 128] [8, 32, 128, 128] -- True --
│ │ │ │ └─GroupNorm (0) [8, 32, 128, 128] [8, 32, 128, 128] 64 True 0.00%
│ │ │ │ └─SiLU (1) [8, 32, 128, 128] [8, 32, 128, 128] -- -- --
│ │ │ │ └─Conv2d (2) [8, 32, 128, 128] [8, 32, 128, 128] 9,248 True 0.15%
│ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
│ │ │ └─Conv2d (0) [8, 64, 128, 128] [8, 32, 128, 128] 2,080 True 0.03%
│ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (1) [8, 32, 128, 128] [8, 32, 128, 128] -- True --
│ │ │ │ └─GroupNorm (0) [8, 32, 128, 128] [8, 32, 128, 128] 64 True 0.00%
│ │ │ │ └─SiLU (1) [8, 32, 128, 128] [8, 32, 128, 128] -- -- --
│ │ │ │ └─Conv2d (2) [8, 32, 128, 128] [8, 32, 128, 128] 9,248 True 0.15%
│ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (1) [8, 32, 128, 128] [8, 32, 128, 128] -- True --
│ │ │ │ └─GroupNorm (0) [8, 32, 128, 128] [8, 32, 128, 128] 64 True 0.00%
│ │ │ │ └─SiLU (1) [8, 32, 128, 128] [8, 32, 128, 128] -- -- --
│ │ │ │ └─Conv2d (2) [8, 32, 128, 128] [8, 32, 128, 128] 9,248 True 0.15%
│ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
│ │ │ └─Conv2d (1) [8, 32, 128, 128] [8, 32, 128, 128] 1,056 True 0.02%
│ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (2) [8, 32, 128, 128] [8, 32, 128, 128] -- True --
│ │ │ │ └─GroupNorm (0) [8, 32, 128, 128] [8, 32, 128, 128] 64 True 0.00%
│ │ │ │ └─SiLU (1) [8, 32, 128, 128] [8, 32, 128, 128] -- -- --
│ │ │ │ └─Conv2d (2) [8, 32, 128, 128] [8, 32, 128, 128] 9,248 True 0.15%
│ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (2) [8, 32, 128, 128] [8, 32, 128, 128] -- True --
│ │ │ │ └─GroupNorm (0) [8, 32, 128, 128] [8, 32, 128, 128] 64 True 0.00%
│ │ │ │ └─SiLU (1) [8, 32, 128, 128] [8, 32, 128, 128] -- -- --
│ │ │ │ └─Conv2d (2) [8, 32, 128, 128] [8, 32, 128, 128] 9,248 True 0.15%
│ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
│ │ │ └─Conv2d (2) [8, 32, 128, 128] [8, 32, 128, 128] 1,056 True 0.02%
│ │ └─ModuleList (resnet_conv_first) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (3) [8, 32, 128, 128] [8, 32, 128, 128] -- True --
│ │ │ │ └─GroupNorm (0) [8, 32, 128, 128] [8, 32, 128, 128] 64 True 0.00%
│ │ │ │ └─SiLU (1) [8, 32, 128, 128] [8, 32, 128, 128] -- -- --
│ │ │ │ └─Conv2d (2) [8, 32, 128, 128] [8, 32, 128, 128] 9,248 True 0.15%
│ │ └─ModuleList (resnet_conv_second) -- -- (recursive) True (recursive)
│ │ │ └─Sequential (3) [8, 32, 128, 128] [8, 32, 128, 128] -- True --
│ │ │ │ └─GroupNorm (0) [8, 32, 128, 128] [8, 32, 128, 128] 64 True 0.00%
│ │ │ │ └─SiLU (1) [8, 32, 128, 128] [8, 32, 128, 128] -- -- --
│ │ │ │ └─Conv2d (2) [8, 32, 128, 128] [8, 32, 128, 128] 9,248 True 0.15%
│ │ └─ModuleList (residual_input_conv) -- -- (recursive) True (recursive)
│ │ │ └─Conv2d (3) [8, 32, 128, 128] [8, 32, 128, 128] 1,056 True 0.02%
├─GroupNorm (decoder_norm_out) [8, 32, 128, 128] [8, 32, 128, 128] 64 True 0.00%
├─Conv2d (decoder_conv_out) [8, 32, 128, 128] [8, 3, 128, 128] 867 True 0.01%
======================================================================================================================================================
Total params: 6,219,770
Trainable params: 6,219,770
Non-trainable params: 0
Total mult-adds (Units.GIGABYTES): 146.86
======================================================================================================================================================
Input size (MB): 1.57
Forward/backward pass size (MB): 3719.89
Params size (MB): 22.77
Estimated Total Size (MB): 3744.23
======================================================================================================================================================