mramazan's picture
Upload 60 files
426ffb5 verified
raw
history blame
2.1 kB
from .base import BaseModel
import torch
import torch.nn as nn
import torch.nn.functional as F
class VAEModel(BaseModel):
def __init__(self, args):
super().__init__(args)
self.latent_dim = args.vae_latent_dim
# Input dropout
self.input_dropout = nn.Dropout(p=args.vae_dropout)
# Construct a list of dimensions for the encoder and the decoder
dims = [args.vae_hidden_dim] * 2 * args.vae_num_hidden
dims = [args.num_items] + dims + [args.vae_latent_dim * 2]
# Stack encoders and decoders
encoder_modules, decoder_modules = [], []
for i in range(len(dims)//2):
encoder_modules.append(nn.Linear(dims[2*i], dims[2*i+1]))
if i == 0:
decoder_modules.append(nn.Linear(dims[-1]//2, dims[-2]))
else:
decoder_modules.append(nn.Linear(dims[-2*i-1], dims[-2*i-2]))
self.encoder = nn.ModuleList(encoder_modules)
self.decoder = nn.ModuleList(decoder_modules)
# Initialize weights
self.encoder.apply(self.weight_init)
self.decoder.apply(self.weight_init)
def weight_init(self, m):
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight)
m.bias.data.zero_()
@classmethod
def code(cls):
return 'vae'
def forward(self, x):
x = F.normalize(x)
x = self.input_dropout(x)
for i, layer in enumerate(self.encoder):
x = layer(x)
if i != len(self.encoder) - 1:
x = torch.tanh(x)
mu, logvar = x[:, :self.latent_dim], x[:, self.latent_dim:]
if self.training:
# since log(var) = log(sigma^2) = 2*log(sigma)
sigma = torch.exp(0.5 * logvar)
eps = torch.randn_like(sigma)
x = mu + eps * sigma
else:
x = mu
for i, layer in enumerate(self.decoder):
x = layer(x)
if i != len(self.decoder) - 1:
x = torch.tanh(x)
return x, mu, logvar