Spaces:
Running
Running
File size: 1,564 Bytes
0edbb0d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
from .base import BaseModel
import torch
import torch.nn as nn
import torch.nn.functional as F
class DAEModel(BaseModel):
def __init__(self, args):
super().__init__(args)
# Input dropout
self.input_dropout = nn.Dropout(p=args.dae_dropout)
# Construct a list of dimensions for the encoder and the decoder
dims = [args.dae_hidden_dim] * 2 * args.dae_num_hidden
dims = [args.num_items] + dims + [args.dae_latent_dim]
# Stack encoders and decoders
encoder_modules, decoder_modules = [], []
for i in range(len(dims)//2):
encoder_modules.append(nn.Linear(dims[2*i], dims[2*i+1]))
decoder_modules.append(nn.Linear(dims[-2*i-1], dims[-2*i-2]))
self.encoder = nn.ModuleList(encoder_modules)
self.decoder = nn.ModuleList(decoder_modules)
# Initialize weights
self.encoder.apply(self.weight_init)
self.decoder.apply(self.weight_init)
def weight_init(self, m):
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight)
m.bias.data.normal_(0.0, 0.001)
@classmethod
def code(cls):
return 'dae'
def forward(self, x):
x = F.normalize(x)
x = self.input_dropout(x)
for i, layer in enumerate(self.encoder):
x = layer(x)
x = torch.tanh(x)
for i, layer in enumerate(self.decoder):
x = layer(x)
if i != len(self.decoder)-1:
x = torch.tanh(x)
return x
|