File size: 11,852 Bytes
c8ddb9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 |
"""Generator Module"""
from typing import Any, Optional
import torch
from torch import nn
from src.models.modules.acm import ACM
from src.models.modules.attention import ChannelWiseAttention, SpatialAttention
from src.models.modules.cond_augment import CondAugmentation
from src.models.modules.downsample import down_sample
from src.models.modules.residual import ResidualBlock
from src.models.modules.upsample import img_up_block, up_sample
class InitStageG(nn.Module):
"""Initial Stage Generator Module"""
# pylint: disable=too-many-instance-attributes
# pylint: disable=too-many-arguments
# pylint: disable=invalid-name
# pylint: disable=too-many-locals
def __init__(
self, Ng: int, Ng_init: int, conditioning_dim: int, D: int, noise_dim: int
):
"""
:param Ng: Number of channels.
:param Ng_init: Initial value of Ng, this is output channel of first image upsample.
:param conditioning_dim: Dimension of the conditioning space
:param D: Dimension of the text embedding space [D from AttnGAN paper]
:param noise_dim: Dimension of the noise space
"""
super().__init__()
self.gf_dim = Ng
self.gf_init = Ng_init
self.in_dim = noise_dim + conditioning_dim + D
self.text_dim = D
self.define_module()
def define_module(self) -> None:
"""Defines FC, Upsample, Residual, ACM, Attention modules"""
nz, ng = self.in_dim, self.gf_dim
self.fully_connect = nn.Sequential(
nn.Linear(nz, ng * 4 * 4 * 2, bias=False),
nn.BatchNorm1d(ng * 4 * 4 * 2),
nn.GLU(dim=1), # we start from 4 x 4 feat_map and return hidden_64.
)
self.upsample1 = up_sample(ng, ng // 2)
self.upsample2 = up_sample(ng // 2, ng // 4)
self.upsample3 = up_sample(ng // 4, ng // 8)
self.upsample4 = up_sample(
ng // 8 * 3, ng // 16
) # multiply channel by 3 because concat spatial and channel att
self.residual = self._make_layer(ResidualBlock, ng // 8 * 3)
self.acm_module = ACM(self.gf_init, ng // 8 * 3)
self.spatial_att = SpatialAttention(self.text_dim, ng // 8)
self.channel_att = ChannelWiseAttention(
32 * 32, self.text_dim
) # 32 x 32 is the feature map size
def _make_layer(self, block: Any, channel_num: int) -> nn.Module:
layers = []
for _ in range(2): # number of residual blocks hardcoded to 2
layers.append(block(channel_num))
return nn.Sequential(*layers)
def forward(
self,
noise: torch.Tensor,
condition: torch.Tensor,
global_inception: torch.Tensor,
local_upsampled_inception: torch.Tensor,
word_embeddings: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> Any:
"""
:param noise: Noise tensor
:param condition: Condition tensor (c^ from stackGAN++ paper)
:param global_inception: Global inception feature
:param local_upsampled_inception: Local inception feature, upsampled to 32 x 32
:param word_embeddings: Word embeddings [shape: D x L or D x T]
:param mask: Mask for padding tokens
:return: Hidden Image feature map Tensor of 64 x 64 size
"""
noise_concat = torch.cat((noise, condition), 1)
inception_concat = torch.cat((noise_concat, global_inception), 1)
hidden = self.fully_connect(inception_concat)
hidden = hidden.view(-1, self.gf_dim, 4, 4) # convert to 4x4 image feature map
hidden = self.upsample1(hidden)
hidden = self.upsample2(hidden)
hidden_32 = self.upsample3(hidden) # shape: (batch_size, gf_dim // 8, 32, 32)
hidden_32_view = hidden_32.view(
hidden_32.shape[0], -1, hidden_32.shape[2] * hidden_32.shape[3]
) # this reshaping is done as attention module expects this shape.
spatial_att_feat = self.spatial_att(
word_embeddings, hidden_32_view, mask
) # spatial att shape: (batch, D^, 32 * 32)
channel_att_feat = self.channel_att(
spatial_att_feat, word_embeddings
) # channel att shape: (batch, D^, 32 * 32), or (batch, C, Hk* Wk) from controlGAN paper
spatial_att_feat = spatial_att_feat.view(
word_embeddings.shape[0], -1, hidden_32.shape[2], hidden_32.shape[3]
) # reshape to (batch, D^, 32, 32)
channel_att_feat = channel_att_feat.view(
word_embeddings.shape[0], -1, hidden_32.shape[2], hidden_32.shape[3]
) # reshape to (batch, D^, 32, 32)
spatial_concat = torch.cat(
(hidden_32, spatial_att_feat), 1
) # concat spatial attention feature with hidden_32
attn_concat = torch.cat(
(spatial_concat, channel_att_feat), 1
) # concat channel and spatial attention feature
hidden_32 = self.acm_module(attn_concat, local_upsampled_inception)
hidden_32 = self.residual(hidden_32)
hidden_64 = self.upsample4(hidden_32)
return hidden_64
class NextStageG(nn.Module):
"""Next Stage Generator Module"""
# pylint: disable=too-many-instance-attributes
# pylint: disable=too-many-arguments
# pylint: disable=invalid-name
# pylint: disable=too-many-locals
def __init__(self, Ng: int, Ng_init: int, D: int, image_size: int):
"""
:param Ng: Number of channels.
:param Ng_init: Initial value of Ng.
:param D: Dimension of the text embedding space [D from AttnGAN paper]
:param image_size: Size of the output image from previous generator stage.
"""
super().__init__()
self.gf_dim = Ng
self.gf_init = Ng_init
self.text_dim = D
self.img_size = image_size
self.define_module()
def define_module(self) -> None:
"""Defines FC, Upsample, Residual, ACM, Attention modules"""
ng = self.gf_dim
self.spatial_att = SpatialAttention(self.text_dim, ng)
self.channel_att = ChannelWiseAttention(
self.img_size * self.img_size, self.text_dim
)
self.residual = self._make_layer(ResidualBlock, ng * 3)
self.upsample = up_sample(ng * 3, ng)
self.acm_module = ACM(self.gf_init, ng * 3)
self.upsample2 = up_sample(ng, ng)
def _make_layer(self, block: Any, channel_num: int) -> nn.Module:
layers = []
for _ in range(2): # no of residual layers hardcoded to 2
layers.append(block(channel_num))
return nn.Sequential(*layers)
def forward(
self,
hidden_feat: Any,
word_embeddings: torch.Tensor,
vgg64_feat: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> Any:
"""
:param hidden_feat: Hidden feature from previous generator stage [i.e. hidden_64]
:param word_embeddings: Word embeddings
:param vgg64_feat: VGG feature map of size 64 x 64
:param mask: Mask for the padding tokens
:return: Image feature map of size 256 x 256
"""
hidden_view = hidden_feat.view(
hidden_feat.shape[0], -1, hidden_feat.shape[2] * hidden_feat.shape[3]
) # reshape to pass into attention modules.
spatial_att_feat = self.spatial_att(
word_embeddings, hidden_view, mask
) # spatial att shape: (batch, D^, 64 * 64), or D^ x N
channel_att_feat = self.channel_att(
spatial_att_feat, word_embeddings
) # channel att shape: (batch, D^, 64 * 64), or (batch, C, Hk* Wk) from controlGAN paper
spatial_att_feat = spatial_att_feat.view(
word_embeddings.shape[0], -1, hidden_feat.shape[2], hidden_feat.shape[3]
) # reshape to (batch, D^, 64, 64)
channel_att_feat = channel_att_feat.view(
word_embeddings.shape[0], -1, hidden_feat.shape[2], hidden_feat.shape[3]
) # reshape to (batch, D^, 64, 64)
spatial_concat = torch.cat(
(hidden_feat, spatial_att_feat), 1
) # concat spatial attention feature with hidden_64
attn_concat = torch.cat(
(spatial_concat, channel_att_feat), 1
) # concat channel and spatial attention feature
hidden_64 = self.acm_module(attn_concat, vgg64_feat)
hidden_64 = self.residual(hidden_64)
hidden_128 = self.upsample(hidden_64)
hidden_256 = self.upsample2(hidden_128)
return hidden_256
class GetImageG(nn.Module):
"""Generates the Final Fake Image from the Image Feature Map"""
def __init__(self, Ng: int):
"""
:param Ng: Number of channels.
"""
super().__init__()
self.img = nn.Sequential(
nn.Conv2d(Ng, 3, kernel_size=3, stride=1, padding=1, bias=False), nn.Tanh()
)
def forward(self, hidden_feat: torch.Tensor) -> Any:
"""
:param hidden_feat: Image feature map
:return: Final fake image
"""
return self.img(hidden_feat)
class Generator(nn.Module):
"""Generator Module"""
# pylint: disable=too-many-instance-attributes
# pylint: disable=too-many-arguments
# pylint: disable=invalid-name
# pylint: disable=too-many-locals
def __init__(self, Ng: int, D: int, conditioning_dim: int, noise_dim: int):
"""
:param Ng: Number of channels. [Taken from StackGAN++ paper]
:param D: Dimension of the text embedding space
:param conditioning_dim: Dimension of the conditioning space
:param noise_dim: Dimension of the noise space
"""
super().__init__()
self.cond_augment = CondAugmentation(D, conditioning_dim)
self.hidden_net1 = InitStageG(Ng * 16, Ng, conditioning_dim, D, noise_dim)
self.inception_img_upsample = img_up_block(
D, Ng
) # as channel size returned by inception encoder is D (Default in paper: 256)
self.hidden_net2 = NextStageG(Ng, Ng, D, 64)
self.generate_img = GetImageG(Ng)
self.acm_module = ACM(Ng, Ng)
self.vgg_downsample = down_sample(D // 2, Ng)
self.upsample1 = up_sample(Ng, Ng)
self.upsample2 = up_sample(Ng, Ng)
def forward(
self,
noise: torch.Tensor,
sentence_embeddings: torch.Tensor,
word_embeddings: torch.Tensor,
global_inception_feat: torch.Tensor,
local_inception_feat: torch.Tensor,
vgg_feat: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> Any:
"""
:param noise: Noise vector [shape: (batch, noise_dim)]
:param sentence_embeddings: Sentence embeddings [shape: (batch, D)]
:param word_embeddings: Word embeddings [shape: D x L, where L is length of sentence]
:param global_inception_feat: Global Inception feature map [shape: (batch, D)]
:param local_inception_feat: Local Inception feature map [shape: (batch, D, 17, 17)]
:param vgg_feat: VGG feature map [shape: (batch, D // 2 = 128, 128, 128)]
:param mask: Mask for the padding tokens
:return: Final fake image
"""
c_hat, mu_tensor, logvar = self.cond_augment(sentence_embeddings)
hidden_32 = self.inception_img_upsample(local_inception_feat)
hidden_64 = self.hidden_net1(
noise, c_hat, global_inception_feat, hidden_32, word_embeddings, mask
)
vgg_64 = self.vgg_downsample(vgg_feat)
hidden_256 = self.hidden_net2(hidden_64, word_embeddings, vgg_64, mask)
vgg_128 = self.upsample1(vgg_64)
vgg_256 = self.upsample2(vgg_128)
hidden_256 = self.acm_module(hidden_256, vgg_256)
fake_img = self.generate_img(hidden_256)
return fake_img, mu_tensor, logvar
|