taim-gan / src /models /modules /generator.py
Dmmc's picture
three-model version
c8ddb9b
"""Generator Module"""
from typing import Any, Optional
import torch
from torch import nn
from src.models.modules.acm import ACM
from src.models.modules.attention import ChannelWiseAttention, SpatialAttention
from src.models.modules.cond_augment import CondAugmentation
from src.models.modules.downsample import down_sample
from src.models.modules.residual import ResidualBlock
from src.models.modules.upsample import img_up_block, up_sample
class InitStageG(nn.Module):
"""Initial Stage Generator Module"""
# pylint: disable=too-many-instance-attributes
# pylint: disable=too-many-arguments
# pylint: disable=invalid-name
# pylint: disable=too-many-locals
def __init__(
self, Ng: int, Ng_init: int, conditioning_dim: int, D: int, noise_dim: int
):
"""
:param Ng: Number of channels.
:param Ng_init: Initial value of Ng, this is output channel of first image upsample.
:param conditioning_dim: Dimension of the conditioning space
:param D: Dimension of the text embedding space [D from AttnGAN paper]
:param noise_dim: Dimension of the noise space
"""
super().__init__()
self.gf_dim = Ng
self.gf_init = Ng_init
self.in_dim = noise_dim + conditioning_dim + D
self.text_dim = D
self.define_module()
def define_module(self) -> None:
"""Defines FC, Upsample, Residual, ACM, Attention modules"""
nz, ng = self.in_dim, self.gf_dim
self.fully_connect = nn.Sequential(
nn.Linear(nz, ng * 4 * 4 * 2, bias=False),
nn.BatchNorm1d(ng * 4 * 4 * 2),
nn.GLU(dim=1), # we start from 4 x 4 feat_map and return hidden_64.
)
self.upsample1 = up_sample(ng, ng // 2)
self.upsample2 = up_sample(ng // 2, ng // 4)
self.upsample3 = up_sample(ng // 4, ng // 8)
self.upsample4 = up_sample(
ng // 8 * 3, ng // 16
) # multiply channel by 3 because concat spatial and channel att
self.residual = self._make_layer(ResidualBlock, ng // 8 * 3)
self.acm_module = ACM(self.gf_init, ng // 8 * 3)
self.spatial_att = SpatialAttention(self.text_dim, ng // 8)
self.channel_att = ChannelWiseAttention(
32 * 32, self.text_dim
) # 32 x 32 is the feature map size
def _make_layer(self, block: Any, channel_num: int) -> nn.Module:
layers = []
for _ in range(2): # number of residual blocks hardcoded to 2
layers.append(block(channel_num))
return nn.Sequential(*layers)
def forward(
self,
noise: torch.Tensor,
condition: torch.Tensor,
global_inception: torch.Tensor,
local_upsampled_inception: torch.Tensor,
word_embeddings: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> Any:
"""
:param noise: Noise tensor
:param condition: Condition tensor (c^ from stackGAN++ paper)
:param global_inception: Global inception feature
:param local_upsampled_inception: Local inception feature, upsampled to 32 x 32
:param word_embeddings: Word embeddings [shape: D x L or D x T]
:param mask: Mask for padding tokens
:return: Hidden Image feature map Tensor of 64 x 64 size
"""
noise_concat = torch.cat((noise, condition), 1)
inception_concat = torch.cat((noise_concat, global_inception), 1)
hidden = self.fully_connect(inception_concat)
hidden = hidden.view(-1, self.gf_dim, 4, 4) # convert to 4x4 image feature map
hidden = self.upsample1(hidden)
hidden = self.upsample2(hidden)
hidden_32 = self.upsample3(hidden) # shape: (batch_size, gf_dim // 8, 32, 32)
hidden_32_view = hidden_32.view(
hidden_32.shape[0], -1, hidden_32.shape[2] * hidden_32.shape[3]
) # this reshaping is done as attention module expects this shape.
spatial_att_feat = self.spatial_att(
word_embeddings, hidden_32_view, mask
) # spatial att shape: (batch, D^, 32 * 32)
channel_att_feat = self.channel_att(
spatial_att_feat, word_embeddings
) # channel att shape: (batch, D^, 32 * 32), or (batch, C, Hk* Wk) from controlGAN paper
spatial_att_feat = spatial_att_feat.view(
word_embeddings.shape[0], -1, hidden_32.shape[2], hidden_32.shape[3]
) # reshape to (batch, D^, 32, 32)
channel_att_feat = channel_att_feat.view(
word_embeddings.shape[0], -1, hidden_32.shape[2], hidden_32.shape[3]
) # reshape to (batch, D^, 32, 32)
spatial_concat = torch.cat(
(hidden_32, spatial_att_feat), 1
) # concat spatial attention feature with hidden_32
attn_concat = torch.cat(
(spatial_concat, channel_att_feat), 1
) # concat channel and spatial attention feature
hidden_32 = self.acm_module(attn_concat, local_upsampled_inception)
hidden_32 = self.residual(hidden_32)
hidden_64 = self.upsample4(hidden_32)
return hidden_64
class NextStageG(nn.Module):
"""Next Stage Generator Module"""
# pylint: disable=too-many-instance-attributes
# pylint: disable=too-many-arguments
# pylint: disable=invalid-name
# pylint: disable=too-many-locals
def __init__(self, Ng: int, Ng_init: int, D: int, image_size: int):
"""
:param Ng: Number of channels.
:param Ng_init: Initial value of Ng.
:param D: Dimension of the text embedding space [D from AttnGAN paper]
:param image_size: Size of the output image from previous generator stage.
"""
super().__init__()
self.gf_dim = Ng
self.gf_init = Ng_init
self.text_dim = D
self.img_size = image_size
self.define_module()
def define_module(self) -> None:
"""Defines FC, Upsample, Residual, ACM, Attention modules"""
ng = self.gf_dim
self.spatial_att = SpatialAttention(self.text_dim, ng)
self.channel_att = ChannelWiseAttention(
self.img_size * self.img_size, self.text_dim
)
self.residual = self._make_layer(ResidualBlock, ng * 3)
self.upsample = up_sample(ng * 3, ng)
self.acm_module = ACM(self.gf_init, ng * 3)
self.upsample2 = up_sample(ng, ng)
def _make_layer(self, block: Any, channel_num: int) -> nn.Module:
layers = []
for _ in range(2): # no of residual layers hardcoded to 2
layers.append(block(channel_num))
return nn.Sequential(*layers)
def forward(
self,
hidden_feat: Any,
word_embeddings: torch.Tensor,
vgg64_feat: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> Any:
"""
:param hidden_feat: Hidden feature from previous generator stage [i.e. hidden_64]
:param word_embeddings: Word embeddings
:param vgg64_feat: VGG feature map of size 64 x 64
:param mask: Mask for the padding tokens
:return: Image feature map of size 256 x 256
"""
hidden_view = hidden_feat.view(
hidden_feat.shape[0], -1, hidden_feat.shape[2] * hidden_feat.shape[3]
) # reshape to pass into attention modules.
spatial_att_feat = self.spatial_att(
word_embeddings, hidden_view, mask
) # spatial att shape: (batch, D^, 64 * 64), or D^ x N
channel_att_feat = self.channel_att(
spatial_att_feat, word_embeddings
) # channel att shape: (batch, D^, 64 * 64), or (batch, C, Hk* Wk) from controlGAN paper
spatial_att_feat = spatial_att_feat.view(
word_embeddings.shape[0], -1, hidden_feat.shape[2], hidden_feat.shape[3]
) # reshape to (batch, D^, 64, 64)
channel_att_feat = channel_att_feat.view(
word_embeddings.shape[0], -1, hidden_feat.shape[2], hidden_feat.shape[3]
) # reshape to (batch, D^, 64, 64)
spatial_concat = torch.cat(
(hidden_feat, spatial_att_feat), 1
) # concat spatial attention feature with hidden_64
attn_concat = torch.cat(
(spatial_concat, channel_att_feat), 1
) # concat channel and spatial attention feature
hidden_64 = self.acm_module(attn_concat, vgg64_feat)
hidden_64 = self.residual(hidden_64)
hidden_128 = self.upsample(hidden_64)
hidden_256 = self.upsample2(hidden_128)
return hidden_256
class GetImageG(nn.Module):
"""Generates the Final Fake Image from the Image Feature Map"""
def __init__(self, Ng: int):
"""
:param Ng: Number of channels.
"""
super().__init__()
self.img = nn.Sequential(
nn.Conv2d(Ng, 3, kernel_size=3, stride=1, padding=1, bias=False), nn.Tanh()
)
def forward(self, hidden_feat: torch.Tensor) -> Any:
"""
:param hidden_feat: Image feature map
:return: Final fake image
"""
return self.img(hidden_feat)
class Generator(nn.Module):
"""Generator Module"""
# pylint: disable=too-many-instance-attributes
# pylint: disable=too-many-arguments
# pylint: disable=invalid-name
# pylint: disable=too-many-locals
def __init__(self, Ng: int, D: int, conditioning_dim: int, noise_dim: int):
"""
:param Ng: Number of channels. [Taken from StackGAN++ paper]
:param D: Dimension of the text embedding space
:param conditioning_dim: Dimension of the conditioning space
:param noise_dim: Dimension of the noise space
"""
super().__init__()
self.cond_augment = CondAugmentation(D, conditioning_dim)
self.hidden_net1 = InitStageG(Ng * 16, Ng, conditioning_dim, D, noise_dim)
self.inception_img_upsample = img_up_block(
D, Ng
) # as channel size returned by inception encoder is D (Default in paper: 256)
self.hidden_net2 = NextStageG(Ng, Ng, D, 64)
self.generate_img = GetImageG(Ng)
self.acm_module = ACM(Ng, Ng)
self.vgg_downsample = down_sample(D // 2, Ng)
self.upsample1 = up_sample(Ng, Ng)
self.upsample2 = up_sample(Ng, Ng)
def forward(
self,
noise: torch.Tensor,
sentence_embeddings: torch.Tensor,
word_embeddings: torch.Tensor,
global_inception_feat: torch.Tensor,
local_inception_feat: torch.Tensor,
vgg_feat: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> Any:
"""
:param noise: Noise vector [shape: (batch, noise_dim)]
:param sentence_embeddings: Sentence embeddings [shape: (batch, D)]
:param word_embeddings: Word embeddings [shape: D x L, where L is length of sentence]
:param global_inception_feat: Global Inception feature map [shape: (batch, D)]
:param local_inception_feat: Local Inception feature map [shape: (batch, D, 17, 17)]
:param vgg_feat: VGG feature map [shape: (batch, D // 2 = 128, 128, 128)]
:param mask: Mask for the padding tokens
:return: Final fake image
"""
c_hat, mu_tensor, logvar = self.cond_augment(sentence_embeddings)
hidden_32 = self.inception_img_upsample(local_inception_feat)
hidden_64 = self.hidden_net1(
noise, c_hat, global_inception_feat, hidden_32, word_embeddings, mask
)
vgg_64 = self.vgg_downsample(vgg_feat)
hidden_256 = self.hidden_net2(hidden_64, word_embeddings, vgg_64, mask)
vgg_128 = self.upsample1(vgg_64)
vgg_256 = self.upsample2(vgg_128)
hidden_256 = self.acm_module(hidden_256, vgg_256)
fake_img = self.generate_img(hidden_256)
return fake_img, mu_tensor, logvar