Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from math import log2 | |
| """ | |
| Factors is used in Discrmininator and Generator for how much | |
| the channels should be multiplied and expanded for each layer, | |
| so specifically the first 5 layers the channels stay the same, | |
| whereas when we increase the img_size (towards the later layers) | |
| we decrease the number of chanels by 1/2, 1/4, etc. | |
| """ | |
| factors = [1, 1, 1, 1, 1 / 2, 1 / 4, 1 / 8, 1 / 16, 1 / 32] | |
| class WSConv2d(nn.Module): | |
| """ | |
| Weight scaled Conv2d (Equalized Learning Rate) | |
| Note that input is multiplied rather than changing weights | |
| this will have the same result. | |
| """ | |
| def __init__( | |
| self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, gain=2 | |
| ): | |
| super(WSConv2d, self).__init__() | |
| self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) | |
| self.scale = (gain / (in_channels * (kernel_size ** 2))) ** 0.5 | |
| self.bias = self.conv.bias | |
| self.conv.bias = None | |
| # initialize conv layer | |
| nn.init.normal_(self.conv.weight) | |
| nn.init.zeros_(self.bias) | |
| def forward(self, x): | |
| return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1) | |
| class PixelNorm(nn.Module): | |
| def __init__(self): | |
| super(PixelNorm, self).__init__() | |
| self.epsilon = 1e-8 | |
| def forward(self, x): | |
| return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon) | |
| class ConvBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels, use_pixelnorm=True): | |
| super(ConvBlock, self).__init__() | |
| self.use_pn = use_pixelnorm | |
| self.conv1 = WSConv2d(in_channels, out_channels) | |
| self.conv2 = WSConv2d(out_channels, out_channels) | |
| self.leaky = nn.LeakyReLU(0.2) | |
| self.pn = PixelNorm() | |
| def forward(self, x): | |
| x = self.leaky(self.conv1(x)) | |
| x = self.pn(x) if self.use_pn else x | |
| x = self.leaky(self.conv2(x)) | |
| x = self.pn(x) if self.use_pn else x | |
| return x | |
| class Generator(nn.Module): | |
| def __init__(self, z_dim, in_channels, img_channels=3): | |
| super(Generator, self).__init__() | |
| # initial takes 1x1 -> 4x4 | |
| self.initial = nn.Sequential( | |
| PixelNorm(), | |
| nn.ConvTranspose2d(z_dim, in_channels, 4, 1, 0), | |
| nn.LeakyReLU(0.2), | |
| WSConv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1), | |
| nn.LeakyReLU(0.2), | |
| PixelNorm(), | |
| ) | |
| self.initial_rgb = WSConv2d( | |
| in_channels, img_channels, kernel_size=1, stride=1, padding=0 | |
| ) | |
| self.prog_blocks, self.rgb_layers = ( | |
| nn.ModuleList([]), | |
| nn.ModuleList([self.initial_rgb]), | |
| ) | |
| for i in range( | |
| len(factors) - 1 | |
| ): # -1 to prevent index error because of factors[i+1] | |
| conv_in_c = int(in_channels * factors[i]) | |
| conv_out_c = int(in_channels * factors[i + 1]) | |
| self.prog_blocks.append(ConvBlock(conv_in_c, conv_out_c)) | |
| self.rgb_layers.append( | |
| WSConv2d(conv_out_c, img_channels, kernel_size=1, stride=1, padding=0) | |
| ) | |
| def fade_in(self, alpha, upscaled, generated): | |
| # alpha should be scalar within [0, 1], and upscale.shape == generated.shape | |
| return torch.tanh(alpha * generated + (1 - alpha) * upscaled) | |
| def forward(self, x, alpha, steps): | |
| out = self.initial(x) | |
| if steps == 0: | |
| return self.initial_rgb(out) | |
| for step in range(steps): | |
| upscaled = F.interpolate(out, scale_factor=2, mode="nearest") | |
| out = self.prog_blocks[step](upscaled) | |
| # The number of channels in upscale will stay the same, while | |
| # out which has moved through prog_blocks might change. To ensure | |
| # we can convert both to rgb we use different rgb_layers | |
| # (steps-1) and steps for upscaled, out respectively | |
| final_upscaled = self.rgb_layers[steps - 1](upscaled) | |
| final_out = self.rgb_layers[steps](out) | |
| return self.fade_in(alpha, final_upscaled, final_out) | |
| class Discriminator(nn.Module): | |
| def __init__(self, z_dim, in_channels, img_channels=3): | |
| super(Discriminator, self).__init__() | |
| self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([]) | |
| self.leaky = nn.LeakyReLU(0.2) | |
| # here we work back ways from factors because the discriminator | |
| # should be mirrored from the generator. So the first prog_block and | |
| # rgb layer we append will work for input size 1024x1024, then 512->256-> etc | |
| for i in range(len(factors) - 1, 0, -1): | |
| conv_in = int(in_channels * factors[i]) | |
| conv_out = int(in_channels * factors[i - 1]) | |
| self.prog_blocks.append(ConvBlock(conv_in, conv_out, use_pixelnorm=False)) | |
| self.rgb_layers.append( | |
| WSConv2d(img_channels, conv_in, kernel_size=1, stride=1, padding=0) | |
| ) | |
| # perhaps confusing name "initial_rgb" this is just the RGB layer for 4x4 input size | |
| # did this to "mirror" the generator initial_rgb | |
| self.initial_rgb = WSConv2d( | |
| img_channels, in_channels, kernel_size=1, stride=1, padding=0 | |
| ) | |
| self.rgb_layers.append(self.initial_rgb) | |
| self.avg_pool = nn.AvgPool2d( | |
| kernel_size=2, stride=2 | |
| ) # down sampling using avg pool | |
| # this is the block for 4x4 input size | |
| self.final_block = nn.Sequential( | |
| # +1 to in_channels because we concatenate from MiniBatch std | |
| WSConv2d(in_channels + 1, in_channels, kernel_size=3, padding=1), | |
| nn.LeakyReLU(0.2), | |
| WSConv2d(in_channels, in_channels, kernel_size=4, padding=0, stride=1), | |
| nn.LeakyReLU(0.2), | |
| WSConv2d( | |
| in_channels, 1, kernel_size=1, padding=0, stride=1 | |
| ), # we use this instead of linear layer | |
| ) | |
| def fade_in(self, alpha, downscaled, out): | |
| """Used to fade in downscaled using avg pooling and output from CNN""" | |
| # alpha should be scalar within [0, 1], and upscale.shape == generated.shape | |
| return alpha * out + (1 - alpha) * downscaled | |
| def minibatch_std(self, x): | |
| batch_statistics = ( | |
| torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3]) | |
| ) | |
| # we take the std for each example (across all channels, and pixels) then we repeat it | |
| # for a single channel and concatenate it with the image. In this way the discriminator | |
| # will get information about the variation in the batch/image | |
| return torch.cat([x, batch_statistics], dim=1) | |
| def forward(self, x, alpha, steps): | |
| # where we should start in the list of prog_blocks, maybe a bit confusing but | |
| # the last is for the 4x4. So example let's say steps=1, then we should start | |
| # at the second to last because input_size will be 8x8. If steps==0 we just | |
| # use the final block | |
| cur_step = len(self.prog_blocks) - steps | |
| # convert from rgb as initial step, this will depend on | |
| # the image size (each will have it's on rgb layer) | |
| out = self.leaky(self.rgb_layers[cur_step](x)) | |
| if steps == 0: # i.e, image is 4x4 | |
| out = self.minibatch_std(out) | |
| return self.final_block(out).view(out.shape[0], -1) | |
| # because prog_blocks might change the channels, for down scale we use rgb_layer | |
| # from previous/smaller size which in our case correlates to +1 in the indexing | |
| downscaled = self.leaky(self.rgb_layers[cur_step + 1](self.avg_pool(x))) | |
| out = self.avg_pool(self.prog_blocks[cur_step](out)) | |
| # the fade_in is done first between the downscaled and the input | |
| # this is opposite from the generator | |
| out = self.fade_in(alpha, downscaled, out) | |
| for step in range(cur_step + 1, len(self.prog_blocks)): | |
| out = self.prog_blocks[step](out) | |
| out = self.avg_pool(out) | |
| out = self.minibatch_std(out) | |
| return self.final_block(out).view(out.shape[0], -1) | |
| if __name__ == "__main__": | |
| Z_DIM = 100 | |
| IN_CHANNELS = 256 | |
| gen = Generator(Z_DIM, IN_CHANNELS, img_channels=3) | |
| critic = Discriminator(Z_DIM, IN_CHANNELS, img_channels=3) | |
| for img_size in [4, 8, 16, 32, 64, 128, 256, 512, 1024]: | |
| num_steps = int(log2(img_size / 4)) | |
| x = torch.randn((1, Z_DIM, 1, 1)) | |
| z = gen(x, 0.5, steps=num_steps) | |
| assert z.shape == (1, 3, img_size, img_size) | |
| out = critic(z, alpha=0.5, steps=num_steps) | |
| assert out.shape == (1, 1) | |
| print(f"Success! At img size: {img_size}") |