import torch import torch.nn as nn import torch.nn.functional as F class ResidualBlock(nn.Module): def __init__(self, channels): super(ResidualBlock, self).__init__() self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm2d(channels) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(channels) def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out += residual out = self.relu(out) return out class Encoder(nn.Module): def __init__(self, input_channels=1, hidden_dims=[64, 128, 256, 512, 1024], latent_dim=32): super(Encoder, self).__init__() self.hidden_dims = hidden_dims # Build Encoder with Residual Blocks modules = [] for h_dim in hidden_dims: modules.append( nn.Sequential( nn.Conv2d(input_channels, h_dim, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(h_dim), nn.LeakyReLU(), ResidualBlock(h_dim) # Adding a residual block ) ) input_channels = h_dim self.encoder = nn.Sequential(*modules) self.fc_mu = nn.Linear(hidden_dims[-1]*hidden_dims[-3], latent_dim) self.fc_var = nn.Linear(hidden_dims[-1]*hidden_dims[-3], latent_dim) def forward(self, x): for layer in self.encoder: x = layer(x) x = torch.flatten(x, start_dim=1) mu = self.fc_mu(x) log_var = self.fc_var(x) return mu, log_var class Decoder(nn.Module): def __init__(self, latent_dim=32, output_channels=1, hidden_dims=[64, 128, 256, 512, 1024]): super(Decoder, self).__init__() self.hidden_dims = hidden_dims # Reversing the order for the decoder hidden_dims = hidden_dims[::-1] self.decoder_input = nn.Linear(latent_dim, hidden_dims[0]*hidden_dims[2]) # Build Decoder with Residual Blocks modules = [] for i in range(len(hidden_dims) - 1): modules.append( nn.Sequential( nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i+1], kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[i+1]), nn.LeakyReLU(), ResidualBlock(hidden_dims[i+1]) # Adding a residual block ) ) self.decoder = nn.Sequential(*modules) self.final_layer = nn.Sequential( nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[-1]), nn.LeakyReLU(), nn.Conv2d(hidden_dims[-1], output_channels, kernel_size=3, padding=1), nn.Sigmoid() ) def forward(self, z): z = self.decoder_input(z) z = z.view(-1, 1024, 16, 16) for layer in self.decoder: z = layer(z) result = self.final_layer(z) return result class VAE(nn.Module): def __init__(self, input_channels=1, latent_dim=32, hidden_dims=None): super(VAE, self).__init__() if hidden_dims is None: hidden_dims = [64, 128, 256, 512, 1024] self.encoder = Encoder(input_channels=input_channels, hidden_dims=hidden_dims, latent_dim=latent_dim) self.decoder = Decoder(latent_dim=latent_dim, output_channels=input_channels, hidden_dims=hidden_dims) def encode(self, input): mu, log_var = self.encoder(input) return mu, log_var def reparameterize(self, mu, log_var): std = torch.exp(0.5 * log_var) eps = torch.randn_like(std) return mu + eps * std def decode(self, z): return self.decoder(z) def forward(self, input): mu, log_var = self.encode(input) z = self.reparameterize(mu, log_var) return self.decode(z), mu, log_var # Loss function for VAE def loss_function(recon_x, x, mu, log_var): BCE = F.binary_cross_entropy(recon_x, x, reduction='sum') KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) return BCE + KLD