from abc import abstractmethod from functools import partial import math from typing import Iterable import numpy as np import torch as th #from .utils_pos_embedding.pos_embed import RoPE2D import torch.nn as nn import torch.nn.functional as F import sys from fairscale.nn.model_parallel.layers import ( ColumnParallelLinear, ParallelEmbedding, RowParallelLinear, ) from timm.models.layers import DropPath from .utils import auto_grad_checkpoint, to_2tuple from .PixArt_blocks import t2i_modulate, WindowAttention, MultiHeadCrossAttention, T2IFinalLayer, TimestepEmbedder, FinalLayer import xformers.ops import math class PatchEmbed(nn.Module): """ 2D Image to Patch Embedding """ def __init__( self, img_size=(256, 16), patch_size=(16, 4), overlap = (0, 0), in_chans=128, embed_dim=768, norm_layer=None, flatten=True, bias=True, ): super().__init__() # img_size=(256, 16) # patch_size=(16, 4) # overlap = (2, 2) # in_chans=128 # embed_dim=768 # import pdb # pdb.set_trace() self.img_size = img_size self.patch_size = patch_size self.ol = overlap self.grid_size = (math.ceil((img_size[0] - patch_size[0]) / (patch_size[0]-overlap[0])) + 1, math.ceil((img_size[1] - patch_size[1]) / (patch_size[1]-overlap[1])) + 1) self.pad_size = ((self.grid_size[0]-1) * (self.patch_size[0]-overlap[0])+self.patch_size[0]-self.img_size[0], +(self.grid_size[1]-1)*(self.patch_size[1]-overlap[1])+self.patch_size[1]-self.img_size[1]) self.pad_size = (self.pad_size[0] // 2, self.pad_size[1] // 2) # self.p-ad_size = (((img_size[0] - patch_size[0]) // ), ) self.num_patches = self.grid_size[0] * self.grid_size[1] self.flatten = flatten self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0]-overlap[0], patch_size[1]-overlap[1]), bias=bias) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x): # B, C, H, W = x.shape # _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") # _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") x = F.pad(x, (self.pad_size[-1], self.pad_size[-1], self.pad_size[-2], self.pad_size[-2]), "constant", 0) x = self.proj(x) if self.flatten: x = x.flatten(2).transpose(1, 2) # BCHW -> BNC x = self.norm(x) return x class PatchEmbed_1D(nn.Module): def __init__( self, img_size=(256, 16), # patch_size=(16, 4), # overlap = (0, 0), in_chans=8, embed_dim=1152, norm_layer=None, # flatten=True, bias=True, ): super().__init__() self.proj = nn.Linear(in_chans*img_size[1], embed_dim, bias=bias) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x): # B, C, H, W = x.shape # _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") # _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") # x = F.pad(x, (self.pad_size[-1], self.pad_size[-1], self.pad_size[-2], self.pad_size[-2]), "constant", 0) # x = self.proj(x) # if self.flatten: # x = x.flatten(2).transpose(1, 2) # BCHW -> BNC x = th.einsum('bctf->btfc', x) x = x.flatten(2) # BTFC -> BTD x = self.proj(x) x = self.norm(x) return x # if __name__ == '__main__': # x = th.rand(1, 256, 16).unsqueeze(0) # model = PatchEmbed(in_chans=1) # y = model(x) from timm.models.vision_transformer import Attention, Mlp def modulate(x, shift, scale): return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) from positional_encodings.torch_encodings import PositionalEncoding1D def t2i_modulate(x, shift, scale): return x * (1 + scale) + shift class PixArtBlock(nn.Module): """ A PixArt block with adaptive layer norm (adaLN-single) conditioning. """ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., window_size=0, input_size=None, use_rel_pos=False, **block_kwargs): super().__init__() self.hidden_size = hidden_size self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.attn = WindowAttention(hidden_size, num_heads=num_heads, qkv_bias=True, input_size=input_size if window_size == 0 else (window_size, window_size), use_rel_pos=use_rel_pos, **block_kwargs) self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) # to be compatible with lower version pytorch approx_gelu = lambda: nn.GELU(approximate="tanh") self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.window_size = window_size self.scale_shift_table = nn.Parameter(th.randn(6, hidden_size) / hidden_size ** 0.5) def forward(self, x, y, t, mask=None, **kwargs): B, N, C = x.shape # x [B, T, D] shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1) x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C)) x = x + self.cross_attn(x, y, mask) x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) return x from qa_mdt.audioldm_train.modules.diffusionmodules.attention import CrossAttention_1D class PixArtBlock_Slow(nn.Module): """ A PixArt block with adaptive layer norm (adaLN-single) conditioning. """ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., window_size=0, input_size=None, use_rel_pos=False, **block_kwargs): super().__init__() self.hidden_size = hidden_size self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.attn = CrossAttention_1D(query_dim=hidden_size, context_dim=hidden_size, heads=num_heads, dim_head=int(hidden_size/num_heads)) self.cross_attn = CrossAttention_1D(query_dim=hidden_size, context_dim=hidden_size, heads=num_heads, dim_head=int(hidden_size/num_heads)) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) # to be compatible with lower version pytorch approx_gelu = lambda: nn.GELU(approximate="tanh") self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.window_size = window_size self.scale_shift_table = nn.Parameter(th.randn(6, hidden_size) / hidden_size ** 0.5) def forward(self, x, y, t, mask=None, **kwargs): B, N, C = x.shape # x [B, T, D] shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1) x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C)) x = x + self.cross_attn(x, y, mask) x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) return x class PixArt(nn.Module): """ Diffusion model with a Transformer backbone. """ def __init__(self, input_size=(256,16), patch_size=(16,4), overlap=(0, 0), in_channels=8, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, pred_sigma=True, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, cond_dim=1024, lewei_scale=1.0, use_cfg=True, cfg_scale=4.0, config=None, model_max_length=120, **kwargs): if window_block_indexes is None: window_block_indexes = [] super().__init__() self.use_cfg = use_cfg self.cfg_scale = cfg_scale self.input_size = input_size self.pred_sigma = pred_sigma self.in_channels = in_channels self.out_channels = in_channels * 2 if pred_sigma else in_channels self.patch_size = patch_size self.num_heads = num_heads self.lewei_scale = lewei_scale, self.x_embedder = PatchEmbed(input_size, patch_size, overlap, in_channels, hidden_size, bias=True) # self.x_embedder = PatchEmbed_1D(input) self.t_embedder = TimestepEmbedder(hidden_size) num_patches = self.x_embedder.num_patches self.base_size = input_size[0] // self.patch_size[0] * 2 # Will use fixed sin-cos embedding: self.register_buffer("pos_embed", th.zeros(1, num_patches, hidden_size)) approx_gelu = lambda: nn.GELU(approximate="tanh") self.t_block = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True) ) self.y_embedder = nn.Linear(cond_dim, hidden_size) drop_path = [x.item() for x in th.linspace(0, drop_path, depth)] # stochastic depth decay rule self.blocks = nn.ModuleList([ PixArtBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i], input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), window_size=0, use_rel_pos=False) for i in range(depth) ]) self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels) self.initialize_weights() # if config: # logger = get_root_logger(os.path.join(config.work_dir, 'train_log.log')) # logger.warning(f"lewei scale: {self.lewei_scale}, base size: {self.base_size}") # else: # print(f'Warning: lewei scale: {self.lewei_scale}, base size: {self.base_size}') def forward(self, x, timestep, context_list, context_mask_list=None, **kwargs): """ Forward pass of PixArt. x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) t: (N,) tensor of diffusion timesteps y: (N, 1, 120, C) tensor of class labels """ x = x.to(self.dtype) timestep = timestep.to(self.dtype) y = context_list[0].to(self.dtype) pos_embed = self.pos_embed.to(self.dtype) self.h, self.w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1] x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2 t = self.t_embedder(timestep.to(x.dtype)) # (N, D) t0 = self.t_block(t) y = self.y_embedder(y) # (N, L, D) mask = context_mask_list[0] # (N, L) assert mask is not None # if mask is not None: y = y.masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) y_lens = mask.sum(dim=1).tolist() y_lens = [int(_) for _ in y_lens] for block in self.blocks: x = auto_grad_checkpoint(block, x, y, t0, y_lens) # (N, T, D) #support grad checkpoint x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) x = self.unpatchify(x) # (N, out_channels, H, W) return x def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs): """ dpm solver donnot need variance prediction """ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb model_out = self.forward(x, timestep, y, mask) return model_out.chunk(2, dim=1)[0] def forward_with_cfg(self, x, timestep, y, cfg_scale, mask=None, **kwargs): """ Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance. """ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb half = x[: len(x) // 2] combined = th.cat([half, half], dim=0) model_out = self.forward(combined, timestep, y, mask) model_out = model_out['x'] if isinstance(model_out, dict) else model_out eps, rest = model_out[:, :8], model_out[:, 8:] cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0) half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) eps = th.cat([half_eps, half_eps], dim=0) return eps # return th.cat([eps, rest], dim=1) def unpatchify(self, x): """ x: (N, T, patch_size 0 * patch_size 1 * C) imgs: (Bs. 256. 16. 8) """ # torch_map = th.zeros(self.x_embedder.img_size[0]+2*self.x_embedder.pad_size[0], # self.x_embedder.img_size[1]+2*self.x_embedder.pad_size[1]).to(x.device) # lf = self.x_embedder.grid_size[0] # rf = self.x_embedder.grid_size[1] # for i in range(lf): # for j in range(rf): # xx = (i) * (self.x_embedder.patch_size[0]-self.x_embedder.ol[0]) # yy = (j) * (self.x_embedder.patch_size[1]-self.x_embedder.ol[1]) # torch_map[xx:(xx+self.x_embedder.patch_size[0]), yy:(yy+self.x_embedder.patch_size[1])]+=1 # torch_map = torch_map[self.x_embedder.pad_size[0]:self.x_embedder.pad_size[0]+self.x_embedder.img_size[0], # self.x_embedder.pad_size[1]:self.x_embedder.pad_size[1]+self.x_embedder.img_size[1]] # torch_map = th.reciprocal(torch_map) # c = self.out_channels # p0, p1 = self.x_embedder.patch_size[0], self.x_embedder.patch_size[1] # x = x.reshape(shape=(x.shape[0], self.x_embedder.grid_size[0], # self.x_embedder.grid_size[1], p0, p1, c)) # x = th.einsum('nhwpqc->nchwpq', x) # added_map = th.zeros(x.shape[0], c, # self.x_embedder.img_size[0]+2*self.x_embedder.pad_size[0], # self.x_embedder.img_size[1]+2*self.x_embedder.pad_size[1]).to(x.device) # for b_id in range(x.shape[0]): # for i in range(lf): # for j in range(rf): # for c_id in range(c): # xx = (i) * (self.x_embedder.patch_size[0]-self.x_embedder.ol[0]) # yy = (j) * (self.x_embedder.patch_size[1]-self.x_embedder.ol[1]) # added_map[b_id][c_id][xx:(xx+self.x_embedder.patch_size[0]), yy:(yy+self.x_embedder.patch_size[1])] += \ # x[b_id, c_id, i, j] # ret_map = th.zeros(x.shape[0], c, self.x_embedder.img_size[0], # self.x_embedder.img_size[1]).to(x.device) # for b_id in range(x.shape[0]): # for id_c in range(c): # ret_map[b_id, id_c, :, :] = th.mul(added_map[b_id][id_c][self.x_embedder.pad_size[0]:self.x_embedder.pad_size[0]+self.x_embedder.img_size[0], # self.x_embedder.pad_size[1]:self.x_embedder.pad_size[1]+self.x_embedder.img_size[1]], torch_map) c = self.out_channels p0 = self.x_embedder.patch_size[0] p1 = self.x_embedder.patch_size[1] h, w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1] # h = w = int(x.shape[1] ** 0.5) # print(x.shape, h, w, p0, p1) # import pdb # pdb.set_trace() x = x.reshape(shape=(x.shape[0], h, w, p0, p1, c)) x = th.einsum('nhwpqc->nchpwq', x) imgs = x.reshape(shape=(x.shape[0], c, h * p0, w * p1)) return imgs def initialize_weights(self): # Initialize transformer layers: def _basic_init(module): if isinstance(module, nn.Linear): th.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) self.apply(_basic_init) # Initialize (and freeze) pos_embed by sin-cos embedding: pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size, lewei_scale=self.lewei_scale, base_size=self.base_size) self.pos_embed.data.copy_(th.from_numpy(pos_embed).float().unsqueeze(0)) # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): w = self.x_embedder.proj.weight.data nn.init.xavier_uniform_(w.view([w.shape[0], -1])) # Initialize timestep embedding MLP: nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) nn.init.normal_(self.t_block[1].weight, std=0.02) # Initialize caption embedding MLP: nn.init.normal_(self.y_embedder.weight, std=0.02) # nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02) # Zero-out adaLN modulation layers in PixArt blocks: for block in self.blocks: nn.init.constant_(block.cross_attn.proj.weight, 0) nn.init.constant_(block.cross_attn.proj.bias, 0) # Zero-out output layers: nn.init.constant_(self.final_layer.linear.weight, 0) nn.init.constant_(self.final_layer.linear.bias, 0) @property def dtype(self): return next(self.parameters()).dtype class SwiGLU(nn.Module): def __init__( self, dim: int, hidden_dim: int, multiple_of: int, ): super().__init__() hidden_dim = int(2 * hidden_dim / 3) hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) self.w1 = ColumnParallelLinear( dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x ) self.w2 = RowParallelLinear( hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x ) self.w3 = ColumnParallelLinear( dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x ) def forward(self, x): return self.w2(F.silu(self.w1(x)) * self.w3(x)) class DEBlock(nn.Module): """ Decoder block with added SpecTNT transformer """ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, FFN_type='SwiGLU', drop_path=0., window_size=0, input_size=None, use_rel_pos=False, skip=False, num_f=None, num_t=None, **block_kwargs): super().__init__() self.hidden_size = hidden_size self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.attn = WindowAttention(hidden_size, num_heads=num_heads, qkv_bias=True, input_size=input_size if window_size == 0 else (window_size, window_size), use_rel_pos=use_rel_pos, **block_kwargs) self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs) # self.cross_attn_f = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs) # self.cross_attn_t = MultiHeadCrossAttention(hidden_size*num_f, num_heads, **block_kwargs) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.norm4 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.norm5 = nn.LayerNorm(hidden_size * num_f, elementwise_affine=False, eps=1e-6) self.norm6 = nn.LayerNorm(hidden_size * num_f, elementwise_affine=False, eps=1e-6) # to be compatible with lower version pytorch approx_gelu = lambda: nn.GELU(approximate="tanh") if FFN_type == 'mlp': self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0) # self.mlp2 = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0) # self.mlp3 = Mlp(in_features=hidden_size*num_f, hidden_features=int(hidden_size*num_f * mlp_ratio), act_layer=approx_gelu, drop=0) elif FFN_type == 'SwiGLU': self.mlp = SwiGLU(hidden_size, int(hidden_size * mlp_ratio), 1) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.window_size = window_size self.scale_shift_table = nn.Parameter(th.randn(6, hidden_size) / hidden_size ** 0.5) # self.scale_shift_table_2 = nn.Parameter(th.randn(6, hidden_size) / hidden_size ** 0.5) # self.scale_shift_table_3 = nn.Parameter(th.randn(6, hidden_size) / hidden_size ** 0.5) self.skip_linear = nn.Linear(2 * hidden_size, hidden_size) if skip else None self.F_transformer = WindowAttention(hidden_size, num_heads=4, qkv_bias=True, input_size=input_size if window_size == 0 else (window_size, window_size), use_rel_pos=use_rel_pos, **block_kwargs) self.T_transformer = WindowAttention(hidden_size * num_f, num_heads=16, qkv_bias=True, input_size=input_size if window_size == 0 else (window_size, window_size), use_rel_pos=use_rel_pos, **block_kwargs) self.f_pos = nn.Embedding(num_f, hidden_size) self.t_pos = nn.Embedding(num_t, hidden_size * num_f) self.num_f = num_f self.num_t = num_t def forward(self, x_normal, end, y, t, mask=None, skip=None, ids_keep=None, **kwargs): # import pdb # pdb.set_trace() B, D, C = x_normal.shape T = self.num_t F_add_1 = self.num_f # B, T, F_add_1, C = x.shape # F_add_1 = F_add_1 + 1 # x_normal = th.reshape() # # x_end [B, T, 1, C] # x_end = x[:, :, -1, :].unsqueeze(2) if self.skip_linear is not None: x_normal = self.skip_linear(th.cat([x_normal, skip], dim=-1)) D = T * (F_add_1 - 1) # x_normal [B, D, C] # import pdb # pdb.set_trace() shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1) x_normal = x_normal + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x_normal), shift_msa, scale_msa)).reshape(B, D, C)) x_normal = x_normal.reshape(B, T, F_add_1-1, C) x_normal = th.cat((x_normal, end), 2) # x_normal [B*T, F+1, C] x_normal = x_normal.reshape(B*T, F_add_1, C) pos_f = th.arange(self.num_f, device=x.device).unsqueeze(0).expand(B*T, -1) # import pdb; pdb.set_trace() x_normal = x_normal + self.f_pos(pos_f) x_normal = x_normal + self.F_transformer(self.norm3(x_normal)) # x_normal = x_normal + self.cross_attn_f(x_normal, y, mask) # x_normal = x_normal + self.mlp2(self.norm4(x_normal)) # x_normal [B, T, (F+1) * C] x_normal = x_normal.reshape(B, T, F_add_1 * C) pos_t = th.arange(self.num_t, device=x.device).unsqueeze(0).expand(B, -1) x_normal = x_normal + self.t_pos(pos_t) x_normal = x_normal + self.T_transformer(self.norm5(x_normal)) # x_normal = x_normal + self.cross_attn_t(x_normal, y, mask) x_normal = x_normal.reshape(B, T ,F_add_1, C) end = x_normal[:, :, -1, :].unsqueeze(2) x_normal = x_normal[:, :, :-1, :] x_normal = x_normal.reshape(B, T*(F_add_1 - 1), C) x_normal = x_normal + self.cross_attn(x_normal, y, mask) x_normal = x_normal + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x_normal), shift_mlp, scale_mlp))) # x_normal = th.cat return x_normal, end #.reshape(B, ) class MDTBlock(nn.Module): """ A PixArt block with adaptive layer norm (adaLN-single) conditioning. """ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, FFN_type='mlp', drop_path=0., window_size=0, input_size=None, use_rel_pos=False, skip=False, **block_kwargs): super().__init__() self.hidden_size = hidden_size self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.attn = WindowAttention(hidden_size, num_heads=num_heads, qkv_bias=True, input_size=input_size if window_size == 0 else (window_size, window_size), use_rel_pos=use_rel_pos, **block_kwargs) self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) # to be compatible with lower version pytorch approx_gelu = lambda: nn.GELU(approximate="tanh") if FFN_type == 'mlp': self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0) elif FFN_type == 'SwiGLU': self.mlp = SwiGLU(hidden_size, int(hidden_size * mlp_ratio), 1) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.window_size = window_size self.scale_shift_table = nn.Parameter(th.randn(6, hidden_size) / hidden_size ** 0.5) self.skip_linear = nn.Linear(2 * hidden_size, hidden_size) if skip else None def forward(self, x, y, t, mask=None, skip=None, ids_keep=None, **kwargs): B, N, C = x.shape if self.skip_linear is not None: x = self.skip_linear(th.cat([x, skip], dim=-1)) # x [B, T, D] shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1) x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C)) x = x + self.cross_attn(x, y, mask) x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) return x class PixArt_MDT_MASK_TF(nn.Module): """ Diffusion model with a Transformer backbone. """ def __init__(self, input_size=(256,16), patch_size=(16,4), overlap=(0, 0), in_channels=8, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, pred_sigma=False, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, cond_dim=1024, lewei_scale=1.0, use_cfg=True, cfg_scale=4.0, config=None, model_max_length=120, mask_t=0.17, mask_f=0.15, decode_layer=4,**kwargs): if window_block_indexes is None: window_block_indexes = [] super().__init__() self.use_cfg = use_cfg self.cfg_scale = cfg_scale self.input_size = input_size self.pred_sigma = pred_sigma self.in_channels = in_channels self.out_channels = in_channels self.patch_size = patch_size self.num_heads = num_heads self.lewei_scale = lewei_scale, decode_layer = int(decode_layer) self.x_embedder = PatchEmbed(input_size, patch_size, overlap, in_channels, hidden_size, bias=True) # self.x_embedder = PatchEmbed_1D(input) self.t_embedder = TimestepEmbedder(hidden_size) num_patches = self.x_embedder.num_patches self.base_size = input_size[0] // self.patch_size[0] * 2 # Will use fixed sin-cos embedding: self.register_buffer("pos_embed", th.zeros(1, num_patches, hidden_size)) approx_gelu = lambda: nn.GELU(approximate="tanh") self.t_block = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True) ) self.y_embedder = nn.Linear(cond_dim, hidden_size) half_depth = (depth - decode_layer)//2 self.half_depth=half_depth drop_path_half = [x.item() for x in th.linspace(0, drop_path, half_depth)] # stochastic depth decay rule drop_path_decode = [x.item() for x in th.linspace(0, drop_path, decode_layer)] self.en_inblocks = nn.ModuleList([ MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i], input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), window_size=0, use_rel_pos=False, FFN_type='mlp') for i in range(half_depth) ]) self.en_outblocks = nn.ModuleList([ MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i], input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), window_size=0, use_rel_pos=False, skip=True, FFN_type='mlp') for i in range(half_depth) ]) self.de_blocks = nn.ModuleList([ MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_decode[i], input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), window_size=0, use_rel_pos=False, skip=True, FFN_type='mlp') for i in range(decode_layer) ]) self.sideblocks = nn.ModuleList([ MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), window_size=0, use_rel_pos=False, FFN_type='mlp') for _ in range(1) ]) self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels) self.decoder_pos_embed = nn.Parameter(th.zeros( 1, num_patches, hidden_size), requires_grad=True) # if mask_ratio is not None: # self.mask_token = nn.Parameter(th.zeros(1, 1, hidden_size)) # self.mask_ratio = float(mask_ratio) # self.decode_layer = int(decode_layer) # else: # self.mask_token = nn.Parameter(th.zeros( # 1, 1, hidden_size), requires_grad=False) # self.mask_ratio = None # self.decode_layer = int(decode_layer) assert mask_t != 0 and mask_f != 0 self.mask_token = nn.Parameter(th.zeros(1, 1, hidden_size)) self.mask_t = mask_t self.mask_f = mask_f self.decode_layer = int(decode_layer) print(f"mask ratio: T-{self.mask_t} F-{self.mask_f}", "decode_layer:", self.decode_layer) self.initialize_weights() def forward(self, x, timestep, context_list, context_mask_list=None, enable_mask=False, **kwargs): """ Forward pass of PixArt. x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) t: (N,) tensor of diffusion timesteps y: (N, 1, 120, C) tensor of class labels """ x = x.to(self.dtype) timestep = timestep.to(self.dtype) y = context_list[0].to(self.dtype) pos_embed = self.pos_embed.to(self.dtype) self.h, self.w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1] # import pdb # pdb.set_trace() x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2 t = self.t_embedder(timestep.to(x.dtype)) # (N, D) t0 = self.t_block(t) y = self.y_embedder(y) # (N, L, D) # if not self.training: try: mask = context_mask_list[0] # (N, L) except: mask = th.ones(x.shape[0], 1).to(x.device) print("MASK !!!!!!!!!!!!!!!!!!!!!!!!!") assert mask is not None # if mask is not None: y = y.masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) y_lens = mask.sum(dim=1).tolist() y_lens = [int(_) for _ in y_lens] input_skip = x masked_stage = False skips = [] # TODO : masking op for training if self.mask_t is not None and self.training: # masking: length -> length * mask_ratio rand_mask_ratio = th.rand(1, device=x.device) # noise in [0, 1] rand_mask_ratio_t = rand_mask_ratio * 0.13 + self.mask_t # mask_ratio, mask_ratio + 0.2 rand_mask_ratio_f = rand_mask_ratio * 0.13 + self.mask_f # print(rand_mask_ratio) x, mask, ids_restore, ids_keep = self.random_masking_2d( x, rand_mask_ratio_t, rand_mask_ratio_f) masked_stage = True for block in self.en_inblocks: if masked_stage: x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=ids_keep) else: x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=None) skips.append(x) for block in self.en_outblocks: if masked_stage: x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=ids_keep) else: x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=None) if self.mask_t is not None and self.mask_f is not None and self.training: x = self.forward_side_interpolater(x, y, t0, y_lens, mask, ids_restore) masked_stage = False else: # add pos embed x = x + self.decoder_pos_embed for i in range(len(self.de_blocks)): block = self.de_blocks[i] this_skip = input_skip x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=this_skip, ids_keep=None) x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) x = self.unpatchify(x) # (N, out_channels, H, W) return x def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs): """ dpm solver donnot need variance prediction """ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb model_out = self.forward(x, timestep, y, mask) return model_out.chunk(2, dim=1)[0] def forward_with_cfg(self, x, timestep, context_list, context_mask_list=None, cfg_scale=4.0, **kwargs): """ Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance. """ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb # import pdb # pdb.set_trace() half = x[: len(x) // 2] combined = th.cat([half, half], dim=0) model_out = self.forward(combined, timestep, context_list, context_mask_list=None) model_out = model_out['x'] if isinstance(model_out, dict) else model_out eps, rest = model_out[:, :8], model_out[:, 8:] cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0) half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) eps = th.cat([half_eps, half_eps], dim=0) return eps # return th.cat([eps, rest], dim=1) def unpatchify(self, x): """ x: (N, T, patch_size 0 * patch_size 1 * C) imgs: (Bs. 256. 16. 8) """ if self.x_embedder.ol == (0, 0) or self.x_embedder.ol == [0, 0]: c = self.out_channels p0 = self.x_embedder.patch_size[0] p1 = self.x_embedder.patch_size[1] h, w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1] x = x.reshape(shape=(x.shape[0], h, w, p0, p1, c)) x = th.einsum('nhwpqc->nchpwq', x) imgs = x.reshape(shape=(x.shape[0], c, h * p0, w * p1)) return imgs lf = self.x_embedder.grid_size[0] rf = self.x_embedder.grid_size[1] lp = self.x_embedder.patch_size[0] rp = self.x_embedder.patch_size[1] lo = self.x_embedder.ol[0] ro = self.x_embedder.ol[1] lm = self.x_embedder.img_size[0] rm = self.x_embedder.img_size[1] lpad = self.x_embedder.pad_size[0] rpad = self.x_embedder.pad_size[1] bs = x.shape[0] torch_map = self.torch_map c = self.out_channels x = x.reshape(shape=(bs, lf, rf, lp, rp, c)) x = th.einsum('nhwpqc->nchwpq', x) added_map = th.zeros(bs, c, lm+2*lpad, rm+2*rpad).to(x.device) for i in range(lf): for j in range(rf): xx = (i) * (lp - lo) yy = (j) * (rp - ro) added_map[:, :, xx:(xx+lp), yy:(yy+rp)] += \ x[:, :, i, j, :, :] # import pdb # pdb.set_trace() added_map = added_map[:][:][lpad:lm+lpad, rpad:rm+rpad] return th.mul(added_map.to(x.device), torch_map.to(x.device)) def random_masking_2d(self, x, mask_t_prob, mask_f_prob): """ 2D: Spectrogram (msking t and f under mask_t_prob and mask_f_prob) Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random noise. x: [N, L, D], sequence """ N, L, D = x.shape # batch, length, dim # if self.use_custom_patch: # overlapped patch # T=101 # F=12 # else: # T=64 # F=8 T = self.x_embedder.grid_size[0] F = self.x_embedder.grid_size[1] #x = x.reshape(N, T, F, D) len_keep_t = int(T * (1 - mask_t_prob)) len_keep_f = int(F * (1 - mask_f_prob)) # noise for mask in time noise_t = th.rand(N, T, device=x.device) # noise in [0, 1] # sort noise for each sample aling time ids_shuffle_t = th.argsort(noise_t, dim=1) # ascend: small is keep, large is remove ids_restore_t = th.argsort(ids_shuffle_t, dim=1) ids_keep_t = ids_shuffle_t[:,:len_keep_t] # noise mask in freq noise_f = th.rand(N, F, device=x.device) # noise in [0, 1] ids_shuffle_f = th.argsort(noise_f, dim=1) # ascend: small is keep, large is remove ids_restore_f = th.argsort(ids_shuffle_f, dim=1) ids_keep_f = ids_shuffle_f[:,:len_keep_f] # # generate the binary mask: 0 is keep, 1 is remove # mask in freq mask_f = th.ones(N, F, device=x.device) mask_f[:,:len_keep_f] = 0 mask_f = th.gather(mask_f, dim=1, index=ids_restore_f).unsqueeze(1).repeat(1,T,1) # N,T,F # mask in time mask_t = th.ones(N, T, device=x.device) mask_t[:,:len_keep_t] = 0 mask_t = th.gather(mask_t, dim=1, index=ids_restore_t).unsqueeze(1).repeat(1,F,1).permute(0,2,1) # N,T,F mask = 1-(1-mask_t)*(1-mask_f) # N, T, F # get masked x id2res=th.Tensor(list(range(N*T*F))).reshape(N,T,F).to(x.device) id2res = id2res + 999*mask # add a large value for masked elements id2res2 = th.argsort(id2res.flatten(start_dim=1)) ids_keep=id2res2.flatten(start_dim=1)[:,:len_keep_f*len_keep_t] x_masked = th.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) ids_restore = th.argsort(id2res2.flatten(start_dim=1)) mask = mask.flatten(start_dim=1) return x_masked, mask, ids_restore, ids_keep def random_masking(self, x, mask_ratio): """ Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random noise. x: [N, L, D], sequence """ N, L, D = x.shape # batch, length, dim len_keep = int(L * (1 - mask_ratio)) noise = th.rand(N, L, device=x.device) # noise in [0, 1] # sort noise for each sample # ascend: small is keep, large is remove ids_shuffle = th.argsort(noise, dim=1) ids_restore = th.argsort(ids_shuffle, dim=1) # keep the first subset ids_keep = ids_shuffle[:, :len_keep] x_masked = th.gather( x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) # generate the binary mask: 0 is keep, 1 is remove mask = th.ones([N, L], device=x.device) mask[:, :len_keep] = 0 # unshuffle to get the binary mask mask = th.gather(mask, dim=1, index=ids_restore) return x_masked, mask, ids_restore, ids_keep def forward_side_interpolater(self, x, y, t0, y_lens, mask, ids_restore): # append mask tokens to sequence mask_tokens = self.mask_token.repeat( x.shape[0], ids_restore.shape[1] - x.shape[1], 1) x_ = th.cat([x, mask_tokens], dim=1) x = th.gather( x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle # add pos embed x = x + self.decoder_pos_embed # pass to the basic block x_before = x for sideblock in self.sideblocks: x = sideblock(x, y, t0, y_lens, ids_keep=None) # masked shortcut mask = mask.unsqueeze(dim=-1) x = x*mask + (1-mask)*x_before return x def initialize_weights(self): # Initialize transformer layers: def _basic_init(module): if isinstance(module, nn.Linear): th.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) self.apply(_basic_init) # Initialize (and freeze) pos_embed by sin-cos embedding: pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size, lewei_scale=self.lewei_scale, base_size=self.base_size) self.pos_embed.data.copy_(th.from_numpy(pos_embed).float().unsqueeze(0)) # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): w = self.x_embedder.proj.weight.data nn.init.xavier_uniform_(w.view([w.shape[0], -1])) # Initialize timestep embedding MLP: nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) nn.init.normal_(self.t_block[1].weight, std=0.02) # Initialize caption embedding MLP: nn.init.normal_(self.y_embedder.weight, std=0.02) # nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02) # Zero-out adaLN modulation layers in PixArt blocks: for block in self.en_inblocks: nn.init.constant_(block.cross_attn.proj.weight, 0) nn.init.constant_(block.cross_attn.proj.bias, 0) for block in self.en_outblocks: nn.init.constant_(block.cross_attn.proj.weight, 0) nn.init.constant_(block.cross_attn.proj.bias, 0) for block in self.de_blocks: nn.init.constant_(block.cross_attn.proj.weight, 0) nn.init.constant_(block.cross_attn.proj.bias, 0) for block in self.sideblocks: nn.init.constant_(block.cross_attn.proj.weight, 0) nn.init.constant_(block.cross_attn.proj.bias, 0) # Zero-out output layers: nn.init.constant_(self.final_layer.linear.weight, 0) nn.init.constant_(self.final_layer.linear.bias, 0) if self.x_embedder.ol == [0, 0] or self.x_embedder.ol == (0, 0): return lf = self.x_embedder.grid_size[0] rf = self.x_embedder.grid_size[1] lp = self.x_embedder.patch_size[0] rp = self.x_embedder.patch_size[1] lo = self.x_embedder.ol[0] ro = self.x_embedder.ol[1] lm = self.x_embedder.img_size[0] rm = self.x_embedder.img_size[1] lpad = self.x_embedder.pad_size[0] rpad = self.x_embedder.pad_size[1] torch_map = th.zeros(lm+2*lpad, rm+2*rpad).to('cuda') for i in range(lf): for j in range(rf): xx = (i) * (lp - lo) yy = (j) * (rp - ro) torch_map[xx:(xx+lp), yy:(yy+rp)]+=1 torch_map = torch_map[lpad:lm+lpad, rpad:rm+rpad] self.torch_map = th.reciprocal(torch_map) @property def dtype(self): return next(self.parameters()).dtype class PixArt_MDT_MOS_AS_TOKEN(nn.Module): """ Diffusion model with a Transformer backbone. """ def __init__(self, input_size=(256,16), patch_size=(16,4), overlap=(0, 0), in_channels=8, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, pred_sigma=False, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, cond_dim=1024, lewei_scale=1.0, use_cfg=True, cfg_scale=4.0, config=None, model_max_length=120, mask_ratio=None, decode_layer=4,**kwargs): if window_block_indexes is None: window_block_indexes = [] super().__init__() self.use_cfg = use_cfg self.cfg_scale = cfg_scale self.input_size = input_size self.pred_sigma = pred_sigma self.in_channels = in_channels self.out_channels = in_channels self.patch_size = patch_size self.num_heads = num_heads self.lewei_scale = lewei_scale, decode_layer = int(decode_layer) self.mos_embed = nn.Embedding(num_embeddings=5, embedding_dim=hidden_size) self.x_embedder = PatchEmbed(input_size, patch_size, overlap, in_channels, hidden_size, bias=True) # self.x_embedder = PatchEmbed_1D(input) self.t_embedder = TimestepEmbedder(hidden_size) num_patches = self.x_embedder.num_patches self.base_size = input_size[0] // self.patch_size[0] * 2 # Will use fixed sin-cos embedding: self.register_buffer("pos_embed", th.zeros(1, num_patches, hidden_size)) approx_gelu = lambda: nn.GELU(approximate="tanh") self.t_block = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True) ) # self.mos_block = nn.Sequential( # ) self.y_embedder = nn.Linear(cond_dim, hidden_size) half_depth = (depth - decode_layer)//2 self.half_depth=half_depth drop_path_half = [x.item() for x in th.linspace(0, drop_path, half_depth)] # stochastic depth decay rule drop_path_decode = [x.item() for x in th.linspace(0, drop_path, decode_layer)] self.en_inblocks = nn.ModuleList([ MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i], input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), window_size=0, use_rel_pos=False, FFN_type='mlp') for i in range(half_depth) ]) self.en_outblocks = nn.ModuleList([ MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i], input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), window_size=0, use_rel_pos=False, skip=True, FFN_type='mlp') for i in range(half_depth) ]) self.de_blocks = nn.ModuleList([ MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_decode[i], input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), window_size=0, use_rel_pos=False, skip=True, FFN_type='mlp') for i in range(decode_layer) ]) self.sideblocks = nn.ModuleList([ MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), window_size=0, use_rel_pos=False, FFN_type='mlp') for _ in range(1) ]) self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels) self.decoder_pos_embed = nn.Parameter(th.zeros( 1, num_patches, hidden_size), requires_grad=True) if mask_ratio is not None: self.mask_token = nn.Parameter(th.zeros(1, 1, hidden_size)) self.mask_ratio = float(mask_ratio) self.decode_layer = int(decode_layer) else: self.mask_token = nn.Parameter(th.zeros( 1, 1, hidden_size), requires_grad=False) self.mask_ratio = None self.decode_layer = int(decode_layer) print("mask ratio:", self.mask_ratio, "decode_layer:", self.decode_layer) self.initialize_weights() def forward(self, x, timestep, context_list, context_mask_list=None, enable_mask=False, mos=None, **kwargs): """ Forward pass of PixArt. x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) t: (N,) tensor of diffusion timesteps y: (N, 1, 120, C) tensor of class labels """ # mos = th.ones(x.shape[0], dtype=th.int).to(x.device) #print(f'DEBUG! {x}, {mos}') assert mos.shape[0] == x.shape[0] #import pdb; pdb.set_trace() mos = mos - 1 mos = self.mos_embed(mos.to(x.device).to(th.int)) # [N, dim] x = x.to(self.dtype) timestep = timestep.to(self.dtype) y = context_list[0].to(self.dtype) pos_embed = self.pos_embed.to(self.dtype) self.h, self.w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1] # import pdb # pdb.set_trace() x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2 t = self.t_embedder(timestep.to(x.dtype)) # (N, D) t0 = self.t_block(t) y = self.y_embedder(y) # (N, L, D) # if not self.training: try: mask = context_mask_list[0] # (N, L) except: mask = th.ones(x.shape[0], 1).to(x.device) assert mask is not None # if mask is not None: y = y.masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) y_lens = mask.sum(dim=1).tolist() y_lens = [int(_) for _ in y_lens] masked_stage = False skips = [] # TODO : masking op for training try: x = th.cat([mos, x], dim=1) # [N, L+1, dim] except: x = th.cat([mos.unsqueeze(1), x], dim=1) input_skip = x if self.mask_ratio is not None and self.training: # masking: length -> length * mask_ratio rand_mask_ratio = th.rand(1, device=x.device) # noise in [0, 1] rand_mask_ratio = rand_mask_ratio * 0.2 + self.mask_ratio # mask_ratio, mask_ratio + 0.2 # print(rand_mask_ratio) x, mask, ids_restore, ids_keep = self.random_masking( x, rand_mask_ratio) masked_stage = True for block in self.en_inblocks: if masked_stage: x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=ids_keep) else: x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=None) skips.append(x) for block in self.en_outblocks: if masked_stage: x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=ids_keep) else: x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=None) if self.mask_ratio is not None and self.training: x = self.forward_side_interpolater(x, y, t0, y_lens, mask, ids_restore) masked_stage = False else: # add pos embed x[:, 1:, :] = x[:, 1:, :] + self.decoder_pos_embed # x = x + self.decoder_pos_embed # import pdb # pdb.set_trace() for i in range(len(self.de_blocks)): block = self.de_blocks[i] this_skip = input_skip x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=this_skip, ids_keep=None) x = x[:, 1:, :] x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) x = self.unpatchify(x) # (N, out_channels, H, W) # import pdb # pdb.set_trace() return x def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs): """ dpm solver donnot need variance prediction """ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb model_out = self.forward(x, timestep, y, mask) return model_out.chunk(2, dim=1)[0] def forward_with_cfg(self, x, timestep, context_list, context_mask_list=None, cfg_scale=4.0, **kwargs): """ Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance. """ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb # import pdb # pdb.set_trace() half = x[: len(x) // 2] combined = th.cat([half, half], dim=0) model_out = self.forward(combined, timestep, context_list, context_mask_list=None) model_out = model_out['x'] if isinstance(model_out, dict) else model_out eps, rest = model_out[:, :8], model_out[:, 8:] cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0) half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) eps = th.cat([half_eps, half_eps], dim=0) return eps # return th.cat([eps, rest], dim=1) def unpatchify(self, x): """ x: (N, T, patch_size 0 * patch_size 1 * C) imgs: (Bs. 256. 16. 8) """ if self.x_embedder.ol == (0, 0) or self.x_embedder.ol == [0, 0]: c = self.out_channels p0 = self.x_embedder.patch_size[0] p1 = self.x_embedder.patch_size[1] h, w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1] x = x.reshape(shape=(x.shape[0], h, w, p0, p1, c)) x = th.einsum('nhwpqc->nchpwq', x) imgs = x.reshape(shape=(x.shape[0], c, h * p0, w * p1)) return imgs lf = self.x_embedder.grid_size[0] rf = self.x_embedder.grid_size[1] lp = self.x_embedder.patch_size[0] rp = self.x_embedder.patch_size[1] lo = self.x_embedder.ol[0] ro = self.x_embedder.ol[1] lm = self.x_embedder.img_size[0] rm = self.x_embedder.img_size[1] lpad = self.x_embedder.pad_size[0] rpad = self.x_embedder.pad_size[1] bs = x.shape[0] torch_map = self.torch_map c = self.out_channels x = x.reshape(shape=(bs, lf, rf, lp, rp, c)) x = th.einsum('nhwpqc->nchwpq', x) added_map = th.zeros(bs, c, lm+2*lpad, rm+2*rpad).to(x.device) for i in range(lf): for j in range(rf): xx = (i) * (lp - lo) yy = (j) * (rp - ro) added_map[:, :, xx:(xx+lp), yy:(yy+rp)] += \ x[:, :, i, j, :, :] # import pdb # pdb.set_trace() added_map = added_map[:][:][lpad:lm+lpad, rpad:rm+rpad] return th.mul(added_map.to(x.device), torch_map.to(x.device)) def random_masking(self, x, mask_ratio): """ Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random noise. x: [N, L, D], sequence """ N, L, D = x.shape # batch, length, dim L = L - 1 len_keep = int(L * (1 - mask_ratio)) noise = th.rand(N, L, device=x.device) # noise in [0, 1] # sort noise for each sample # ascend: small is keep, large is remove ids_shuffle = th.argsort(noise, dim=1) ids_restore = th.argsort(ids_shuffle, dim=1) # keep the first subset ids_keep = ids_shuffle[:, :len_keep] x_masked = th.gather( x[:, 1:, :], dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) x_masked = th.cat([x[:, 0, :].unsqueeze(1), x_masked], dim=1) # import pdb # pdb.set_trace() # generate the binary mask: 0 is keep, 1 is remove mask = th.ones([N, L], device=x.device) mask[:, :len_keep] = 0 # unshuffle to get the binary mask mask = th.gather(mask, dim=1, index=ids_restore) return x_masked, mask, ids_restore, ids_keep def forward_side_interpolater(self, x, y, t0, y_lens, mask, ids_restore): # append mask tokens to sequence mask_tokens = self.mask_token.repeat( x.shape[0], ids_restore.shape[1] - x.shape[1] + 1, 1) x_ = th.cat([x[:, 1:, :], mask_tokens], dim=1) x_ = th.gather( x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle # add pos embed x_ = x_ + self.decoder_pos_embed x = th.cat([x[:, 0, :].unsqueeze(1), x_], dim=1) # import pdb # pdb.set_trace() # pass to the basic block x_before = x for sideblock in self.sideblocks: x = sideblock(x, y, t0, y_lens, ids_keep=None) # masked shortcut mask = mask.unsqueeze(dim=-1) # import pdb;pdb.set_trace() mask = th.cat([th.ones(mask.shape[0], 1, 1).to(mask.device), mask], dim=1) x = x*mask + (1-mask)*x_before return x def initialize_weights(self): # Initialize transformer layers: def _basic_init(module): if isinstance(module, nn.Linear): th.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) self.apply(_basic_init) # Initialize (and freeze) pos_embed by sin-cos embedding: pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size, lewei_scale=self.lewei_scale, base_size=self.base_size) self.pos_embed.data.copy_(th.from_numpy(pos_embed).float().unsqueeze(0)) # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): w = self.x_embedder.proj.weight.data nn.init.xavier_uniform_(w.view([w.shape[0], -1])) # Initialize timestep embedding MLP: nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) nn.init.normal_(self.t_block[1].weight, std=0.02) # Initialize caption embedding MLP: nn.init.normal_(self.y_embedder.weight, std=0.02) # nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02) # Zero-out adaLN modulation layers in PixArt blocks: for block in self.en_inblocks: nn.init.constant_(block.cross_attn.proj.weight, 0) nn.init.constant_(block.cross_attn.proj.bias, 0) for block in self.en_outblocks: nn.init.constant_(block.cross_attn.proj.weight, 0) nn.init.constant_(block.cross_attn.proj.bias, 0) for block in self.de_blocks: nn.init.constant_(block.cross_attn.proj.weight, 0) nn.init.constant_(block.cross_attn.proj.bias, 0) for block in self.sideblocks: nn.init.constant_(block.cross_attn.proj.weight, 0) nn.init.constant_(block.cross_attn.proj.bias, 0) # Zero-out output layers: nn.init.constant_(self.final_layer.linear.weight, 0) nn.init.constant_(self.final_layer.linear.bias, 0) if self.x_embedder.ol == [0, 0] or self.x_embedder.ol == (0, 0): return lf = self.x_embedder.grid_size[0] rf = self.x_embedder.grid_size[1] lp = self.x_embedder.patch_size[0] rp = self.x_embedder.patch_size[1] lo = self.x_embedder.ol[0] ro = self.x_embedder.ol[1] lm = self.x_embedder.img_size[0] rm = self.x_embedder.img_size[1] lpad = self.x_embedder.pad_size[0] rpad = self.x_embedder.pad_size[1] torch_map = th.zeros(lm+2*lpad, rm+2*rpad).to('cuda') for i in range(lf): for j in range(rf): xx = (i) * (lp - lo) yy = (j) * (rp - ro) torch_map[xx:(xx+lp), yy:(yy+rp)]+=1 torch_map = torch_map[lpad:lm+lpad, rpad:rm+rpad] self.torch_map = th.reciprocal(torch_map) @property def dtype(self): return next(self.parameters()).dtype class PixArt_MDT_LC(nn.Module): """ Diffusion model with a Transformer backbone. """ def __init__(self, input_size=(256,16), patch_size=(16,4), overlap=(0, 0), in_channels=8, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, pred_sigma=False, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, cond_dim=1024, lewei_scale=1.0, use_cfg=True, cfg_scale=4.0, config=None, model_max_length=120, mask_ratio=None, decode_layer=4,**kwargs): if window_block_indexes is None: window_block_indexes = [] super().__init__() self.use_cfg = use_cfg self.cfg_scale = cfg_scale self.input_size = input_size self.pred_sigma = pred_sigma self.in_channels = in_channels self.out_channels = in_channels self.patch_size = patch_size self.num_heads = num_heads self.lewei_scale = lewei_scale, decode_layer = int(decode_layer) self.x_embedder = PatchEmbed(input_size, patch_size, overlap, in_channels, hidden_size, bias=True) # self.x_embedder = PatchEmbed_1D(input) self.t_embedder = TimestepEmbedder(hidden_size) num_patches = self.x_embedder.num_patches self.base_size = input_size[0] // self.patch_size[0] * 2 # Will use fixed sin-cos embedding: self.register_buffer("pos_embed", th.zeros(1, num_patches, hidden_size)) approx_gelu = lambda: nn.GELU(approximate="tanh") self.t_block = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True) ) self.y_embedder = nn.Linear(cond_dim, hidden_size) half_depth = (depth - decode_layer)//2 self.half_depth=half_depth drop_path_half = [x.item() for x in th.linspace(0, drop_path, half_depth)] # stochastic depth decay rule drop_path_decode = [x.item() for x in th.linspace(0, drop_path, decode_layer)] self.en_inblocks = nn.ModuleList([ MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i], input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), window_size=0, use_rel_pos=False, FFN_type='mlp') for i in range(half_depth) ]) self.en_outblocks = nn.ModuleList([ MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i], input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), window_size=0, use_rel_pos=False, skip=True, FFN_type='mlp') for i in range(half_depth) ]) self.de_blocks = nn.ModuleList([ DEBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_decode[i], input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), window_size=0, use_rel_pos=False, skip=True, FFN_type='mlp', num_f=self.x_embedder.grid_size[1]+1, num_t=self.x_embedder.grid_size[0]) for i in range(decode_layer) ]) self.sideblocks = nn.ModuleList([ MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), window_size=0, use_rel_pos=False, FFN_type='mlp') for _ in range(1) ]) self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels) self.decoder_pos_embed = nn.Parameter(th.zeros( 1, num_patches, hidden_size), requires_grad=True) if mask_ratio is not None: self.mask_token = nn.Parameter(th.zeros(1, 1, hidden_size)) self.mask_ratio = float(mask_ratio) self.decode_layer = int(decode_layer) else: self.mask_token = nn.Parameter(th.zeros( 1, 1, hidden_size), requires_grad=False) self.mask_ratio = None self.decode_layer = int(decode_layer) print("mask ratio:", self.mask_ratio, "decode_layer:", self.decode_layer) self.initialize_weights() def forward(self, x, timestep, context_list, context_mask_list=None, enable_mask=False, **kwargs): """ Forward pass of PixArt. x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) t: (N,) tensor of diffusion timesteps y: (N, 1, 120, C) tensor of class labels """ x = x.to(self.dtype) timestep = timestep.to(self.dtype) y = context_list[0].to(self.dtype) pos_embed = self.pos_embed.to(self.dtype) self.h, self.w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1] # import pdb # pdb.set_trace() x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2 t = self.t_embedder(timestep.to(x.dtype)) # (N, D) t0 = self.t_block(t) y = self.y_embedder(y) # (N, L, D) # if not self.training: try: mask = context_mask_list[0] # (N, L) except: mask = th.ones(x.shape[0], 1).to(x.device) print("MASK !!!!!!!!!!!!!!!!!!!!!!!!!") assert mask is not None # if mask is not None: y = y.masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) y_lens = mask.sum(dim=1).tolist() y_lens = [int(_) for _ in y_lens] input_skip = x masked_stage = False skips = [] # TODO : masking op for training if self.mask_ratio is not None and self.training: # masking: length -> length * mask_ratio rand_mask_ratio = th.rand(1, device=x.device) # noise in [0, 1] rand_mask_ratio = rand_mask_ratio * 0.2 + self.mask_ratio # mask_ratio, mask_ratio + 0.2 # print(rand_mask_ratio) x, mask, ids_restore, ids_keep = self.random_masking( x, rand_mask_ratio) masked_stage = True for block in self.en_inblocks: if masked_stage: x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=ids_keep) else: x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=None) skips.append(x) for block in self.en_outblocks: if masked_stage: x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=ids_keep) else: x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=None) if self.mask_ratio is not None and self.training: x = self.forward_side_interpolater(x, y, t0, y_lens, mask, ids_restore) masked_stage = False else: # add pos embed x = x + self.decoder_pos_embed bs = x.shape[0] bs, D, L = x.shape T, F = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1] # reshaped = x.reshape(bs, T, F, L).to(x.device) end = th.zeros(bs, T, 1, L).to(x.device) # x = th.cat((reshaped, zero_tensor), 2) # import pdb;pdb.set_trace() # assert x.shape == [bs, T, F + 1, L] for i in range(len(self.de_blocks)): block = self.de_blocks[i] this_skip = input_skip x, end = auto_grad_checkpoint(block, x, end, y, t0, y_lens, skip=this_skip, ids_keep=None) x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) x = self.unpatchify(x) # (N, out_channels, H, W) return x def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs): """ dpm solver donnot need variance prediction """ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb model_out = self.forward(x, timestep, y, mask) return model_out.chunk(2, dim=1)[0] def forward_with_cfg(self, x, timestep, context_list, context_mask_list=None, cfg_scale=4.0, **kwargs): """ Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance. """ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb # import pdb # pdb.set_trace() half = x[: len(x) // 2] combined = th.cat([half, half], dim=0) model_out = self.forward(combined, timestep, context_list, context_mask_list=None) model_out = model_out['x'] if isinstance(model_out, dict) else model_out eps, rest = model_out[:, :8], model_out[:, 8:] cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0) half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) eps = th.cat([half_eps, half_eps], dim=0) return eps # return th.cat([eps, rest], dim=1) def unpatchify(self, x): """ x: (N, T, patch_size 0 * patch_size 1 * C) imgs: (Bs. 256. 16. 8) """ if self.x_embedder.ol == (0, 0) or self.x_embedder.ol == [0, 0]: c = self.out_channels p0 = self.x_embedder.patch_size[0] p1 = self.x_embedder.patch_size[1] h, w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1] x = x.reshape(shape=(x.shape[0], h, w, p0, p1, c)) x = th.einsum('nhwpqc->nchpwq', x) imgs = x.reshape(shape=(x.shape[0], c, h * p0, w * p1)) return imgs lf = self.x_embedder.grid_size[0] rf = self.x_embedder.grid_size[1] lp = self.x_embedder.patch_size[0] rp = self.x_embedder.patch_size[1] lo = self.x_embedder.ol[0] ro = self.x_embedder.ol[1] lm = self.x_embedder.img_size[0] rm = self.x_embedder.img_size[1] lpad = self.x_embedder.pad_size[0] rpad = self.x_embedder.pad_size[1] bs = x.shape[0] torch_map = self.torch_map c = self.out_channels x = x.reshape(shape=(bs, lf, rf, lp, rp, c)) x = th.einsum('nhwpqc->nchwpq', x) added_map = th.zeros(bs, c, lm+2*lpad, rm+2*rpad).to(x.device) for i in range(lf): for j in range(rf): xx = (i) * (lp - lo) yy = (j) * (rp - ro) added_map[:, :, xx:(xx+lp), yy:(yy+rp)] += \ x[:, :, i, j, :, :] # import pdb # pdb.set_trace() added_map = added_map[:][:][lpad:lm+lpad, rpad:rm+rpad] return th.mul(added_map.to(x.device), torch_map.to(x.device)) def random_masking(self, x, mask_ratio): """ Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random noise. x: [N, L, D], sequence """ N, L, D = x.shape # batch, length, dim len_keep = int(L * (1 - mask_ratio)) noise = th.rand(N, L, device=x.device) # noise in [0, 1] # sort noise for each sample # ascend: small is keep, large is remove ids_shuffle = th.argsort(noise, dim=1) ids_restore = th.argsort(ids_shuffle, dim=1) # keep the first subset ids_keep = ids_shuffle[:, :len_keep] x_masked = th.gather( x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) # generate the binary mask: 0 is keep, 1 is remove mask = th.ones([N, L], device=x.device) mask[:, :len_keep] = 0 # unshuffle to get the binary mask mask = th.gather(mask, dim=1, index=ids_restore) return x_masked, mask, ids_restore, ids_keep def forward_side_interpolater(self, x, y, t0, y_lens, mask, ids_restore): # append mask tokens to sequence mask_tokens = self.mask_token.repeat( x.shape[0], ids_restore.shape[1] - x.shape[1], 1) x_ = th.cat([x, mask_tokens], dim=1) x = th.gather( x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle # add pos embed x = x + self.decoder_pos_embed # pass to the basic block x_before = x for sideblock in self.sideblocks: x = sideblock(x, y, t0, y_lens, ids_keep=None) # masked shortcut mask = mask.unsqueeze(dim=-1) x = x*mask + (1-mask)*x_before return x def initialize_weights(self): # Initialize transformer layers: def _basic_init(module): if isinstance(module, nn.Linear): th.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) self.apply(_basic_init) # Initialize (and freeze) pos_embed by sin-cos embedding: pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size, lewei_scale=self.lewei_scale, base_size=self.base_size) self.pos_embed.data.copy_(th.from_numpy(pos_embed).float().unsqueeze(0)) # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): w = self.x_embedder.proj.weight.data nn.init.xavier_uniform_(w.view([w.shape[0], -1])) # Initialize timestep embedding MLP: nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) nn.init.normal_(self.t_block[1].weight, std=0.02) # Initialize caption embedding MLP: nn.init.normal_(self.y_embedder.weight, std=0.02) # nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02) # Zero-out adaLN modulation layers in PixArt blocks: for block in self.en_inblocks: nn.init.constant_(block.cross_attn.proj.weight, 0) nn.init.constant_(block.cross_attn.proj.bias, 0) for block in self.en_outblocks: nn.init.constant_(block.cross_attn.proj.weight, 0) nn.init.constant_(block.cross_attn.proj.bias, 0) for block in self.de_blocks: nn.init.constant_(block.cross_attn.proj.weight, 0) nn.init.constant_(block.cross_attn.proj.bias, 0) for block in self.sideblocks: nn.init.constant_(block.cross_attn.proj.weight, 0) nn.init.constant_(block.cross_attn.proj.bias, 0) # Zero-out output layers: nn.init.constant_(self.final_layer.linear.weight, 0) nn.init.constant_(self.final_layer.linear.bias, 0) if self.x_embedder.ol == [0, 0] or self.x_embedder.ol == (0, 0): return lf = self.x_embedder.grid_size[0] rf = self.x_embedder.grid_size[1] lp = self.x_embedder.patch_size[0] rp = self.x_embedder.patch_size[1] lo = self.x_embedder.ol[0] ro = self.x_embedder.ol[1] lm = self.x_embedder.img_size[0] rm = self.x_embedder.img_size[1] lpad = self.x_embedder.pad_size[0] rpad = self.x_embedder.pad_size[1] torch_map = th.zeros(lm+2*lpad, rm+2*rpad).to('cuda') for i in range(lf): for j in range(rf): xx = (i) * (lp - lo) yy = (j) * (rp - ro) torch_map[xx:(xx+lp), yy:(yy+rp)]+=1 torch_map = torch_map[lpad:lm+lpad, rpad:rm+rpad] self.torch_map = th.reciprocal(torch_map) @property def dtype(self): return next(self.parameters()).dtype class PixArt_MDT(nn.Module): """ Diffusion model with a Transformer backbone. """ def __init__(self, input_size=(256,16), patch_size=(16,4), overlap=(0, 0), in_channels=8, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, pred_sigma=False, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, cond_dim=1024, lewei_scale=1.0, use_cfg=True, cfg_scale=4.0, config=None, model_max_length=120, mask_ratio=None, decode_layer=4,**kwargs): if window_block_indexes is None: window_block_indexes = [] super().__init__() self.use_cfg = use_cfg self.cfg_scale = cfg_scale self.input_size = input_size self.pred_sigma = pred_sigma self.in_channels = in_channels self.out_channels = in_channels self.patch_size = patch_size self.num_heads = num_heads self.lewei_scale = lewei_scale, decode_layer = int(decode_layer) self.x_embedder = PatchEmbed(input_size, patch_size, overlap, in_channels, hidden_size, bias=True) # self.x_embedder = PatchEmbed_1D(input) self.t_embedder = TimestepEmbedder(hidden_size) num_patches = self.x_embedder.num_patches self.base_size = input_size[0] // self.patch_size[0] * 2 # Will use fixed sin-cos embedding: self.register_buffer("pos_embed", th.zeros(1, num_patches, hidden_size)) approx_gelu = lambda: nn.GELU(approximate="tanh") self.t_block = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True) ) self.y_embedder = nn.Linear(cond_dim, hidden_size) half_depth = (depth - decode_layer)//2 self.half_depth=half_depth drop_path_half = [x.item() for x in th.linspace(0, drop_path, half_depth)] # stochastic depth decay rule drop_path_decode = [x.item() for x in th.linspace(0, drop_path, decode_layer)] self.en_inblocks = nn.ModuleList([ MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i], input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), window_size=0, use_rel_pos=False) for i in range(half_depth) ]) self.en_outblocks = nn.ModuleList([ MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i], input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), window_size=0, use_rel_pos=False, skip=True) for i in range(half_depth) ]) self.de_blocks = nn.ModuleList([ MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_decode[i], input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), window_size=0, use_rel_pos=False, skip=True) for i in range(decode_layer) ]) self.sideblocks = nn.ModuleList([ MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), window_size=0, use_rel_pos=False) for _ in range(1) ]) self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels) self.decoder_pos_embed = nn.Parameter(th.zeros( 1, num_patches, hidden_size), requires_grad=True) if mask_ratio is not None: self.mask_token = nn.Parameter(th.zeros(1, 1, hidden_size)) self.mask_ratio = float(mask_ratio) self.decode_layer = int(decode_layer) else: self.mask_token = nn.Parameter(th.zeros( 1, 1, hidden_size), requires_grad=False) self.mask_ratio = None self.decode_layer = int(decode_layer) print("mask ratio:", self.mask_ratio, "decode_layer:", self.decode_layer) self.initialize_weights() def forward(self, x, timestep, context_list, context_mask_list=None, enable_mask=False, **kwargs): """ Forward pass of PixArt. x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) t: (N,) tensor of diffusion timesteps y: (N, 1, 120, C) tensor of class labels """ # print(f'debug_MDT : {x.shape[0]}') x = x.to(self.dtype) timestep = timestep.to(self.dtype) y = context_list[0].to(self.dtype) pos_embed = self.pos_embed.to(self.dtype) self.h, self.w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1] # import pdb # print(f'debug_MDT : {x.shape[0]}') # pdb.set_trace() x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2 # print(f'debug_MDT : {x.shape[0]}') t = self.t_embedder(timestep.to(x.dtype)) # (N, D) t0 = self.t_block(t) # print(f'debug_MDT : {x.shape[0]}') y = self.y_embedder(y) # (N, L, D) # if not self.training: try: mask = context_mask_list[0] # (N, L) except: mask = th.ones(x.shape[0], 1).to(x.device) print("MASK !!!!!!!!!!!!!!!!!!!!!!!!!") assert mask is not None # if mask is not None: y = y.masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) y_lens = mask.sum(dim=1).tolist() y_lens = [int(_) for _ in y_lens] # print(f'debug_MDT : {x.shape[0]}') input_skip = x masked_stage = False skips = [] # TODO : masking op for training if self.mask_ratio is not None and self.training: # masking: length -> length * mask_ratio rand_mask_ratio = th.rand(1, device=x.device) # noise in [0, 1] rand_mask_ratio = rand_mask_ratio * 0.2 + self.mask_ratio # mask_ratio, mask_ratio + 0.2 # print(rand_mask_ratio) x, mask, ids_restore, ids_keep = self.random_masking( x, rand_mask_ratio) masked_stage = True for block in self.en_inblocks: if masked_stage: x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=ids_keep) else: x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=None) skips.append(x) for block in self.en_outblocks: if masked_stage: x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=ids_keep) else: x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=None) if self.mask_ratio is not None and self.training: x = self.forward_side_interpolater(x, y, t0, y_lens, mask, ids_restore) masked_stage = False else: # add pos embed x = x + self.decoder_pos_embed for i in range(len(self.de_blocks)): block = self.de_blocks[i] this_skip = input_skip x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=this_skip, ids_keep=None) x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) x = self.unpatchify(x) # (N, out_channels, H, W) return x def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs): """ dpm solver donnot need variance prediction """ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb model_out = self.forward(x, timestep, y, mask) return model_out.chunk(2, dim=1)[0] def forward_with_cfg(self, x, timestep, y, cfg_scale, mask=None, **kwargs): """ Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance. """ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb half = x[: len(x) // 2] combined = th.cat([half, half], dim=0) model_out = self.forward(combined, timestep, y, mask) model_out = model_out['x'] if isinstance(model_out, dict) else model_out eps, rest = model_out[:, :8], model_out[:, 8:] cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0) half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) eps = th.cat([half_eps, half_eps], dim=0) return eps # return th.cat([eps, rest], dim=1) def unpatchify(self, x): """ x: (N, T, patch_size 0 * patch_size 1 * C) imgs: (Bs. 256. 16. 8) """ if self.x_embedder.ol == (0, 0) or self.x_embedder.ol == [0, 0]: c = self.out_channels p0 = self.x_embedder.patch_size[0] p1 = self.x_embedder.patch_size[1] h, w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1] x = x.reshape(shape=(x.shape[0], h, w, p0, p1, c)) x = th.einsum('nhwpqc->nchpwq', x) imgs = x.reshape(shape=(x.shape[0], c, h * p0, w * p1)) return imgs lf = self.x_embedder.grid_size[0] rf = self.x_embedder.grid_size[1] lp = self.x_embedder.patch_size[0] rp = self.x_embedder.patch_size[1] lo = self.x_embedder.ol[0] ro = self.x_embedder.ol[1] lm = self.x_embedder.img_size[0] rm = self.x_embedder.img_size[1] lpad = self.x_embedder.pad_size[0] rpad = self.x_embedder.pad_size[1] bs = x.shape[0] torch_map = self.torch_map c = self.out_channels x = x.reshape(shape=(bs, lf, rf, lp, rp, c)) x = th.einsum('nhwpqc->nchwpq', x) added_map = th.zeros(bs, c, lm+2*lpad, rm+2*rpad).to(x.device) for i in range(lf): for j in range(rf): xx = (i) * (lp - lo) yy = (j) * (rp - ro) added_map[:, :, xx:(xx+lp), yy:(yy+rp)] += \ x[:, :, i, j, :, :] added_map = added_map[:, :, lpad:lm+lpad, rpad:rm+rpad] return th.mul(added_map, torch_map.to(added_map.device)) def random_masking(self, x, mask_ratio): """ Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random noise. x: [N, L, D], sequence """ N, L, D = x.shape # batch, length, dim len_keep = int(L * (1 - mask_ratio)) noise = th.rand(N, L, device=x.device) # noise in [0, 1] # sort noise for each sample # ascend: small is keep, large is remove ids_shuffle = th.argsort(noise, dim=1) ids_restore = th.argsort(ids_shuffle, dim=1) # keep the first subset ids_keep = ids_shuffle[:, :len_keep] x_masked = th.gather( x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) # generate the binary mask: 0 is keep, 1 is remove mask = th.ones([N, L], device=x.device) mask[:, :len_keep] = 0 # unshuffle to get the binary mask mask = th.gather(mask, dim=1, index=ids_restore) return x_masked, mask, ids_restore, ids_keep def forward_side_interpolater(self, x, y, t0, y_lens, mask, ids_restore): # append mask tokens to sequence mask_tokens = self.mask_token.repeat( x.shape[0], ids_restore.shape[1] - x.shape[1], 1) x_ = th.cat([x, mask_tokens], dim=1) x = th.gather( x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle # add pos embed x = x + self.decoder_pos_embed # pass to the basic block x_before = x for sideblock in self.sideblocks: x = sideblock(x, y, t0, y_lens, ids_keep=None) # masked shortcut mask = mask.unsqueeze(dim=-1) x = x*mask + (1-mask)*x_before return x def initialize_weights(self): # Initialize transformer layers: def _basic_init(module): if isinstance(module, nn.Linear): th.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) self.apply(_basic_init) # Initialize (and freeze) pos_embed by sin-cos embedding: pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size, lewei_scale=self.lewei_scale, base_size=self.base_size) self.pos_embed.data.copy_(th.from_numpy(pos_embed).float().unsqueeze(0)) # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): w = self.x_embedder.proj.weight.data nn.init.xavier_uniform_(w.view([w.shape[0], -1])) # Initialize timestep embedding MLP: nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) nn.init.normal_(self.t_block[1].weight, std=0.02) # Initialize caption embedding MLP: nn.init.normal_(self.y_embedder.weight, std=0.02) # nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02) # Zero-out adaLN modulation layers in PixArt blocks: for block in self.en_inblocks: nn.init.constant_(block.cross_attn.proj.weight, 0) nn.init.constant_(block.cross_attn.proj.bias, 0) for block in self.en_outblocks: nn.init.constant_(block.cross_attn.proj.weight, 0) nn.init.constant_(block.cross_attn.proj.bias, 0) for block in self.de_blocks: nn.init.constant_(block.cross_attn.proj.weight, 0) nn.init.constant_(block.cross_attn.proj.bias, 0) for block in self.sideblocks: nn.init.constant_(block.cross_attn.proj.weight, 0) nn.init.constant_(block.cross_attn.proj.bias, 0) # Zero-out output layers: nn.init.constant_(self.final_layer.linear.weight, 0) nn.init.constant_(self.final_layer.linear.bias, 0) if self.x_embedder.ol == [0, 0] or self.x_embedder.ol == (0, 0): return lf = self.x_embedder.grid_size[0] rf = self.x_embedder.grid_size[1] lp = self.x_embedder.patch_size[0] rp = self.x_embedder.patch_size[1] lo = self.x_embedder.ol[0] ro = self.x_embedder.ol[1] lm = self.x_embedder.img_size[0] rm = self.x_embedder.img_size[1] lpad = self.x_embedder.pad_size[0] rpad = self.x_embedder.pad_size[1] torch_map = th.zeros(lm+2*lpad, rm+2*rpad).to('cuda') for i in range(lf): for j in range(rf): xx = (i) * (lp - lo) yy = (j) * (rp - ro) torch_map[xx:(xx+lp), yy:(yy+rp)]+=1 torch_map = torch_map[lpad:lm+lpad, rpad:rm+rpad] self.torch_map = th.reciprocal(torch_map) @property def dtype(self): return next(self.parameters()).dtype class PixArt_MDT_FIT(nn.Module): """ Diffusion model with a Transformer backbone. """ def __init__(self, input_size=(256,16), patch_size=(16,4), overlap=(0, 0), in_channels=8, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, pred_sigma=False, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, cond_dim=1024, lewei_scale=1.0, use_cfg=True, cfg_scale=4.0, config=None, model_max_length=120, mask_ratio=None, decode_layer=4,**kwargs): if window_block_indexes is None: window_block_indexes = [] super().__init__() self.use_cfg = use_cfg self.cfg_scale = cfg_scale self.input_size = input_size self.pred_sigma = pred_sigma self.in_channels = in_channels self.out_channels = in_channels self.patch_size = patch_size self.num_heads = num_heads self.lewei_scale = lewei_scale, decode_layer = int(decode_layer) self.x_embedder = PatchEmbed(input_size, patch_size, overlap, in_channels, hidden_size, bias=True) # self.x_embedder = PatchEmbed_1D(input) self.t_embedder = TimestepEmbedder(hidden_size) num_patches = self.x_embedder.num_patches self.base_size = input_size[0] // self.patch_size[0] * 2 # Will use fixed sin-cos embedding: self.register_buffer("pos_embed", th.zeros(1, num_patches, hidden_size)) approx_gelu = lambda: nn.GELU(approximate="tanh") self.t_block = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True) ) self.y_embedder = nn.Linear(cond_dim, hidden_size) half_depth = (depth - decode_layer)//2 self.half_depth=half_depth drop_path_half = [x.item() for x in th.linspace(0, drop_path, half_depth)] # stochastic depth decay rule drop_path_decode = [x.item() for x in th.linspace(0, drop_path, decode_layer)] self.en_inblocks = nn.ModuleList([ MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i], input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), window_size=0, use_rel_pos=False) for i in range(half_depth) ]) self.en_outblocks = nn.ModuleList([ MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i], input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), window_size=0, use_rel_pos=False, skip=True) for i in range(half_depth) ]) self.de_blocks = nn.ModuleList([ MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_decode[i], input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), window_size=0, use_rel_pos=False, skip=True) for i in range(decode_layer) ]) self.sideblocks = nn.ModuleList([ MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), window_size=0, use_rel_pos=False) for _ in range(1) ]) self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels) self.decoder_pos_embed = nn.Parameter(th.zeros( 1, num_patches, hidden_size), requires_grad=True) if mask_ratio is not None: self.mask_token = nn.Parameter(th.zeros(1, 1, hidden_size)) self.mask_ratio = float(mask_ratio) self.decode_layer = int(decode_layer) else: self.mask_token = nn.Parameter(th.zeros( 1, 1, hidden_size), requires_grad=False) self.mask_ratio = None self.decode_layer = int(decode_layer) print("mask ratio:", self.mask_ratio, "decode_layer:", self.decode_layer) self.initialize_weights() def forward(self, x, timestep, context_list, context_mask_list=None, enable_mask=False, **kwargs): """ Forward pass of PixArt. x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) t: (N,) tensor of diffusion timesteps y: (N, 1, 120, C) tensor of class labels """ x = x.to(self.dtype) timestep = timestep.to(self.dtype) y = context_list[0].to(self.dtype) pos_embed = self.pos_embed.to(self.dtype) self.h, self.w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1] # import pdb # pdb.set_trace() x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2 t = self.t_embedder(timestep.to(x.dtype)) # (N, D) t0 = self.t_block(t) y = self.y_embedder(y) # (N, L, D) # if not self.training: try: mask = context_mask_list[0] # (N, L) except: mask = th.ones(x.shape[0], 1).to(x.device) print("MASK !!!!!!!!!!!!!!!!!!!!!!!!!") assert mask is not None # if mask is not None: y = y.masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) y_lens = mask.sum(dim=1).tolist() y_lens = [int(_) for _ in y_lens] input_skip = x masked_stage = False skips = [] # TODO : masking op for training if self.mask_ratio is not None and self.training: # masking: length -> length * mask_ratio rand_mask_ratio = th.rand(1, device=x.device) # noise in [0, 1] rand_mask_ratio = rand_mask_ratio * 0.2 + self.mask_ratio # mask_ratio, mask_ratio + 0.2 # print(rand_mask_ratio) x, mask, ids_restore, ids_keep = self.random_masking( x, rand_mask_ratio) masked_stage = True for block in self.en_inblocks: if masked_stage: x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=ids_keep) else: x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=None) skips.append(x) for block in self.en_outblocks: if masked_stage: x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=ids_keep) else: x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=None) if self.mask_ratio is not None and self.training: x = self.forward_side_interpolater(x, y, t0, y_lens, mask, ids_restore) masked_stage = False else: # add pos embed x = x + self.decoder_pos_embed for i in range(len(self.de_blocks)): block = self.de_blocks[i] this_skip = input_skip x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=this_skip, ids_keep=None) x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) x = self.unpatchify(x) # (N, out_channels, H, W) return x def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs): """ dpm solver donnot need variance prediction """ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb model_out = self.forward(x, timestep, y, mask) return model_out.chunk(2, dim=1)[0] def forward_with_cfg(self, x, timestep, y, cfg_scale, mask=None, **kwargs): """ Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance. """ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb half = x[: len(x) // 2] combined = th.cat([half, half], dim=0) model_out = self.forward(combined, timestep, y, mask) model_out = model_out['x'] if isinstance(model_out, dict) else model_out eps, rest = model_out[:, :8], model_out[:, 8:] cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0) half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) eps = th.cat([half_eps, half_eps], dim=0) return eps # return th.cat([eps, rest], dim=1) def unpatchify(self, x): """ x: (N, T, patch_size 0 * patch_size 1 * C) imgs: (Bs. 256. 16. 8) """ c = self.out_channels p0 = self.x_embedder.patch_size[0] p1 = self.x_embedder.patch_size[1] h, w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1] x = x.reshape(shape=(x.shape[0], h, w, p0, p1, c)) x = th.einsum('nhwpqc->nchpwq', x) imgs = x.reshape(shape=(x.shape[0], c, h * p0, w * p1)) return imgs def random_masking(self, x, mask_ratio): """ Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random noise. x: [N, L, D], sequence """ N, L, D = x.shape # batch, length, dim len_keep = int(L * (1 - mask_ratio)) noise = th.rand(N, L, device=x.device) # noise in [0, 1] # sort noise for each sample # ascend: small is keep, large is remove ids_shuffle = th.argsort(noise, dim=1) ids_restore = th.argsort(ids_shuffle, dim=1) # keep the first subset ids_keep = ids_shuffle[:, :len_keep] x_masked = th.gather( x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) # generate the binary mask: 0 is keep, 1 is remove mask = th.ones([N, L], device=x.device) mask[:, :len_keep] = 0 # unshuffle to get the binary mask mask = th.gather(mask, dim=1, index=ids_restore) return x_masked, mask, ids_restore, ids_keep def forward_side_interpolater(self, x, y, t0, y_lens, mask, ids_restore): # append mask tokens to sequence mask_tokens = self.mask_token.repeat( x.shape[0], ids_restore.shape[1] - x.shape[1], 1) x_ = th.cat([x, mask_tokens], dim=1) x = th.gather( x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle # add pos embed x = x + self.decoder_pos_embed # pass to the basic block x_before = x for sideblock in self.sideblocks: x = sideblock(x, y, t0, y_lens, ids_keep=None) # masked shortcut mask = mask.unsqueeze(dim=-1) x = x*mask + (1-mask)*x_before return x def initialize_weights(self): # Initialize transformer layers: def _basic_init(module): if isinstance(module, nn.Linear): th.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) self.apply(_basic_init) # Initialize (and freeze) pos_embed by sin-cos embedding: pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size, lewei_scale=self.lewei_scale, base_size=self.base_size) # Replace the absolute embedding with 2d-rope position embedding: # pos_embed = self.pos_embed.data.copy_(th.from_numpy(pos_embed).float().unsqueeze(0)) # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): w = self.x_embedder.proj.weight.data nn.init.xavier_uniform_(w.view([w.shape[0], -1])) # Initialize timestep embedding MLP: nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) nn.init.normal_(self.t_block[1].weight, std=0.02) # Initialize caption embedding MLP: nn.init.normal_(self.y_embedder.weight, std=0.02) # nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02) # Zero-out adaLN modulation layers in PixArt blocks: for block in self.en_inblocks: nn.init.constant_(block.cross_attn.proj.weight, 0) nn.init.constant_(block.cross_attn.proj.bias, 0) for block in self.en_outblocks: nn.init.constant_(block.cross_attn.proj.weight, 0) nn.init.constant_(block.cross_attn.proj.bias, 0) for block in self.de_blocks: nn.init.constant_(block.cross_attn.proj.weight, 0) nn.init.constant_(block.cross_attn.proj.bias, 0) for block in self.sideblocks: nn.init.constant_(block.cross_attn.proj.weight, 0) nn.init.constant_(block.cross_attn.proj.bias, 0) # Zero-out output layers: nn.init.constant_(self.final_layer.linear.weight, 0) nn.init.constant_(self.final_layer.linear.bias, 0) @property def dtype(self): return next(self.parameters()).dtype class PixArt_Slow(nn.Module): """ Diffusion model with a Transformer backbone. """ def __init__(self, input_size=(256,16), patch_size=(16,4), overlap=(0, 0), in_channels=8, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, pred_sigma=True, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, cond_dim=1024, lewei_scale=1.0, use_cfg=True, cfg_scale=4.0, config=None, model_max_length=120, **kwargs): if window_block_indexes is None: window_block_indexes = [] super().__init__() self.use_cfg = use_cfg self.cfg_scale = cfg_scale self.input_size = input_size self.pred_sigma = pred_sigma self.in_channels = in_channels self.out_channels = in_channels * 2 if pred_sigma else in_channels self.patch_size = patch_size self.num_heads = num_heads self.lewei_scale = lewei_scale, self.x_embedder = PatchEmbed(input_size, patch_size, overlap, in_channels, hidden_size, bias=True) # self.x_embedder = PatchEmbed_1D(input) self.t_embedder = TimestepEmbedder(hidden_size) num_patches = self.x_embedder.num_patches self.base_size = input_size[0] // self.patch_size[0] * 2 # Will use fixed sin-cos embedding: self.register_buffer("pos_embed", th.zeros(1, num_patches, hidden_size)) approx_gelu = lambda: nn.GELU(approximate="tanh") self.t_block = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True) ) self.y_embedder = nn.Linear(cond_dim, hidden_size) drop_path = [x.item() for x in th.linspace(0, drop_path, depth)] # stochastic depth decay rule self.blocks = nn.ModuleList([ PixArtBlock_Slow(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i], input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), window_size=0, use_rel_pos=False) for i in range(depth) ]) self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels) self.initialize_weights() def forward(self, x, timestep, context_list, context_mask_list=None, **kwargs): """ Forward pass of PixArt. x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) t: (N,) tensor of diffusion timesteps y: (N, 1, 120, C) tensor of class labels """ x = x.to(self.dtype) timestep = timestep.to(self.dtype) y = context_list[0].to(self.dtype) pos_embed = self.pos_embed.to(self.dtype) self.h, self.w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1] x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2 t = self.t_embedder(timestep.to(x.dtype)) # (N, D) t0 = self.t_block(t) y = self.y_embedder(y) # (N, L, D) mask = context_mask_list[0] # (N, L) assert mask is not None # if mask is not None: # y = y.masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) # y_lens = mask.sum(dim=1).tolist() # y_lens = [int(_) for _ in y_lens] for block in self.blocks: x = auto_grad_checkpoint(block, x, y, t0, mask) # (N, T, D) #support grad checkpoint x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) x = self.unpatchify(x) # (N, out_channels, H, W) return x def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs): """ dpm solver donnot need variance prediction """ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb model_out = self.forward(x, timestep, y, mask) return model_out.chunk(2, dim=1)[0] def forward_with_cfg(self, x, timestep, y, cfg_scale, mask=None, **kwargs): """ Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance. """ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb half = x[: len(x) // 2] combined = th.cat([half, half], dim=0) model_out = self.forward(combined, timestep, y, mask) model_out = model_out['x'] if isinstance(model_out, dict) else model_out eps, rest = model_out[:, :8], model_out[:, 8:] cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0) half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) eps = th.cat([half_eps, half_eps], dim=0) return eps # return th.cat([eps, rest], dim=1) def unpatchify(self, x): """ x: (N, T, patch_size 0 * patch_size 1 * C) imgs: (Bs. 256. 16. 8) """ c = self.out_channels p0 = self.x_embedder.patch_size[0] p1 = self.x_embedder.patch_size[1] h, w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1] x = x.reshape(shape=(x.shape[0], h, w, p0, p1, c)) x = th.einsum('nhwpqc->nchpwq', x) imgs = x.reshape(shape=(x.shape[0], c, h * p0, w * p1)) return imgs def initialize_weights(self): # Initialize transformer layers: def _basic_init(module): if isinstance(module, nn.Linear): th.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) self.apply(_basic_init) # Initialize (and freeze) pos_embed by sin-cos embedding: pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size, lewei_scale=self.lewei_scale, base_size=self.base_size) self.pos_embed.data.copy_(th.from_numpy(pos_embed).float().unsqueeze(0)) # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): w = self.x_embedder.proj.weight.data nn.init.xavier_uniform_(w.view([w.shape[0], -1])) # Initialize timestep embedding MLP: nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) nn.init.normal_(self.t_block[1].weight, std=0.02) # Initialize caption embedding MLP: nn.init.normal_(self.y_embedder.weight, std=0.02) # nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02) # Zero-out adaLN modulation layers in PixArt blocks: # for block in self.blocks: # nn.init.constant_(block.cross_attn.proj.weight, 0) # nn.init.constant_(block.cross_attn.proj.bias, 0) # Zero-out output layers: nn.init.constant_(self.final_layer.linear.weight, 0) nn.init.constant_(self.final_layer.linear.bias, 0) @property def dtype(self): return next(self.parameters()).dtype class PixArtBlock_1D(nn.Module): """ A PixArt block with adaptive layer norm (adaLN-single) conditioning. """ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., window_size=0, use_rel_pos=False, **block_kwargs): super().__init__() self.hidden_size = hidden_size self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.attn = WindowAttention(hidden_size, num_heads=num_heads, qkv_bias=True, input_size=None, use_rel_pos=use_rel_pos, **block_kwargs) # self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) # to be compatible with lower version pytorch approx_gelu = lambda: nn.GELU(approximate="tanh") self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.window_size = window_size self.scale_shift_table = nn.Parameter(th.randn(6, hidden_size) / hidden_size ** 0.5) def forward(self, x, y, t, mask=None, **kwargs): B, N, C = x.shape # x [3, 133, 1152] shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1) x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C)) x = x + self.cross_attn(x, y, mask) x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) return x class PixArt_1D(nn.Module): """ Diffusion model with a Transformer backbone. """ def __init__(self, input_size=(256,16), in_channels=8, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, pred_sigma=True, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, cond_dim=1024, lewei_scale=1.0, use_cfg=True, cfg_scale=4.0, config=None, model_max_length=120, **kwargs): if window_block_indexes is None: window_block_indexes = [] super().__init__() self.use_cfg = use_cfg self.cfg_scale = cfg_scale self.input_size = input_size self.pred_sigma = pred_sigma self.in_channels = in_channels self.out_channels = in_channels self.num_heads = num_heads self.lewei_scale = lewei_scale, self.x_embedder = PatchEmbed_1D(input_size, in_channels, hidden_size) # self.x_embedder = PatchEmbed(input_size, patch_size, overlap, in_channels, hidden_size, bias=True) self.t_embedder = TimestepEmbedder(hidden_size) self.p_enc_1d_model = PositionalEncoding1D(hidden_size) approx_gelu = lambda: nn.GELU(approximate="tanh") self.t_block = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True) ) self.y_embedder = nn.Linear(cond_dim, hidden_size) drop_path = [x.item() for x in th.linspace(0, drop_path, depth)] # stochastic depth decay rule self.blocks = nn.ModuleList([ PixArtBlock_1D(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i], window_size=0, use_rel_pos=False) for i in range(depth) ]) self.final_layer = T2IFinalLayer(hidden_size, (1, input_size[1]), self.out_channels) self.initialize_weights() # if config: # logger = get_root_logger(os.path.join(config.work_dir, 'train_log.log')) # logger.warning(f"lewei scale: {self.lewei_scale}, base size: {self.base_size}") # else: # print(f'Warning: lewei scale: {self.lewei_scale}, base size: {self.base_size}') def forward(self, x, timestep, context_list, context_mask_list=None, **kwargs): """ Forward pass of PixArt. x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) t: (N,) tensor of diffusion timesteps y: (N, 1, 120, C) tensor of class labels """ x = x.to(self.dtype) timestep = timestep.to(self.dtype) y = context_list[0].to(self.dtype) x = self.x_embedder(x) # (N, T, D) pos_embed = self.p_enc_1d_model(x) x = x + pos_embed t = self.t_embedder(timestep.to(x.dtype)) # (N, D) t0 = self.t_block(t) y = self.y_embedder(y) # (N, L, D) try: mask = context_mask_list[0] # (N, L) except: mask = th.ones(x.shape[0], 1).to(x.device) print("MASK !!!!!!!!!!!!!!!!!!!!!!!!!") assert mask is not None # if mask is not None: y = y.masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) y_lens = mask.sum(dim=1).tolist() y_lens = [int(_) for _ in y_lens] for block in self.blocks: x = auto_grad_checkpoint(block, x, y, t0, y_lens) # (N, T, D) #support grad checkpoint x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) x = self.unpatchify_1D(x) # (N, out_channels, H, W) return x def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs): """ dpm solver donnot need variance prediction """ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb model_out = self.forward(x, timestep, y, mask) return model_out.chunk(2, dim=1)[0] def forward_with_cfg(self, x, timestep, y, cfg_scale, mask=None, **kwargs): """ Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance. """ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb half = x[: len(x) // 2] combined = th.cat([half, half], dim=0) model_out = self.forward(combined, timestep, y, mask) model_out = model_out['x'] if isinstance(model_out, dict) else model_out eps, rest = model_out[:, :8], model_out[:, 8:] cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0) half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) eps = th.cat([half_eps, half_eps], dim=0) return eps # return th.cat([eps, rest], dim=1) def unpatchify_1D(self, x): """ """ c = self.out_channels x = x.reshape(shape=(x.shape[0], self.input_size[0], self.input_size[1], c)) x = th.einsum('btfc->bctf', x) # imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) return x def initialize_weights(self): # Initialize transformer layers: def _basic_init(module): if isinstance(module, nn.Linear): th.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) self.apply(_basic_init) # Initialize (and freeze) pos_embed by sin-cos embedding: # pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size, lewei_scale=self.lewei_scale, base_size=self.base_size) # self.pos_embed.data.copy_(th.from_numpy(pos_embed).float().unsqueeze(0)) # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): w = self.x_embedder.proj.weight.data nn.init.xavier_uniform_(w.view([w.shape[0], -1])) # Initialize timestep embedding MLP: nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) nn.init.normal_(self.t_block[1].weight, std=0.02) # Initialize caption embedding MLP: nn.init.normal_(self.y_embedder.weight, std=0.02) # nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02) # Zero-out adaLN modulation layers in PixArt blocks: for block in self.blocks: nn.init.constant_(block.cross_attn.proj.weight, 0) nn.init.constant_(block.cross_attn.proj.bias, 0) # Zero-out output layers: nn.init.constant_(self.final_layer.linear.weight, 0) nn.init.constant_(self.final_layer.linear.bias, 0) @property def dtype(self): return next(self.parameters()).dtype class PixArt_Slow_1D(nn.Module): """ Diffusion model with a Transformer backbone. """ def __init__(self, input_size=(256,16), in_channels=8, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, pred_sigma=True, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, cond_dim=1024, lewei_scale=1.0, use_cfg=True, cfg_scale=4.0, config=None, model_max_length=120, **kwargs): if window_block_indexes is None: window_block_indexes = [] super().__init__() self.use_cfg = use_cfg self.cfg_scale = cfg_scale self.input_size = input_size self.pred_sigma = pred_sigma self.in_channels = in_channels self.out_channels = in_channels * 2 if pred_sigma else in_channels self.num_heads = num_heads self.lewei_scale = lewei_scale, self.x_embedder = PatchEmbed_1D(input_size, in_channels, hidden_size) # self.x_embedder = PatchEmbed(input_size, patch_size, overlap, in_channels, hidden_size, bias=True) self.t_embedder = TimestepEmbedder(hidden_size) self.p_enc_1d_model = PositionalEncoding1D(hidden_size) approx_gelu = lambda: nn.GELU(approximate="tanh") self.t_block = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True) ) self.y_embedder = nn.Linear(cond_dim, hidden_size) drop_path = [x.item() for x in th.linspace(0, drop_path, depth)] # stochastic depth decay rule self.blocks = nn.ModuleList([ PixArtBlock_Slow(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i], window_size=0, use_rel_pos=False) for i in range(depth) ]) self.final_layer = T2IFinalLayer(hidden_size, (1, input_size[1]), self.out_channels) self.initialize_weights() # if config: # logger = get_root_logger(os.path.join(config.work_dir, 'train_log.log')) # logger.warning(f"lewei scale: {self.lewei_scale}, base size: {self.base_size}") # else: # print(f'Warning: lewei scale: {self.lewei_scale}, base size: {self.base_size}') def forward(self, x, timestep, context_list, context_mask_list=None, **kwargs): """ Forward pass of PixArt. x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) t: (N,) tensor of diffusion timesteps y: (N, 1, 120, C) tensor of class labels """ x = x.to(self.dtype) timestep = timestep.to(self.dtype) y = context_list[0].to(self.dtype) x = self.x_embedder(x) # (N, T, D) pos_embed = self.p_enc_1d_model(x) x = x + pos_embed t = self.t_embedder(timestep.to(x.dtype)) # (N, D) t0 = self.t_block(t) y = self.y_embedder(y) # (N, L, D) mask = context_mask_list[0] # (N, L) assert mask is not None # if mask is not None: # y = y.masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) # y_lens = mask.sum(dim=1).tolist() # y_lens = [int(_) for _ in y_lens] for block in self.blocks: x = auto_grad_checkpoint(block, x, y, t0, mask) # (N, T, D) #support grad checkpoint x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) x = self.unpatchify_1D(x) # (N, out_channels, H, W) return x def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs): """ dpm solver donnot need variance prediction """ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb model_out = self.forward(x, timestep, y, mask) return model_out.chunk(2, dim=1)[0] def forward_with_cfg(self, x, timestep, y, cfg_scale, mask=None, **kwargs): """ Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance. """ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb half = x[: len(x) // 2] combined = th.cat([half, half], dim=0) model_out = self.forward(combined, timestep, y, mask) model_out = model_out['x'] if isinstance(model_out, dict) else model_out eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0) half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) eps = th.cat([half_eps, half_eps], dim=0) return eps # return th.cat([eps, rest], dim=1) def unpatchify_1D(self, x): """ """ c = self.out_channels x = x.reshape(shape=(x.shape[0], self.input_size[0], self.input_size[1], c)) x = th.einsum('btfc->bctf', x) # imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) return x def initialize_weights(self): # Initialize transformer layers: def _basic_init(module): if isinstance(module, nn.Linear): th.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) self.apply(_basic_init) # Initialize (and freeze) pos_embed by sin-cos embedding: # pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size, lewei_scale=self.lewei_scale, base_size=self.base_size) # self.pos_embed.data.copy_(th.from_numpy(pos_embed).float().unsqueeze(0)) # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): w = self.x_embedder.proj.weight.data nn.init.xavier_uniform_(w.view([w.shape[0], -1])) # Initialize timestep embedding MLP: nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) nn.init.normal_(self.t_block[1].weight, std=0.02) # Initialize caption embedding MLP: nn.init.normal_(self.y_embedder.weight, std=0.02) # nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02) # Zero-out adaLN modulation layers in PixArt blocks: # for block in self.blocks: # nn.init.constant_(block.cross_attn.proj.weight, 0) # nn.init.constant_(block.cross_attn.proj.bias, 0) # Zero-out output layers: nn.init.constant_(self.final_layer.linear.weight, 0) nn.init.constant_(self.final_layer.linear.bias, 0) @property def dtype(self): return next(self.parameters()).dtype def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, lewei_scale=1.0, base_size=16): """ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ # import pdb # pdb.set_trace() if isinstance(grid_size, int): grid_size = to_2tuple(grid_size) grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0]/base_size) / lewei_scale grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1]/base_size) / lewei_scale grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token and extra_tokens > 0: pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) return pos_embed def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) return np.concatenate([emb_h, emb_w], axis=1) def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float64) omega /= embed_dim / 2. omega = 1. / 10000 ** omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) return np.concatenate([emb_sin, emb_cos], axis=1) # if __name__ == '__main__' : # import pdb # pdb.set_trace() # model = PixArt_1D().to('cuda') # # x: (N, T, patch_size 0 * patch_size 1 * C) # th.manual_seed(233) # # x = th.rand(1, 4*16, 16*4*16).to('cuda') # x = th.rand(3, 8, 256, 16).to('cuda') # t = th.tensor([1, 2, 3]).to('cuda') # c = th.rand(3, 20, 1024).to('cuda') # c_mask = th.ones(3, 20).to('cuda') # c_list = [c] # c_mask_list = [c_mask] # y = model.forward(x, t, c_list, c_mask_list) # res = model.unpatchify(x) # class DiTModel(nn.Module): # """ # The full UNet model with attention and timestep embedding. # :param in_channels: channels in the input Tensor. # :param model_channels: base channel count for the model. # :param out_channels: channels in the output Tensor. # :param num_res_blocks: number of residual blocks per downsample. # :param attention_resolutions: a collection of downsample rates at which # attention will take place. May be a set, list, or tuple. # For example, if this contains 4, then at 4x downsampling, attention # will be used. # :param dropout: the dropout probability. # :param channel_mult: channel multiplier for each level of the UNet. # :param conv_resample: if True, use learned convolutions for upsampling and # downsampling. # :param dims: determines if the signal is 1D, 2D, or 3D. # :param num_classes: if specified (as an int), then this model will be # class-conditional with `num_classes` classes. # :param use_checkpoint: use gradient checkpointing to reduce memory usage. # :param num_heads: the number of attention heads in each attention layer. # :param num_heads_channels: if specified, ignore num_heads and instead use # a fixed channel width per attention head. # :param num_heads_upsample: works with num_heads to set a different number # of heads for upsampling. Deprecated. # :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. # :param resblock_updown: use residual blocks for up/downsampling. # :param use_new_attention_order: use a different attention pattern for potentially # increased efficiency. # """ # def __init__( # self, # input_size, # patch_size, # overlap, # in_channels, # embed_dim, # model_channels, # out_channels, # dims=2, # extra_film_condition_dim=None, # use_checkpoint=False, # use_fp16=False, # num_heads=-1, # num_head_channels=-1, # use_scale_shift_norm=False, # use_new_attention_order=False, # transformer_depth=1, # custom transformer support # context_dim=None, # custom transformer support # n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model # legacy=True, # ): # super().__init__() # self.x_embedder = PatchEmbed(input_size, patch_size, overlap, in_channels, embed_dim, bias=True) # num_patches = self.x_embedder.num_patches # self.pos_embed = nn.Parameter(th.zeros(1, num_patches, embed_dim), requires_grad=False) # self.blocks = nn.ModuleList([ # DiTBlock_crossattn # ]) # def convert_to_fp16(self): # """ # Convert the torso of the model to float16. # """ # # self.input_blocks.apply(convert_module_to_f16) # # self.middle_block.apply(convert_module_to_f16) # # self.output_blocks.apply(convert_module_to_f16) # def convert_to_fp32(self): # """ # Convert the torso of the model to float32. # """ # # self.input_blocks.apply(convert_module_to_f32) # # self.middle_block.apply(convert_module_to_f32) # # self.output_blocks.apply(convert_module_to_f32) # def forward( # self, # x, # timesteps=None, # y=None, # context_list=None, # context_attn_mask_list=None, # **kwargs, # ): # """ # Apply the model to an input batch. # :param x: an [N x C x ...] Tensor of inputs. # :param timesteps: a 1-D batch of timesteps. # :param context: conditioning plugged in via crossattn # :param y: an [N] Tensor of labels, if class-conditional. an [N, extra_film_condition_dim] Tensor if film-embed conditional # :return: an [N x C x ...] Tensor of outputs. # """ # x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 # t = self.t_embedder(timesteps) # (N, D) # y = self.y_embedder(y, self.training) # (N, D) # c = t + y # (N, D) # for block in self.blocks: # x = block(x, c) # (N, T, D) # x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) # x = self.unpatchify(x) # (N, out_channels, H, W)