Spaces:
Running
on
Zero
Running
on
Zero
| import time | |
| import torch | |
| import torch.nn as nn | |
| class BaseTrainer(nn.Module): | |
| def __init__(self, | |
| null_condition_p=0.1, | |
| log_var=False, | |
| ): | |
| super(BaseTrainer, self).__init__() | |
| self.null_condition_p = null_condition_p | |
| self.log_var = log_var | |
| def preproprocess(self, raw_iamges, x, condition, uncondition): | |
| bsz = x.shape[0] | |
| if self.null_condition_p > 0: | |
| mask = torch.rand((bsz), device=condition.device) < self.null_condition_p | |
| mask = mask.expand_as(condition) | |
| condition[mask] = uncondition[mask] | |
| return raw_iamges, x, condition | |
| def _impl_trainstep(self, net, ema_net, raw_images, x, y): | |
| raise NotImplementedError | |
| def __call__(self, net, ema_net, raw_images, x, condition, uncondition): | |
| raw_images, x, condition = self.preproprocess(raw_images, x, condition, uncondition) | |
| return self._impl_trainstep(net, ema_net, raw_images, x, condition) | |