| """Residual Block Adopted from ManiGAN""" | |
| from typing import Any | |
| import torch | |
| from torch import nn | |
| class ResidualBlock(nn.Module): | |
| """Residual Block""" | |
| def __init__(self, channel_num: int) -> None: | |
| """ | |
| :param channel_num: Number of channels in the input | |
| """ | |
| super().__init__() | |
| self.block = nn.Sequential( | |
| nn.Conv2d( | |
| channel_num, | |
| channel_num * 2, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| bias=False, | |
| ), | |
| nn.InstanceNorm2d(channel_num * 2), | |
| nn.GLU(dim=1), | |
| nn.Conv2d( | |
| channel_num, channel_num, kernel_size=3, stride=1, padding=1, bias=False | |
| ), | |
| nn.InstanceNorm2d(channel_num), | |
| ) | |
| def forward(self, input_tensor: torch.Tensor) -> Any: | |
| """ | |
| :param input_tensor: Input tensor | |
| :return: Output tensor | |
| """ | |
| residual = input_tensor | |
| out = self.block(input_tensor) | |
| out += residual | |
| return out | |