amuse / vae_module.py
alppo's picture
add vae and slicer modules
0cc41af
import torch
import torch.nn as nn
import torch.nn.functional as F
class ResidualBlock(nn.Module):
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += residual
out = self.relu(out)
return out
class Encoder(nn.Module):
def __init__(self, input_channels=1, hidden_dims=[64, 128, 256, 512, 1024], latent_dim=32):
super(Encoder, self).__init__()
self.hidden_dims = hidden_dims
# Build Encoder with Residual Blocks
modules = []
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(input_channels, h_dim, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(h_dim),
nn.LeakyReLU(),
ResidualBlock(h_dim) # Adding a residual block
)
)
input_channels = h_dim
self.encoder = nn.Sequential(*modules)
self.fc_mu = nn.Linear(hidden_dims[-1]*hidden_dims[-3], latent_dim)
self.fc_var = nn.Linear(hidden_dims[-1]*hidden_dims[-3], latent_dim)
def forward(self, x):
for layer in self.encoder:
x = layer(x)
x = torch.flatten(x, start_dim=1)
mu = self.fc_mu(x)
log_var = self.fc_var(x)
return mu, log_var
class Decoder(nn.Module):
def __init__(self, latent_dim=32, output_channels=1, hidden_dims=[64, 128, 256, 512, 1024]):
super(Decoder, self).__init__()
self.hidden_dims = hidden_dims
# Reversing the order for the decoder
hidden_dims = hidden_dims[::-1]
self.decoder_input = nn.Linear(latent_dim, hidden_dims[0]*hidden_dims[2])
# Build Decoder with Residual Blocks
modules = []
for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i+1], kernel_size=3, stride=2, padding=1, output_padding=1),
nn.BatchNorm2d(hidden_dims[i+1]),
nn.LeakyReLU(),
ResidualBlock(hidden_dims[i+1]) # Adding a residual block
)
)
self.decoder = nn.Sequential(*modules)
self.final_layer = nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(),
nn.Conv2d(hidden_dims[-1], output_channels, kernel_size=3, padding=1),
nn.Sigmoid()
)
def forward(self, z):
z = self.decoder_input(z)
z = z.view(-1, 1024, 16, 16)
for layer in self.decoder:
z = layer(z)
result = self.final_layer(z)
return result
class VAE(nn.Module):
def __init__(self,
input_channels=1,
latent_dim=32,
hidden_dims=None):
super(VAE, self).__init__()
if hidden_dims is None:
hidden_dims = [64, 128, 256, 512, 1024]
self.encoder = Encoder(input_channels=input_channels,
hidden_dims=hidden_dims,
latent_dim=latent_dim)
self.decoder = Decoder(latent_dim=latent_dim,
output_channels=input_channels,
hidden_dims=hidden_dims)
def encode(self, input):
mu, log_var = self.encoder(input)
return mu, log_var
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
return self.decoder(z)
def forward(self, input):
mu, log_var = self.encode(input)
z = self.reparameterize(mu, log_var)
return self.decode(z), mu, log_var
# Loss function for VAE
def loss_function(recon_x, x, mu, log_var):
BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
return BCE + KLD