Spaces:
Running
Running
| from .base import BaseModel | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class VAEModel(BaseModel): | |
| def __init__(self, args): | |
| super().__init__(args) | |
| self.latent_dim = args.vae_latent_dim | |
| # Input dropout | |
| self.input_dropout = nn.Dropout(p=args.vae_dropout) | |
| # Construct a list of dimensions for the encoder and the decoder | |
| dims = [args.vae_hidden_dim] * 2 * args.vae_num_hidden | |
| dims = [args.num_items] + dims + [args.vae_latent_dim * 2] | |
| # Stack encoders and decoders | |
| encoder_modules, decoder_modules = [], [] | |
| for i in range(len(dims)//2): | |
| encoder_modules.append(nn.Linear(dims[2*i], dims[2*i+1])) | |
| if i == 0: | |
| decoder_modules.append(nn.Linear(dims[-1]//2, dims[-2])) | |
| else: | |
| decoder_modules.append(nn.Linear(dims[-2*i-1], dims[-2*i-2])) | |
| self.encoder = nn.ModuleList(encoder_modules) | |
| self.decoder = nn.ModuleList(decoder_modules) | |
| # Initialize weights | |
| self.encoder.apply(self.weight_init) | |
| self.decoder.apply(self.weight_init) | |
| def weight_init(self, m): | |
| if isinstance(m, nn.Linear): | |
| nn.init.kaiming_normal_(m.weight) | |
| m.bias.data.zero_() | |
| def code(cls): | |
| return 'vae' | |
| def forward(self, x): | |
| x = F.normalize(x) | |
| x = self.input_dropout(x) | |
| for i, layer in enumerate(self.encoder): | |
| x = layer(x) | |
| if i != len(self.encoder) - 1: | |
| x = torch.tanh(x) | |
| mu, logvar = x[:, :self.latent_dim], x[:, self.latent_dim:] | |
| if self.training: | |
| # since log(var) = log(sigma^2) = 2*log(sigma) | |
| sigma = torch.exp(0.5 * logvar) | |
| eps = torch.randn_like(sigma) | |
| x = mu + eps * sigma | |
| else: | |
| x = mu | |
| for i, layer in enumerate(self.decoder): | |
| x = layer(x) | |
| if i != len(self.decoder) - 1: | |
| x = torch.tanh(x) | |
| return x, mu, logvar | |