Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| class BaseConditioner(nn.Module): | |
| def __init__(self): | |
| super(BaseConditioner, self).__init__() | |
| def _impl_condition(self, y): | |
| ... | |
| def _impl_uncondition(self, y): | |
| ... | |
| def __call__(self, y): | |
| condition = self._impl_condition(y) | |
| uncondition = self._impl_uncondition(y) | |
| return condition, uncondition | |
| class LabelConditioner(BaseConditioner): | |
| def __init__(self, null_class): | |
| super().__init__() | |
| self.null_condition = null_class | |
| def _impl_condition(self, y): | |
| return torch.tensor(y).long().cuda() | |
| def _impl_uncondition(self, y): | |
| return torch.full((len(y),), self.null_condition, dtype=torch.long).cuda() |