Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| class DecoderBase(nn.Module): | |
| """docstring for Decoder""" | |
| def __init__(self): | |
| super(DecoderBase, self).__init__() | |
| def freeze(self): | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def decode(self, x, z): | |
| """ | |
| Args: | |
| x: (batch_size, seq_len) | |
| z: (batch_size, n_sample, nz) | |
| Returns: Tensor1 | |
| Tensor1: the output logits with size (batch_size * n_sample, seq_len, vocab_size) | |
| """ | |
| raise NotImplementedError | |
| def reconstruct_error(self, x, z): | |
| """reconstruction loss | |
| Args: | |
| x: (batch_size, *) | |
| z: (batch_size, n_sample, nz) | |
| Returns: | |
| loss: (batch_size, n_sample). Loss | |
| across different sentence and z | |
| """ | |
| raise NotImplementedError | |
| def beam_search_decode(self, z, K): | |
| """beam search decoding | |
| Args: | |
| z: (batch_size, nz) | |
| K: the beam size | |
| Returns: List1 | |
| List1: the decoded word sentence list | |
| """ | |
| raise NotImplementedError | |
| def sample_decode(self, z): | |
| """sampling from z | |
| Args: | |
| z: (batch_size, nz) | |
| Returns: List1 | |
| List1: the decoded word sentence list | |
| """ | |
| raise NotImplementedError | |
| def greedy_decode(self, z): | |
| """greedy decoding from z | |
| Args: | |
| z: (batch_size, nz) | |
| Returns: List1 | |
| List1: the decoded word sentence list | |
| """ | |
| raise NotImplementedError | |
| def log_probability(self, x, z): | |
| """ | |
| Args: | |
| x: (batch_size, *) | |
| z: (batch_size, n_sample, nz) | |
| Returns: | |
| log_p: (batch_size, n_sample). | |
| log_p(x|z) across different x and z | |
| """ | |
| raise NotImplementedError |