|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
class ResidualBlock(nn.Module): |
|
|
def __init__(self, channels): |
|
|
super(ResidualBlock, self).__init__() |
|
|
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) |
|
|
self.bn1 = nn.BatchNorm2d(channels) |
|
|
self.relu = nn.ReLU(inplace=True) |
|
|
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) |
|
|
self.bn2 = nn.BatchNorm2d(channels) |
|
|
|
|
|
def forward(self, x): |
|
|
residual = x |
|
|
out = self.conv1(x) |
|
|
out = self.bn1(out) |
|
|
out = self.relu(out) |
|
|
out = self.conv2(out) |
|
|
out = self.bn2(out) |
|
|
out += residual |
|
|
out = self.relu(out) |
|
|
return out |
|
|
|
|
|
class Encoder(nn.Module): |
|
|
def __init__(self, input_channels=1, hidden_dims=[64, 128, 256, 512, 1024], latent_dim=32): |
|
|
super(Encoder, self).__init__() |
|
|
self.hidden_dims = hidden_dims |
|
|
|
|
|
|
|
|
modules = [] |
|
|
for h_dim in hidden_dims: |
|
|
modules.append( |
|
|
nn.Sequential( |
|
|
nn.Conv2d(input_channels, h_dim, kernel_size=3, stride=2, padding=1), |
|
|
nn.BatchNorm2d(h_dim), |
|
|
nn.LeakyReLU(), |
|
|
ResidualBlock(h_dim) |
|
|
) |
|
|
) |
|
|
input_channels = h_dim |
|
|
|
|
|
self.encoder = nn.Sequential(*modules) |
|
|
self.fc_mu = nn.Linear(hidden_dims[-1]*hidden_dims[-3], latent_dim) |
|
|
self.fc_var = nn.Linear(hidden_dims[-1]*hidden_dims[-3], latent_dim) |
|
|
|
|
|
def forward(self, x): |
|
|
for layer in self.encoder: |
|
|
x = layer(x) |
|
|
x = torch.flatten(x, start_dim=1) |
|
|
mu = self.fc_mu(x) |
|
|
log_var = self.fc_var(x) |
|
|
return mu, log_var |
|
|
|
|
|
class Decoder(nn.Module): |
|
|
def __init__(self, latent_dim=32, output_channels=1, hidden_dims=[64, 128, 256, 512, 1024]): |
|
|
super(Decoder, self).__init__() |
|
|
self.hidden_dims = hidden_dims |
|
|
|
|
|
hidden_dims = hidden_dims[::-1] |
|
|
self.decoder_input = nn.Linear(latent_dim, hidden_dims[0]*hidden_dims[2]) |
|
|
|
|
|
|
|
|
modules = [] |
|
|
for i in range(len(hidden_dims) - 1): |
|
|
modules.append( |
|
|
nn.Sequential( |
|
|
nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i+1], kernel_size=3, stride=2, padding=1, output_padding=1), |
|
|
nn.BatchNorm2d(hidden_dims[i+1]), |
|
|
nn.LeakyReLU(), |
|
|
ResidualBlock(hidden_dims[i+1]) |
|
|
) |
|
|
) |
|
|
|
|
|
self.decoder = nn.Sequential(*modules) |
|
|
self.final_layer = nn.Sequential( |
|
|
nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1), |
|
|
nn.BatchNorm2d(hidden_dims[-1]), |
|
|
nn.LeakyReLU(), |
|
|
nn.Conv2d(hidden_dims[-1], output_channels, kernel_size=3, padding=1), |
|
|
nn.Sigmoid() |
|
|
) |
|
|
|
|
|
def forward(self, z): |
|
|
z = self.decoder_input(z) |
|
|
z = z.view(-1, 1024, 16, 16) |
|
|
for layer in self.decoder: |
|
|
z = layer(z) |
|
|
result = self.final_layer(z) |
|
|
return result |
|
|
|
|
|
class VAE(nn.Module): |
|
|
def __init__(self, |
|
|
input_channels=1, |
|
|
latent_dim=32, |
|
|
hidden_dims=None): |
|
|
super(VAE, self).__init__() |
|
|
|
|
|
if hidden_dims is None: |
|
|
hidden_dims = [64, 128, 256, 512, 1024] |
|
|
|
|
|
self.encoder = Encoder(input_channels=input_channels, |
|
|
hidden_dims=hidden_dims, |
|
|
latent_dim=latent_dim) |
|
|
|
|
|
self.decoder = Decoder(latent_dim=latent_dim, |
|
|
output_channels=input_channels, |
|
|
hidden_dims=hidden_dims) |
|
|
|
|
|
def encode(self, input): |
|
|
mu, log_var = self.encoder(input) |
|
|
return mu, log_var |
|
|
|
|
|
def reparameterize(self, mu, log_var): |
|
|
std = torch.exp(0.5 * log_var) |
|
|
eps = torch.randn_like(std) |
|
|
return mu + eps * std |
|
|
|
|
|
def decode(self, z): |
|
|
return self.decoder(z) |
|
|
|
|
|
def forward(self, input): |
|
|
mu, log_var = self.encode(input) |
|
|
z = self.reparameterize(mu, log_var) |
|
|
return self.decode(z), mu, log_var |
|
|
|
|
|
|
|
|
def loss_function(recon_x, x, mu, log_var): |
|
|
BCE = F.binary_cross_entropy(recon_x, x, reduction='sum') |
|
|
KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) |
|
|
return BCE + KLD |