Spaces:
Runtime error
Runtime error
| from argparse import Namespace, ArgumentParser | |
| from functools import partial | |
| from torch import nn | |
| from .resnet import ResNetBasicBlock, activation_func, norm_module, Conv2dAuto | |
| def add_arguments(parser: ArgumentParser) -> ArgumentParser: | |
| parser.add_argument("--latent_size", type=int, default=512, help="latent size") | |
| return parser | |
| def create_model(args) -> nn.Module: | |
| in_channels = 3 if "rgb" in args and args.rgb else 1 | |
| return Encoder(in_channels, args.encoder_size, latent_size=args.latent_size) | |
| class Flatten(nn.Module): | |
| def forward(self, input_): | |
| return input_.view(input_.size(0), -1) | |
| class Encoder(nn.Module): | |
| def __init__( | |
| self, in_channels: int, size: int, latent_size: int = 512, | |
| activation: str = 'leaky_relu', norm: str = "instance" | |
| ): | |
| super().__init__() | |
| out_channels0 = 64 | |
| norm_m = norm_module(norm) | |
| self.conv0 = nn.Sequential( | |
| Conv2dAuto(in_channels, out_channels0, kernel_size=5), | |
| norm_m(out_channels0), | |
| activation_func(activation), | |
| ) | |
| pool_kernel = 2 | |
| self.pool = nn.AvgPool2d(pool_kernel) | |
| num_channels = [128, 256, 512, 512] | |
| # FIXME: this is a hack | |
| if size >= 256: | |
| num_channels.append(512) | |
| residual = partial(ResNetBasicBlock, activation=activation, norm=norm, bias=True) | |
| residual_blocks = nn.ModuleList() | |
| for in_channel, out_channel in zip([out_channels0] + num_channels[:-1], num_channels): | |
| residual_blocks.append(residual(in_channel, out_channel)) | |
| residual_blocks.append(nn.AvgPool2d(pool_kernel)) | |
| self.residual_blocks = nn.Sequential(*residual_blocks) | |
| self.last = nn.Sequential( | |
| nn.ReLU(), | |
| nn.AvgPool2d(4), # TODO: not sure whehter this would cause problem | |
| Flatten(), | |
| nn.Linear(num_channels[-1], latent_size, bias=True) | |
| ) | |
| def forward(self, input_): | |
| out = self.conv0(input_) | |
| out = self.pool(out) | |
| out = self.residual_blocks(out) | |
| out = self.last(out) | |
| return out | |