jadechoghari's picture
add model
9b9e0ee verified
raw
history blame
134 kB
from abc import abstractmethod
from functools import partial
import math
from typing import Iterable
import numpy as np
import torch as th
#from .utils_pos_embedding.pos_embed import RoPE2D
import torch.nn as nn
import torch.nn.functional as F
import sys
from fairscale.nn.model_parallel.layers import (
ColumnParallelLinear,
ParallelEmbedding,
RowParallelLinear,
)
from timm.models.layers import DropPath
from .utils import auto_grad_checkpoint, to_2tuple
from .PixArt_blocks import t2i_modulate, WindowAttention, MultiHeadCrossAttention, T2IFinalLayer, TimestepEmbedder, FinalLayer
import xformers.ops
import math
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
"""
def __init__(
self,
img_size=(256, 16),
patch_size=(16, 4),
overlap = (0, 0),
in_chans=128,
embed_dim=768,
norm_layer=None,
flatten=True,
bias=True,
):
super().__init__()
# img_size=(256, 16)
# patch_size=(16, 4)
# overlap = (2, 2)
# in_chans=128
# embed_dim=768
# import pdb
# pdb.set_trace()
self.img_size = img_size
self.patch_size = patch_size
self.ol = overlap
self.grid_size = (math.ceil((img_size[0] - patch_size[0]) / (patch_size[0]-overlap[0])) + 1,
math.ceil((img_size[1] - patch_size[1]) / (patch_size[1]-overlap[1])) + 1)
self.pad_size = ((self.grid_size[0]-1) * (self.patch_size[0]-overlap[0])+self.patch_size[0]-self.img_size[0],
+(self.grid_size[1]-1)*(self.patch_size[1]-overlap[1])+self.patch_size[1]-self.img_size[1])
self.pad_size = (self.pad_size[0] // 2, self.pad_size[1] // 2)
# self.p-ad_size = (((img_size[0] - patch_size[0]) // ), )
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0]-overlap[0], patch_size[1]-overlap[1]), bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
# B, C, H, W = x.shape
# _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
# _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
x = F.pad(x, (self.pad_size[-1], self.pad_size[-1], self.pad_size[-2], self.pad_size[-2]), "constant", 0)
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x
class PatchEmbed_1D(nn.Module):
def __init__(
self,
img_size=(256, 16),
# patch_size=(16, 4),
# overlap = (0, 0),
in_chans=8,
embed_dim=1152,
norm_layer=None,
# flatten=True,
bias=True,
):
super().__init__()
self.proj = nn.Linear(in_chans*img_size[1], embed_dim, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
# B, C, H, W = x.shape
# _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
# _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
# x = F.pad(x, (self.pad_size[-1], self.pad_size[-1], self.pad_size[-2], self.pad_size[-2]), "constant", 0)
# x = self.proj(x)
# if self.flatten:
# x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = th.einsum('bctf->btfc', x)
x = x.flatten(2) # BTFC -> BTD
x = self.proj(x)
x = self.norm(x)
return x
# if __name__ == '__main__':
# x = th.rand(1, 256, 16).unsqueeze(0)
# model = PatchEmbed(in_chans=1)
# y = model(x)
from timm.models.vision_transformer import Attention, Mlp
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
from positional_encodings.torch_encodings import PositionalEncoding1D
def t2i_modulate(x, shift, scale):
return x * (1 + scale) + shift
class PixArtBlock(nn.Module):
"""
A PixArt block with adaptive layer norm (adaLN-single) conditioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., window_size=0, input_size=None, use_rel_pos=False, **block_kwargs):
super().__init__()
self.hidden_size = hidden_size
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = WindowAttention(hidden_size, num_heads=num_heads, qkv_bias=True,
input_size=input_size if window_size == 0 else (window_size, window_size),
use_rel_pos=use_rel_pos, **block_kwargs)
self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
# to be compatible with lower version pytorch
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.window_size = window_size
self.scale_shift_table = nn.Parameter(th.randn(6, hidden_size) / hidden_size ** 0.5)
def forward(self, x, y, t, mask=None, **kwargs):
B, N, C = x.shape
# x [B, T, D]
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1)
x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C))
x = x + self.cross_attn(x, y, mask)
x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
return x
from qa_mdt.audioldm_train.modules.diffusionmodules.attention import CrossAttention_1D
class PixArtBlock_Slow(nn.Module):
"""
A PixArt block with adaptive layer norm (adaLN-single) conditioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., window_size=0, input_size=None, use_rel_pos=False, **block_kwargs):
super().__init__()
self.hidden_size = hidden_size
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = CrossAttention_1D(query_dim=hidden_size, context_dim=hidden_size, heads=num_heads, dim_head=int(hidden_size/num_heads))
self.cross_attn = CrossAttention_1D(query_dim=hidden_size, context_dim=hidden_size, heads=num_heads, dim_head=int(hidden_size/num_heads))
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
# to be compatible with lower version pytorch
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.window_size = window_size
self.scale_shift_table = nn.Parameter(th.randn(6, hidden_size) / hidden_size ** 0.5)
def forward(self, x, y, t, mask=None, **kwargs):
B, N, C = x.shape
# x [B, T, D]
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1)
x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C))
x = x + self.cross_attn(x, y, mask)
x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
return x
class PixArt(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(self, input_size=(256,16), patch_size=(16,4), overlap=(0, 0), in_channels=8, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, pred_sigma=True, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, cond_dim=1024, lewei_scale=1.0,
use_cfg=True, cfg_scale=4.0, config=None, model_max_length=120, **kwargs):
if window_block_indexes is None:
window_block_indexes = []
super().__init__()
self.use_cfg = use_cfg
self.cfg_scale = cfg_scale
self.input_size = input_size
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if pred_sigma else in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.lewei_scale = lewei_scale,
self.x_embedder = PatchEmbed(input_size, patch_size, overlap, in_channels, hidden_size, bias=True)
# self.x_embedder = PatchEmbed_1D(input)
self.t_embedder = TimestepEmbedder(hidden_size)
num_patches = self.x_embedder.num_patches
self.base_size = input_size[0] // self.patch_size[0] * 2
# Will use fixed sin-cos embedding:
self.register_buffer("pos_embed", th.zeros(1, num_patches, hidden_size))
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.t_block = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
self.y_embedder = nn.Linear(cond_dim, hidden_size)
drop_path = [x.item() for x in th.linspace(0, drop_path, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
PixArtBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False)
for i in range(depth)
])
self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels)
self.initialize_weights()
# if config:
# logger = get_root_logger(os.path.join(config.work_dir, 'train_log.log'))
# logger.warning(f"lewei scale: {self.lewei_scale}, base size: {self.base_size}")
# else:
# print(f'Warning: lewei scale: {self.lewei_scale}, base size: {self.base_size}')
def forward(self, x, timestep, context_list, context_mask_list=None, **kwargs):
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = context_list[0].to(self.dtype)
pos_embed = self.pos_embed.to(self.dtype)
self.h, self.w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
t = self.t_embedder(timestep.to(x.dtype)) # (N, D)
t0 = self.t_block(t)
y = self.y_embedder(y) # (N, L, D)
mask = context_mask_list[0] # (N, L)
assert mask is not None
# if mask is not None:
y = y.masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
y_lens = [int(_) for _ in y_lens]
for block in self.blocks:
x = auto_grad_checkpoint(block, x, y, t0, y_lens) # (N, T, D) #support grad checkpoint
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x) # (N, out_channels, H, W)
return x
def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs):
"""
dpm solver donnot need variance prediction
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
model_out = self.forward(x, timestep, y, mask)
return model_out.chunk(2, dim=1)[0]
def forward_with_cfg(self, x, timestep, y, cfg_scale, mask=None, **kwargs):
"""
Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance.
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
half = x[: len(x) // 2]
combined = th.cat([half, half], dim=0)
model_out = self.forward(combined, timestep, y, mask)
model_out = model_out['x'] if isinstance(model_out, dict) else model_out
eps, rest = model_out[:, :8], model_out[:, 8:]
cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = th.cat([half_eps, half_eps], dim=0)
return eps
# return th.cat([eps, rest], dim=1)
def unpatchify(self, x):
"""
x: (N, T, patch_size 0 * patch_size 1 * C)
imgs: (Bs. 256. 16. 8)
"""
# torch_map = th.zeros(self.x_embedder.img_size[0]+2*self.x_embedder.pad_size[0],
# self.x_embedder.img_size[1]+2*self.x_embedder.pad_size[1]).to(x.device)
# lf = self.x_embedder.grid_size[0]
# rf = self.x_embedder.grid_size[1]
# for i in range(lf):
# for j in range(rf):
# xx = (i) * (self.x_embedder.patch_size[0]-self.x_embedder.ol[0])
# yy = (j) * (self.x_embedder.patch_size[1]-self.x_embedder.ol[1])
# torch_map[xx:(xx+self.x_embedder.patch_size[0]), yy:(yy+self.x_embedder.patch_size[1])]+=1
# torch_map = torch_map[self.x_embedder.pad_size[0]:self.x_embedder.pad_size[0]+self.x_embedder.img_size[0],
# self.x_embedder.pad_size[1]:self.x_embedder.pad_size[1]+self.x_embedder.img_size[1]]
# torch_map = th.reciprocal(torch_map)
# c = self.out_channels
# p0, p1 = self.x_embedder.patch_size[0], self.x_embedder.patch_size[1]
# x = x.reshape(shape=(x.shape[0], self.x_embedder.grid_size[0],
# self.x_embedder.grid_size[1], p0, p1, c))
# x = th.einsum('nhwpqc->nchwpq', x)
# added_map = th.zeros(x.shape[0], c,
# self.x_embedder.img_size[0]+2*self.x_embedder.pad_size[0],
# self.x_embedder.img_size[1]+2*self.x_embedder.pad_size[1]).to(x.device)
# for b_id in range(x.shape[0]):
# for i in range(lf):
# for j in range(rf):
# for c_id in range(c):
# xx = (i) * (self.x_embedder.patch_size[0]-self.x_embedder.ol[0])
# yy = (j) * (self.x_embedder.patch_size[1]-self.x_embedder.ol[1])
# added_map[b_id][c_id][xx:(xx+self.x_embedder.patch_size[0]), yy:(yy+self.x_embedder.patch_size[1])] += \
# x[b_id, c_id, i, j]
# ret_map = th.zeros(x.shape[0], c, self.x_embedder.img_size[0],
# self.x_embedder.img_size[1]).to(x.device)
# for b_id in range(x.shape[0]):
# for id_c in range(c):
# ret_map[b_id, id_c, :, :] = th.mul(added_map[b_id][id_c][self.x_embedder.pad_size[0]:self.x_embedder.pad_size[0]+self.x_embedder.img_size[0],
# self.x_embedder.pad_size[1]:self.x_embedder.pad_size[1]+self.x_embedder.img_size[1]], torch_map)
c = self.out_channels
p0 = self.x_embedder.patch_size[0]
p1 = self.x_embedder.patch_size[1]
h, w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]
# h = w = int(x.shape[1] ** 0.5)
# print(x.shape, h, w, p0, p1)
# import pdb
# pdb.set_trace()
x = x.reshape(shape=(x.shape[0], h, w, p0, p1, c))
x = th.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p0, w * p1))
return imgs
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
th.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize (and freeze) pos_embed by sin-cos embedding:
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size, lewei_scale=self.lewei_scale, base_size=self.base_size)
self.pos_embed.data.copy_(th.from_numpy(pos_embed).float().unsqueeze(0))
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
nn.init.normal_(self.t_block[1].weight, std=0.02)
# Initialize caption embedding MLP:
nn.init.normal_(self.y_embedder.weight, std=0.02)
# nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
# Zero-out adaLN modulation layers in PixArt blocks:
for block in self.blocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
@property
def dtype(self):
return next(self.parameters()).dtype
class SwiGLU(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = ColumnParallelLinear(
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
)
self.w2 = RowParallelLinear(
hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
)
self.w3 = ColumnParallelLinear(
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class DEBlock(nn.Module):
"""
Decoder block with added SpecTNT transformer
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, FFN_type='SwiGLU', drop_path=0., window_size=0, input_size=None, use_rel_pos=False, skip=False, num_f=None, num_t=None, **block_kwargs):
super().__init__()
self.hidden_size = hidden_size
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = WindowAttention(hidden_size, num_heads=num_heads, qkv_bias=True,
input_size=input_size if window_size == 0 else (window_size, window_size),
use_rel_pos=use_rel_pos, **block_kwargs)
self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs)
# self.cross_attn_f = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs)
# self.cross_attn_t = MultiHeadCrossAttention(hidden_size*num_f, num_heads, **block_kwargs)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.norm4 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.norm5 = nn.LayerNorm(hidden_size * num_f, elementwise_affine=False, eps=1e-6)
self.norm6 = nn.LayerNorm(hidden_size * num_f, elementwise_affine=False, eps=1e-6)
# to be compatible with lower version pytorch
approx_gelu = lambda: nn.GELU(approximate="tanh")
if FFN_type == 'mlp':
self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0)
# self.mlp2 = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0)
# self.mlp3 = Mlp(in_features=hidden_size*num_f, hidden_features=int(hidden_size*num_f * mlp_ratio), act_layer=approx_gelu, drop=0)
elif FFN_type == 'SwiGLU':
self.mlp = SwiGLU(hidden_size, int(hidden_size * mlp_ratio), 1)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.window_size = window_size
self.scale_shift_table = nn.Parameter(th.randn(6, hidden_size) / hidden_size ** 0.5)
# self.scale_shift_table_2 = nn.Parameter(th.randn(6, hidden_size) / hidden_size ** 0.5)
# self.scale_shift_table_3 = nn.Parameter(th.randn(6, hidden_size) / hidden_size ** 0.5)
self.skip_linear = nn.Linear(2 * hidden_size, hidden_size) if skip else None
self.F_transformer = WindowAttention(hidden_size, num_heads=4, qkv_bias=True,
input_size=input_size if window_size == 0 else (window_size, window_size),
use_rel_pos=use_rel_pos, **block_kwargs)
self.T_transformer = WindowAttention(hidden_size * num_f, num_heads=16, qkv_bias=True,
input_size=input_size if window_size == 0 else (window_size, window_size),
use_rel_pos=use_rel_pos, **block_kwargs)
self.f_pos = nn.Embedding(num_f, hidden_size)
self.t_pos = nn.Embedding(num_t, hidden_size * num_f)
self.num_f = num_f
self.num_t = num_t
def forward(self, x_normal, end, y, t, mask=None, skip=None, ids_keep=None, **kwargs):
# import pdb
# pdb.set_trace()
B, D, C = x_normal.shape
T = self.num_t
F_add_1 = self.num_f
# B, T, F_add_1, C = x.shape
# F_add_1 = F_add_1 + 1
# x_normal = th.reshape()
# # x_end [B, T, 1, C]
# x_end = x[:, :, -1, :].unsqueeze(2)
if self.skip_linear is not None:
x_normal = self.skip_linear(th.cat([x_normal, skip], dim=-1))
D = T * (F_add_1 - 1)
# x_normal [B, D, C]
# import pdb
# pdb.set_trace()
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1)
x_normal = x_normal + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x_normal), shift_msa, scale_msa)).reshape(B, D, C))
x_normal = x_normal.reshape(B, T, F_add_1-1, C)
x_normal = th.cat((x_normal, end), 2)
# x_normal [B*T, F+1, C]
x_normal = x_normal.reshape(B*T, F_add_1, C)
pos_f = th.arange(self.num_f, device=x.device).unsqueeze(0).expand(B*T, -1)
# import pdb; pdb.set_trace()
x_normal = x_normal + self.f_pos(pos_f)
x_normal = x_normal + self.F_transformer(self.norm3(x_normal))
# x_normal = x_normal + self.cross_attn_f(x_normal, y, mask)
# x_normal = x_normal + self.mlp2(self.norm4(x_normal))
# x_normal [B, T, (F+1) * C]
x_normal = x_normal.reshape(B, T, F_add_1 * C)
pos_t = th.arange(self.num_t, device=x.device).unsqueeze(0).expand(B, -1)
x_normal = x_normal + self.t_pos(pos_t)
x_normal = x_normal + self.T_transformer(self.norm5(x_normal))
# x_normal = x_normal + self.cross_attn_t(x_normal, y, mask)
x_normal = x_normal.reshape(B, T ,F_add_1, C)
end = x_normal[:, :, -1, :].unsqueeze(2)
x_normal = x_normal[:, :, :-1, :]
x_normal = x_normal.reshape(B, T*(F_add_1 - 1), C)
x_normal = x_normal + self.cross_attn(x_normal, y, mask)
x_normal = x_normal + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x_normal), shift_mlp, scale_mlp)))
# x_normal = th.cat
return x_normal, end #.reshape(B, )
class MDTBlock(nn.Module):
"""
A PixArt block with adaptive layer norm (adaLN-single) conditioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, FFN_type='mlp', drop_path=0., window_size=0, input_size=None, use_rel_pos=False, skip=False, **block_kwargs):
super().__init__()
self.hidden_size = hidden_size
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = WindowAttention(hidden_size, num_heads=num_heads, qkv_bias=True,
input_size=input_size if window_size == 0 else (window_size, window_size),
use_rel_pos=use_rel_pos, **block_kwargs)
self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
# to be compatible with lower version pytorch
approx_gelu = lambda: nn.GELU(approximate="tanh")
if FFN_type == 'mlp':
self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0)
elif FFN_type == 'SwiGLU':
self.mlp = SwiGLU(hidden_size, int(hidden_size * mlp_ratio), 1)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.window_size = window_size
self.scale_shift_table = nn.Parameter(th.randn(6, hidden_size) / hidden_size ** 0.5)
self.skip_linear = nn.Linear(2 * hidden_size, hidden_size) if skip else None
def forward(self, x, y, t, mask=None, skip=None, ids_keep=None, **kwargs):
B, N, C = x.shape
if self.skip_linear is not None:
x = self.skip_linear(th.cat([x, skip], dim=-1))
# x [B, T, D]
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1)
x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C))
x = x + self.cross_attn(x, y, mask)
x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
return x
class PixArt_MDT_MASK_TF(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(self, input_size=(256,16), patch_size=(16,4), overlap=(0, 0), in_channels=8, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, pred_sigma=False, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, cond_dim=1024, lewei_scale=1.0,
use_cfg=True, cfg_scale=4.0, config=None, model_max_length=120, mask_t=0.17, mask_f=0.15, decode_layer=4,**kwargs):
if window_block_indexes is None:
window_block_indexes = []
super().__init__()
self.use_cfg = use_cfg
self.cfg_scale = cfg_scale
self.input_size = input_size
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.lewei_scale = lewei_scale,
decode_layer = int(decode_layer)
self.x_embedder = PatchEmbed(input_size, patch_size, overlap, in_channels, hidden_size, bias=True)
# self.x_embedder = PatchEmbed_1D(input)
self.t_embedder = TimestepEmbedder(hidden_size)
num_patches = self.x_embedder.num_patches
self.base_size = input_size[0] // self.patch_size[0] * 2
# Will use fixed sin-cos embedding:
self.register_buffer("pos_embed", th.zeros(1, num_patches, hidden_size))
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.t_block = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
self.y_embedder = nn.Linear(cond_dim, hidden_size)
half_depth = (depth - decode_layer)//2
self.half_depth=half_depth
drop_path_half = [x.item() for x in th.linspace(0, drop_path, half_depth)] # stochastic depth decay rule
drop_path_decode = [x.item() for x in th.linspace(0, drop_path, decode_layer)]
self.en_inblocks = nn.ModuleList([
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i],
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False, FFN_type='mlp') for i in range(half_depth)
])
self.en_outblocks = nn.ModuleList([
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i],
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False, skip=True, FFN_type='mlp') for i in range(half_depth)
])
self.de_blocks = nn.ModuleList([
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_decode[i],
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False, skip=True, FFN_type='mlp') for i in range(decode_layer)
])
self.sideblocks = nn.ModuleList([
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio,
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False, FFN_type='mlp') for _ in range(1)
])
self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels)
self.decoder_pos_embed = nn.Parameter(th.zeros(
1, num_patches, hidden_size), requires_grad=True)
# if mask_ratio is not None:
# self.mask_token = nn.Parameter(th.zeros(1, 1, hidden_size))
# self.mask_ratio = float(mask_ratio)
# self.decode_layer = int(decode_layer)
# else:
# self.mask_token = nn.Parameter(th.zeros(
# 1, 1, hidden_size), requires_grad=False)
# self.mask_ratio = None
# self.decode_layer = int(decode_layer)
assert mask_t != 0 and mask_f != 0
self.mask_token = nn.Parameter(th.zeros(1, 1, hidden_size))
self.mask_t = mask_t
self.mask_f = mask_f
self.decode_layer = int(decode_layer)
print(f"mask ratio: T-{self.mask_t} F-{self.mask_f}", "decode_layer:", self.decode_layer)
self.initialize_weights()
def forward(self, x, timestep, context_list, context_mask_list=None, enable_mask=False, **kwargs):
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = context_list[0].to(self.dtype)
pos_embed = self.pos_embed.to(self.dtype)
self.h, self.w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]
# import pdb
# pdb.set_trace()
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
t = self.t_embedder(timestep.to(x.dtype)) # (N, D)
t0 = self.t_block(t)
y = self.y_embedder(y) # (N, L, D)
# if not self.training:
try:
mask = context_mask_list[0] # (N, L)
except:
mask = th.ones(x.shape[0], 1).to(x.device)
print("MASK !!!!!!!!!!!!!!!!!!!!!!!!!")
assert mask is not None
# if mask is not None:
y = y.masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
y_lens = [int(_) for _ in y_lens]
input_skip = x
masked_stage = False
skips = []
# TODO : masking op for training
if self.mask_t is not None and self.training:
# masking: length -> length * mask_ratio
rand_mask_ratio = th.rand(1, device=x.device) # noise in [0, 1]
rand_mask_ratio_t = rand_mask_ratio * 0.13 + self.mask_t # mask_ratio, mask_ratio + 0.2
rand_mask_ratio_f = rand_mask_ratio * 0.13 + self.mask_f
# print(rand_mask_ratio)
x, mask, ids_restore, ids_keep = self.random_masking_2d(
x, rand_mask_ratio_t, rand_mask_ratio_f)
masked_stage = True
for block in self.en_inblocks:
if masked_stage:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=ids_keep)
else:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=None)
skips.append(x)
for block in self.en_outblocks:
if masked_stage:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=ids_keep)
else:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=None)
if self.mask_t is not None and self.mask_f is not None and self.training:
x = self.forward_side_interpolater(x, y, t0, y_lens, mask, ids_restore)
masked_stage = False
else:
# add pos embed
x = x + self.decoder_pos_embed
for i in range(len(self.de_blocks)):
block = self.de_blocks[i]
this_skip = input_skip
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=this_skip, ids_keep=None)
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x) # (N, out_channels, H, W)
return x
def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs):
"""
dpm solver donnot need variance prediction
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
model_out = self.forward(x, timestep, y, mask)
return model_out.chunk(2, dim=1)[0]
def forward_with_cfg(self, x, timestep, context_list, context_mask_list=None, cfg_scale=4.0, **kwargs):
"""
Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance.
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
# import pdb
# pdb.set_trace()
half = x[: len(x) // 2]
combined = th.cat([half, half], dim=0)
model_out = self.forward(combined, timestep, context_list, context_mask_list=None)
model_out = model_out['x'] if isinstance(model_out, dict) else model_out
eps, rest = model_out[:, :8], model_out[:, 8:]
cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = th.cat([half_eps, half_eps], dim=0)
return eps
# return th.cat([eps, rest], dim=1)
def unpatchify(self, x):
"""
x: (N, T, patch_size 0 * patch_size 1 * C)
imgs: (Bs. 256. 16. 8)
"""
if self.x_embedder.ol == (0, 0) or self.x_embedder.ol == [0, 0]:
c = self.out_channels
p0 = self.x_embedder.patch_size[0]
p1 = self.x_embedder.patch_size[1]
h, w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]
x = x.reshape(shape=(x.shape[0], h, w, p0, p1, c))
x = th.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p0, w * p1))
return imgs
lf = self.x_embedder.grid_size[0]
rf = self.x_embedder.grid_size[1]
lp = self.x_embedder.patch_size[0]
rp = self.x_embedder.patch_size[1]
lo = self.x_embedder.ol[0]
ro = self.x_embedder.ol[1]
lm = self.x_embedder.img_size[0]
rm = self.x_embedder.img_size[1]
lpad = self.x_embedder.pad_size[0]
rpad = self.x_embedder.pad_size[1]
bs = x.shape[0]
torch_map = self.torch_map
c = self.out_channels
x = x.reshape(shape=(bs, lf, rf, lp, rp, c))
x = th.einsum('nhwpqc->nchwpq', x)
added_map = th.zeros(bs, c, lm+2*lpad, rm+2*rpad).to(x.device)
for i in range(lf):
for j in range(rf):
xx = (i) * (lp - lo)
yy = (j) * (rp - ro)
added_map[:, :, xx:(xx+lp), yy:(yy+rp)] += \
x[:, :, i, j, :, :]
# import pdb
# pdb.set_trace()
added_map = added_map[:][:][lpad:lm+lpad, rpad:rm+rpad]
return th.mul(added_map.to(x.device), torch_map.to(x.device))
def random_masking_2d(self, x, mask_t_prob, mask_f_prob):
"""
2D: Spectrogram (msking t and f under mask_t_prob and mask_f_prob)
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
# if self.use_custom_patch: # overlapped patch
# T=101
# F=12
# else:
# T=64
# F=8
T = self.x_embedder.grid_size[0]
F = self.x_embedder.grid_size[1]
#x = x.reshape(N, T, F, D)
len_keep_t = int(T * (1 - mask_t_prob))
len_keep_f = int(F * (1 - mask_f_prob))
# noise for mask in time
noise_t = th.rand(N, T, device=x.device) # noise in [0, 1]
# sort noise for each sample aling time
ids_shuffle_t = th.argsort(noise_t, dim=1) # ascend: small is keep, large is remove
ids_restore_t = th.argsort(ids_shuffle_t, dim=1)
ids_keep_t = ids_shuffle_t[:,:len_keep_t]
# noise mask in freq
noise_f = th.rand(N, F, device=x.device) # noise in [0, 1]
ids_shuffle_f = th.argsort(noise_f, dim=1) # ascend: small is keep, large is remove
ids_restore_f = th.argsort(ids_shuffle_f, dim=1)
ids_keep_f = ids_shuffle_f[:,:len_keep_f] #
# generate the binary mask: 0 is keep, 1 is remove
# mask in freq
mask_f = th.ones(N, F, device=x.device)
mask_f[:,:len_keep_f] = 0
mask_f = th.gather(mask_f, dim=1, index=ids_restore_f).unsqueeze(1).repeat(1,T,1) # N,T,F
# mask in time
mask_t = th.ones(N, T, device=x.device)
mask_t[:,:len_keep_t] = 0
mask_t = th.gather(mask_t, dim=1, index=ids_restore_t).unsqueeze(1).repeat(1,F,1).permute(0,2,1) # N,T,F
mask = 1-(1-mask_t)*(1-mask_f) # N, T, F
# get masked x
id2res=th.Tensor(list(range(N*T*F))).reshape(N,T,F).to(x.device)
id2res = id2res + 999*mask # add a large value for masked elements
id2res2 = th.argsort(id2res.flatten(start_dim=1))
ids_keep=id2res2.flatten(start_dim=1)[:,:len_keep_f*len_keep_t]
x_masked = th.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
ids_restore = th.argsort(id2res2.flatten(start_dim=1))
mask = mask.flatten(start_dim=1)
return x_masked, mask, ids_restore, ids_keep
def random_masking(self, x, mask_ratio):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
noise = th.rand(N, L, device=x.device) # noise in [0, 1]
# sort noise for each sample
# ascend: small is keep, large is remove
ids_shuffle = th.argsort(noise, dim=1)
ids_restore = th.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = th.gather(
x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = th.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = th.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore, ids_keep
def forward_side_interpolater(self, x, y, t0, y_lens, mask, ids_restore):
# append mask tokens to sequence
mask_tokens = self.mask_token.repeat(
x.shape[0], ids_restore.shape[1] - x.shape[1], 1)
x_ = th.cat([x, mask_tokens], dim=1)
x = th.gather(
x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
# add pos embed
x = x + self.decoder_pos_embed
# pass to the basic block
x_before = x
for sideblock in self.sideblocks:
x = sideblock(x, y, t0, y_lens, ids_keep=None)
# masked shortcut
mask = mask.unsqueeze(dim=-1)
x = x*mask + (1-mask)*x_before
return x
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
th.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize (and freeze) pos_embed by sin-cos embedding:
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size, lewei_scale=self.lewei_scale, base_size=self.base_size)
self.pos_embed.data.copy_(th.from_numpy(pos_embed).float().unsqueeze(0))
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
nn.init.normal_(self.t_block[1].weight, std=0.02)
# Initialize caption embedding MLP:
nn.init.normal_(self.y_embedder.weight, std=0.02)
# nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
# Zero-out adaLN modulation layers in PixArt blocks:
for block in self.en_inblocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
for block in self.en_outblocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
for block in self.de_blocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
for block in self.sideblocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
if self.x_embedder.ol == [0, 0] or self.x_embedder.ol == (0, 0):
return
lf = self.x_embedder.grid_size[0]
rf = self.x_embedder.grid_size[1]
lp = self.x_embedder.patch_size[0]
rp = self.x_embedder.patch_size[1]
lo = self.x_embedder.ol[0]
ro = self.x_embedder.ol[1]
lm = self.x_embedder.img_size[0]
rm = self.x_embedder.img_size[1]
lpad = self.x_embedder.pad_size[0]
rpad = self.x_embedder.pad_size[1]
torch_map = th.zeros(lm+2*lpad, rm+2*rpad).to('cuda')
for i in range(lf):
for j in range(rf):
xx = (i) * (lp - lo)
yy = (j) * (rp - ro)
torch_map[xx:(xx+lp), yy:(yy+rp)]+=1
torch_map = torch_map[lpad:lm+lpad, rpad:rm+rpad]
self.torch_map = th.reciprocal(torch_map)
@property
def dtype(self):
return next(self.parameters()).dtype
class PixArt_MDT_MOS_AS_TOKEN(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(self, input_size=(256,16), patch_size=(16,4), overlap=(0, 0), in_channels=8, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, pred_sigma=False, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, cond_dim=1024, lewei_scale=1.0,
use_cfg=True, cfg_scale=4.0, config=None, model_max_length=120, mask_ratio=None, decode_layer=4,**kwargs):
if window_block_indexes is None:
window_block_indexes = []
super().__init__()
self.use_cfg = use_cfg
self.cfg_scale = cfg_scale
self.input_size = input_size
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.lewei_scale = lewei_scale,
decode_layer = int(decode_layer)
self.mos_embed = nn.Embedding(num_embeddings=5, embedding_dim=hidden_size)
self.x_embedder = PatchEmbed(input_size, patch_size, overlap, in_channels, hidden_size, bias=True)
# self.x_embedder = PatchEmbed_1D(input)
self.t_embedder = TimestepEmbedder(hidden_size)
num_patches = self.x_embedder.num_patches
self.base_size = input_size[0] // self.patch_size[0] * 2
# Will use fixed sin-cos embedding:
self.register_buffer("pos_embed", th.zeros(1, num_patches, hidden_size))
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.t_block = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
# self.mos_block = nn.Sequential(
# )
self.y_embedder = nn.Linear(cond_dim, hidden_size)
half_depth = (depth - decode_layer)//2
self.half_depth=half_depth
drop_path_half = [x.item() for x in th.linspace(0, drop_path, half_depth)] # stochastic depth decay rule
drop_path_decode = [x.item() for x in th.linspace(0, drop_path, decode_layer)]
self.en_inblocks = nn.ModuleList([
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i],
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False, FFN_type='mlp') for i in range(half_depth)
])
self.en_outblocks = nn.ModuleList([
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i],
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False, skip=True, FFN_type='mlp') for i in range(half_depth)
])
self.de_blocks = nn.ModuleList([
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_decode[i],
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False, skip=True, FFN_type='mlp') for i in range(decode_layer)
])
self.sideblocks = nn.ModuleList([
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio,
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False, FFN_type='mlp') for _ in range(1)
])
self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels)
self.decoder_pos_embed = nn.Parameter(th.zeros(
1, num_patches, hidden_size), requires_grad=True)
if mask_ratio is not None:
self.mask_token = nn.Parameter(th.zeros(1, 1, hidden_size))
self.mask_ratio = float(mask_ratio)
self.decode_layer = int(decode_layer)
else:
self.mask_token = nn.Parameter(th.zeros(
1, 1, hidden_size), requires_grad=False)
self.mask_ratio = None
self.decode_layer = int(decode_layer)
print("mask ratio:", self.mask_ratio, "decode_layer:", self.decode_layer)
self.initialize_weights()
def forward(self, x, timestep, context_list, context_mask_list=None, enable_mask=False, mos=None, **kwargs):
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
# mos = th.ones(x.shape[0], dtype=th.int).to(x.device)
#print(f'DEBUG! {x}, {mos}')
assert mos.shape[0] == x.shape[0]
#import pdb; pdb.set_trace()
mos = mos - 1
mos = self.mos_embed(mos.to(x.device).to(th.int)) # [N, dim]
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = context_list[0].to(self.dtype)
pos_embed = self.pos_embed.to(self.dtype)
self.h, self.w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]
# import pdb
# pdb.set_trace()
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
t = self.t_embedder(timestep.to(x.dtype)) # (N, D)
t0 = self.t_block(t)
y = self.y_embedder(y) # (N, L, D)
# if not self.training:
try:
mask = context_mask_list[0] # (N, L)
except:
mask = th.ones(x.shape[0], 1).to(x.device)
assert mask is not None
# if mask is not None:
y = y.masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
y_lens = [int(_) for _ in y_lens]
masked_stage = False
skips = []
# TODO : masking op for training
try:
x = th.cat([mos, x], dim=1) # [N, L+1, dim]
except:
x = th.cat([mos.unsqueeze(1), x], dim=1)
input_skip = x
if self.mask_ratio is not None and self.training:
# masking: length -> length * mask_ratio
rand_mask_ratio = th.rand(1, device=x.device) # noise in [0, 1]
rand_mask_ratio = rand_mask_ratio * 0.2 + self.mask_ratio # mask_ratio, mask_ratio + 0.2
# print(rand_mask_ratio)
x, mask, ids_restore, ids_keep = self.random_masking(
x, rand_mask_ratio)
masked_stage = True
for block in self.en_inblocks:
if masked_stage:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=ids_keep)
else:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=None)
skips.append(x)
for block in self.en_outblocks:
if masked_stage:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=ids_keep)
else:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=None)
if self.mask_ratio is not None and self.training:
x = self.forward_side_interpolater(x, y, t0, y_lens, mask, ids_restore)
masked_stage = False
else:
# add pos embed
x[:, 1:, :] = x[:, 1:, :] + self.decoder_pos_embed
# x = x + self.decoder_pos_embed
# import pdb
# pdb.set_trace()
for i in range(len(self.de_blocks)):
block = self.de_blocks[i]
this_skip = input_skip
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=this_skip, ids_keep=None)
x = x[:, 1:, :]
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x) # (N, out_channels, H, W)
# import pdb
# pdb.set_trace()
return x
def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs):
"""
dpm solver donnot need variance prediction
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
model_out = self.forward(x, timestep, y, mask)
return model_out.chunk(2, dim=1)[0]
def forward_with_cfg(self, x, timestep, context_list, context_mask_list=None, cfg_scale=4.0, **kwargs):
"""
Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance.
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
# import pdb
# pdb.set_trace()
half = x[: len(x) // 2]
combined = th.cat([half, half], dim=0)
model_out = self.forward(combined, timestep, context_list, context_mask_list=None)
model_out = model_out['x'] if isinstance(model_out, dict) else model_out
eps, rest = model_out[:, :8], model_out[:, 8:]
cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = th.cat([half_eps, half_eps], dim=0)
return eps
# return th.cat([eps, rest], dim=1)
def unpatchify(self, x):
"""
x: (N, T, patch_size 0 * patch_size 1 * C)
imgs: (Bs. 256. 16. 8)
"""
if self.x_embedder.ol == (0, 0) or self.x_embedder.ol == [0, 0]:
c = self.out_channels
p0 = self.x_embedder.patch_size[0]
p1 = self.x_embedder.patch_size[1]
h, w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]
x = x.reshape(shape=(x.shape[0], h, w, p0, p1, c))
x = th.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p0, w * p1))
return imgs
lf = self.x_embedder.grid_size[0]
rf = self.x_embedder.grid_size[1]
lp = self.x_embedder.patch_size[0]
rp = self.x_embedder.patch_size[1]
lo = self.x_embedder.ol[0]
ro = self.x_embedder.ol[1]
lm = self.x_embedder.img_size[0]
rm = self.x_embedder.img_size[1]
lpad = self.x_embedder.pad_size[0]
rpad = self.x_embedder.pad_size[1]
bs = x.shape[0]
torch_map = self.torch_map
c = self.out_channels
x = x.reshape(shape=(bs, lf, rf, lp, rp, c))
x = th.einsum('nhwpqc->nchwpq', x)
added_map = th.zeros(bs, c, lm+2*lpad, rm+2*rpad).to(x.device)
for i in range(lf):
for j in range(rf):
xx = (i) * (lp - lo)
yy = (j) * (rp - ro)
added_map[:, :, xx:(xx+lp), yy:(yy+rp)] += \
x[:, :, i, j, :, :]
# import pdb
# pdb.set_trace()
added_map = added_map[:][:][lpad:lm+lpad, rpad:rm+rpad]
return th.mul(added_map.to(x.device), torch_map.to(x.device))
def random_masking(self, x, mask_ratio):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
L = L - 1
len_keep = int(L * (1 - mask_ratio))
noise = th.rand(N, L, device=x.device) # noise in [0, 1]
# sort noise for each sample
# ascend: small is keep, large is remove
ids_shuffle = th.argsort(noise, dim=1)
ids_restore = th.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = th.gather(
x[:, 1:, :], dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
x_masked = th.cat([x[:, 0, :].unsqueeze(1), x_masked], dim=1)
# import pdb
# pdb.set_trace()
# generate the binary mask: 0 is keep, 1 is remove
mask = th.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = th.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore, ids_keep
def forward_side_interpolater(self, x, y, t0, y_lens, mask, ids_restore):
# append mask tokens to sequence
mask_tokens = self.mask_token.repeat(
x.shape[0], ids_restore.shape[1] - x.shape[1] + 1, 1)
x_ = th.cat([x[:, 1:, :], mask_tokens], dim=1)
x_ = th.gather(
x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
# add pos embed
x_ = x_ + self.decoder_pos_embed
x = th.cat([x[:, 0, :].unsqueeze(1), x_], dim=1)
# import pdb
# pdb.set_trace()
# pass to the basic block
x_before = x
for sideblock in self.sideblocks:
x = sideblock(x, y, t0, y_lens, ids_keep=None)
# masked shortcut
mask = mask.unsqueeze(dim=-1)
# import pdb;pdb.set_trace()
mask = th.cat([th.ones(mask.shape[0], 1, 1).to(mask.device), mask], dim=1)
x = x*mask + (1-mask)*x_before
return x
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
th.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize (and freeze) pos_embed by sin-cos embedding:
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size, lewei_scale=self.lewei_scale, base_size=self.base_size)
self.pos_embed.data.copy_(th.from_numpy(pos_embed).float().unsqueeze(0))
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
nn.init.normal_(self.t_block[1].weight, std=0.02)
# Initialize caption embedding MLP:
nn.init.normal_(self.y_embedder.weight, std=0.02)
# nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
# Zero-out adaLN modulation layers in PixArt blocks:
for block in self.en_inblocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
for block in self.en_outblocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
for block in self.de_blocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
for block in self.sideblocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
if self.x_embedder.ol == [0, 0] or self.x_embedder.ol == (0, 0):
return
lf = self.x_embedder.grid_size[0]
rf = self.x_embedder.grid_size[1]
lp = self.x_embedder.patch_size[0]
rp = self.x_embedder.patch_size[1]
lo = self.x_embedder.ol[0]
ro = self.x_embedder.ol[1]
lm = self.x_embedder.img_size[0]
rm = self.x_embedder.img_size[1]
lpad = self.x_embedder.pad_size[0]
rpad = self.x_embedder.pad_size[1]
torch_map = th.zeros(lm+2*lpad, rm+2*rpad).to('cuda')
for i in range(lf):
for j in range(rf):
xx = (i) * (lp - lo)
yy = (j) * (rp - ro)
torch_map[xx:(xx+lp), yy:(yy+rp)]+=1
torch_map = torch_map[lpad:lm+lpad, rpad:rm+rpad]
self.torch_map = th.reciprocal(torch_map)
@property
def dtype(self):
return next(self.parameters()).dtype
class PixArt_MDT_LC(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(self, input_size=(256,16), patch_size=(16,4), overlap=(0, 0), in_channels=8, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, pred_sigma=False, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, cond_dim=1024, lewei_scale=1.0,
use_cfg=True, cfg_scale=4.0, config=None, model_max_length=120, mask_ratio=None, decode_layer=4,**kwargs):
if window_block_indexes is None:
window_block_indexes = []
super().__init__()
self.use_cfg = use_cfg
self.cfg_scale = cfg_scale
self.input_size = input_size
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.lewei_scale = lewei_scale,
decode_layer = int(decode_layer)
self.x_embedder = PatchEmbed(input_size, patch_size, overlap, in_channels, hidden_size, bias=True)
# self.x_embedder = PatchEmbed_1D(input)
self.t_embedder = TimestepEmbedder(hidden_size)
num_patches = self.x_embedder.num_patches
self.base_size = input_size[0] // self.patch_size[0] * 2
# Will use fixed sin-cos embedding:
self.register_buffer("pos_embed", th.zeros(1, num_patches, hidden_size))
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.t_block = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
self.y_embedder = nn.Linear(cond_dim, hidden_size)
half_depth = (depth - decode_layer)//2
self.half_depth=half_depth
drop_path_half = [x.item() for x in th.linspace(0, drop_path, half_depth)] # stochastic depth decay rule
drop_path_decode = [x.item() for x in th.linspace(0, drop_path, decode_layer)]
self.en_inblocks = nn.ModuleList([
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i],
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False, FFN_type='mlp') for i in range(half_depth)
])
self.en_outblocks = nn.ModuleList([
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i],
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False, skip=True, FFN_type='mlp') for i in range(half_depth)
])
self.de_blocks = nn.ModuleList([
DEBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_decode[i],
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False, skip=True, FFN_type='mlp', num_f=self.x_embedder.grid_size[1]+1, num_t=self.x_embedder.grid_size[0]) for i in range(decode_layer)
])
self.sideblocks = nn.ModuleList([
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio,
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False, FFN_type='mlp') for _ in range(1)
])
self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels)
self.decoder_pos_embed = nn.Parameter(th.zeros(
1, num_patches, hidden_size), requires_grad=True)
if mask_ratio is not None:
self.mask_token = nn.Parameter(th.zeros(1, 1, hidden_size))
self.mask_ratio = float(mask_ratio)
self.decode_layer = int(decode_layer)
else:
self.mask_token = nn.Parameter(th.zeros(
1, 1, hidden_size), requires_grad=False)
self.mask_ratio = None
self.decode_layer = int(decode_layer)
print("mask ratio:", self.mask_ratio, "decode_layer:", self.decode_layer)
self.initialize_weights()
def forward(self, x, timestep, context_list, context_mask_list=None, enable_mask=False, **kwargs):
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = context_list[0].to(self.dtype)
pos_embed = self.pos_embed.to(self.dtype)
self.h, self.w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]
# import pdb
# pdb.set_trace()
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
t = self.t_embedder(timestep.to(x.dtype)) # (N, D)
t0 = self.t_block(t)
y = self.y_embedder(y) # (N, L, D)
# if not self.training:
try:
mask = context_mask_list[0] # (N, L)
except:
mask = th.ones(x.shape[0], 1).to(x.device)
print("MASK !!!!!!!!!!!!!!!!!!!!!!!!!")
assert mask is not None
# if mask is not None:
y = y.masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
y_lens = [int(_) for _ in y_lens]
input_skip = x
masked_stage = False
skips = []
# TODO : masking op for training
if self.mask_ratio is not None and self.training:
# masking: length -> length * mask_ratio
rand_mask_ratio = th.rand(1, device=x.device) # noise in [0, 1]
rand_mask_ratio = rand_mask_ratio * 0.2 + self.mask_ratio # mask_ratio, mask_ratio + 0.2
# print(rand_mask_ratio)
x, mask, ids_restore, ids_keep = self.random_masking(
x, rand_mask_ratio)
masked_stage = True
for block in self.en_inblocks:
if masked_stage:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=ids_keep)
else:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=None)
skips.append(x)
for block in self.en_outblocks:
if masked_stage:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=ids_keep)
else:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=None)
if self.mask_ratio is not None and self.training:
x = self.forward_side_interpolater(x, y, t0, y_lens, mask, ids_restore)
masked_stage = False
else:
# add pos embed
x = x + self.decoder_pos_embed
bs = x.shape[0]
bs, D, L = x.shape
T, F = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]
# reshaped = x.reshape(bs, T, F, L).to(x.device)
end = th.zeros(bs, T, 1, L).to(x.device)
# x = th.cat((reshaped, zero_tensor), 2)
# import pdb;pdb.set_trace()
# assert x.shape == [bs, T, F + 1, L]
for i in range(len(self.de_blocks)):
block = self.de_blocks[i]
this_skip = input_skip
x, end = auto_grad_checkpoint(block, x, end, y, t0, y_lens, skip=this_skip, ids_keep=None)
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x) # (N, out_channels, H, W)
return x
def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs):
"""
dpm solver donnot need variance prediction
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
model_out = self.forward(x, timestep, y, mask)
return model_out.chunk(2, dim=1)[0]
def forward_with_cfg(self, x, timestep, context_list, context_mask_list=None, cfg_scale=4.0, **kwargs):
"""
Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance.
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
# import pdb
# pdb.set_trace()
half = x[: len(x) // 2]
combined = th.cat([half, half], dim=0)
model_out = self.forward(combined, timestep, context_list, context_mask_list=None)
model_out = model_out['x'] if isinstance(model_out, dict) else model_out
eps, rest = model_out[:, :8], model_out[:, 8:]
cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = th.cat([half_eps, half_eps], dim=0)
return eps
# return th.cat([eps, rest], dim=1)
def unpatchify(self, x):
"""
x: (N, T, patch_size 0 * patch_size 1 * C)
imgs: (Bs. 256. 16. 8)
"""
if self.x_embedder.ol == (0, 0) or self.x_embedder.ol == [0, 0]:
c = self.out_channels
p0 = self.x_embedder.patch_size[0]
p1 = self.x_embedder.patch_size[1]
h, w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]
x = x.reshape(shape=(x.shape[0], h, w, p0, p1, c))
x = th.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p0, w * p1))
return imgs
lf = self.x_embedder.grid_size[0]
rf = self.x_embedder.grid_size[1]
lp = self.x_embedder.patch_size[0]
rp = self.x_embedder.patch_size[1]
lo = self.x_embedder.ol[0]
ro = self.x_embedder.ol[1]
lm = self.x_embedder.img_size[0]
rm = self.x_embedder.img_size[1]
lpad = self.x_embedder.pad_size[0]
rpad = self.x_embedder.pad_size[1]
bs = x.shape[0]
torch_map = self.torch_map
c = self.out_channels
x = x.reshape(shape=(bs, lf, rf, lp, rp, c))
x = th.einsum('nhwpqc->nchwpq', x)
added_map = th.zeros(bs, c, lm+2*lpad, rm+2*rpad).to(x.device)
for i in range(lf):
for j in range(rf):
xx = (i) * (lp - lo)
yy = (j) * (rp - ro)
added_map[:, :, xx:(xx+lp), yy:(yy+rp)] += \
x[:, :, i, j, :, :]
# import pdb
# pdb.set_trace()
added_map = added_map[:][:][lpad:lm+lpad, rpad:rm+rpad]
return th.mul(added_map.to(x.device), torch_map.to(x.device))
def random_masking(self, x, mask_ratio):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
noise = th.rand(N, L, device=x.device) # noise in [0, 1]
# sort noise for each sample
# ascend: small is keep, large is remove
ids_shuffle = th.argsort(noise, dim=1)
ids_restore = th.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = th.gather(
x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = th.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = th.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore, ids_keep
def forward_side_interpolater(self, x, y, t0, y_lens, mask, ids_restore):
# append mask tokens to sequence
mask_tokens = self.mask_token.repeat(
x.shape[0], ids_restore.shape[1] - x.shape[1], 1)
x_ = th.cat([x, mask_tokens], dim=1)
x = th.gather(
x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
# add pos embed
x = x + self.decoder_pos_embed
# pass to the basic block
x_before = x
for sideblock in self.sideblocks:
x = sideblock(x, y, t0, y_lens, ids_keep=None)
# masked shortcut
mask = mask.unsqueeze(dim=-1)
x = x*mask + (1-mask)*x_before
return x
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
th.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize (and freeze) pos_embed by sin-cos embedding:
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size, lewei_scale=self.lewei_scale, base_size=self.base_size)
self.pos_embed.data.copy_(th.from_numpy(pos_embed).float().unsqueeze(0))
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
nn.init.normal_(self.t_block[1].weight, std=0.02)
# Initialize caption embedding MLP:
nn.init.normal_(self.y_embedder.weight, std=0.02)
# nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
# Zero-out adaLN modulation layers in PixArt blocks:
for block in self.en_inblocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
for block in self.en_outblocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
for block in self.de_blocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
for block in self.sideblocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
if self.x_embedder.ol == [0, 0] or self.x_embedder.ol == (0, 0):
return
lf = self.x_embedder.grid_size[0]
rf = self.x_embedder.grid_size[1]
lp = self.x_embedder.patch_size[0]
rp = self.x_embedder.patch_size[1]
lo = self.x_embedder.ol[0]
ro = self.x_embedder.ol[1]
lm = self.x_embedder.img_size[0]
rm = self.x_embedder.img_size[1]
lpad = self.x_embedder.pad_size[0]
rpad = self.x_embedder.pad_size[1]
torch_map = th.zeros(lm+2*lpad, rm+2*rpad).to('cuda')
for i in range(lf):
for j in range(rf):
xx = (i) * (lp - lo)
yy = (j) * (rp - ro)
torch_map[xx:(xx+lp), yy:(yy+rp)]+=1
torch_map = torch_map[lpad:lm+lpad, rpad:rm+rpad]
self.torch_map = th.reciprocal(torch_map)
@property
def dtype(self):
return next(self.parameters()).dtype
class PixArt_MDT(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(self, input_size=(256,16), patch_size=(16,4), overlap=(0, 0), in_channels=8, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, pred_sigma=False, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, cond_dim=1024, lewei_scale=1.0,
use_cfg=True, cfg_scale=4.0, config=None, model_max_length=120, mask_ratio=None, decode_layer=4,**kwargs):
if window_block_indexes is None:
window_block_indexes = []
super().__init__()
self.use_cfg = use_cfg
self.cfg_scale = cfg_scale
self.input_size = input_size
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.lewei_scale = lewei_scale,
decode_layer = int(decode_layer)
self.x_embedder = PatchEmbed(input_size, patch_size, overlap, in_channels, hidden_size, bias=True)
# self.x_embedder = PatchEmbed_1D(input)
self.t_embedder = TimestepEmbedder(hidden_size)
num_patches = self.x_embedder.num_patches
self.base_size = input_size[0] // self.patch_size[0] * 2
# Will use fixed sin-cos embedding:
self.register_buffer("pos_embed", th.zeros(1, num_patches, hidden_size))
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.t_block = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
self.y_embedder = nn.Linear(cond_dim, hidden_size)
half_depth = (depth - decode_layer)//2
self.half_depth=half_depth
drop_path_half = [x.item() for x in th.linspace(0, drop_path, half_depth)] # stochastic depth decay rule
drop_path_decode = [x.item() for x in th.linspace(0, drop_path, decode_layer)]
self.en_inblocks = nn.ModuleList([
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i],
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False) for i in range(half_depth)
])
self.en_outblocks = nn.ModuleList([
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i],
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False, skip=True) for i in range(half_depth)
])
self.de_blocks = nn.ModuleList([
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_decode[i],
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False, skip=True) for i in range(decode_layer)
])
self.sideblocks = nn.ModuleList([
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio,
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False) for _ in range(1)
])
self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels)
self.decoder_pos_embed = nn.Parameter(th.zeros(
1, num_patches, hidden_size), requires_grad=True)
if mask_ratio is not None:
self.mask_token = nn.Parameter(th.zeros(1, 1, hidden_size))
self.mask_ratio = float(mask_ratio)
self.decode_layer = int(decode_layer)
else:
self.mask_token = nn.Parameter(th.zeros(
1, 1, hidden_size), requires_grad=False)
self.mask_ratio = None
self.decode_layer = int(decode_layer)
print("mask ratio:", self.mask_ratio, "decode_layer:", self.decode_layer)
self.initialize_weights()
def forward(self, x, timestep, context_list, context_mask_list=None, enable_mask=False, **kwargs):
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
# print(f'debug_MDT : {x.shape[0]}')
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = context_list[0].to(self.dtype)
pos_embed = self.pos_embed.to(self.dtype)
self.h, self.w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]
# import pdb
# print(f'debug_MDT : {x.shape[0]}')
# pdb.set_trace()
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
# print(f'debug_MDT : {x.shape[0]}')
t = self.t_embedder(timestep.to(x.dtype)) # (N, D)
t0 = self.t_block(t)
# print(f'debug_MDT : {x.shape[0]}')
y = self.y_embedder(y) # (N, L, D)
# if not self.training:
try:
mask = context_mask_list[0] # (N, L)
except:
mask = th.ones(x.shape[0], 1).to(x.device)
print("MASK !!!!!!!!!!!!!!!!!!!!!!!!!")
assert mask is not None
# if mask is not None:
y = y.masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
y_lens = [int(_) for _ in y_lens]
# print(f'debug_MDT : {x.shape[0]}')
input_skip = x
masked_stage = False
skips = []
# TODO : masking op for training
if self.mask_ratio is not None and self.training:
# masking: length -> length * mask_ratio
rand_mask_ratio = th.rand(1, device=x.device) # noise in [0, 1]
rand_mask_ratio = rand_mask_ratio * 0.2 + self.mask_ratio # mask_ratio, mask_ratio + 0.2
# print(rand_mask_ratio)
x, mask, ids_restore, ids_keep = self.random_masking(
x, rand_mask_ratio)
masked_stage = True
for block in self.en_inblocks:
if masked_stage:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=ids_keep)
else:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=None)
skips.append(x)
for block in self.en_outblocks:
if masked_stage:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=ids_keep)
else:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=None)
if self.mask_ratio is not None and self.training:
x = self.forward_side_interpolater(x, y, t0, y_lens, mask, ids_restore)
masked_stage = False
else:
# add pos embed
x = x + self.decoder_pos_embed
for i in range(len(self.de_blocks)):
block = self.de_blocks[i]
this_skip = input_skip
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=this_skip, ids_keep=None)
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x) # (N, out_channels, H, W)
return x
def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs):
"""
dpm solver donnot need variance prediction
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
model_out = self.forward(x, timestep, y, mask)
return model_out.chunk(2, dim=1)[0]
def forward_with_cfg(self, x, timestep, y, cfg_scale, mask=None, **kwargs):
"""
Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance.
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
half = x[: len(x) // 2]
combined = th.cat([half, half], dim=0)
model_out = self.forward(combined, timestep, y, mask)
model_out = model_out['x'] if isinstance(model_out, dict) else model_out
eps, rest = model_out[:, :8], model_out[:, 8:]
cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = th.cat([half_eps, half_eps], dim=0)
return eps
# return th.cat([eps, rest], dim=1)
def unpatchify(self, x):
"""
x: (N, T, patch_size 0 * patch_size 1 * C)
imgs: (Bs. 256. 16. 8)
"""
if self.x_embedder.ol == (0, 0) or self.x_embedder.ol == [0, 0]:
c = self.out_channels
p0 = self.x_embedder.patch_size[0]
p1 = self.x_embedder.patch_size[1]
h, w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]
x = x.reshape(shape=(x.shape[0], h, w, p0, p1, c))
x = th.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p0, w * p1))
return imgs
lf = self.x_embedder.grid_size[0]
rf = self.x_embedder.grid_size[1]
lp = self.x_embedder.patch_size[0]
rp = self.x_embedder.patch_size[1]
lo = self.x_embedder.ol[0]
ro = self.x_embedder.ol[1]
lm = self.x_embedder.img_size[0]
rm = self.x_embedder.img_size[1]
lpad = self.x_embedder.pad_size[0]
rpad = self.x_embedder.pad_size[1]
bs = x.shape[0]
torch_map = self.torch_map
c = self.out_channels
x = x.reshape(shape=(bs, lf, rf, lp, rp, c))
x = th.einsum('nhwpqc->nchwpq', x)
added_map = th.zeros(bs, c, lm+2*lpad, rm+2*rpad).to(x.device)
for i in range(lf):
for j in range(rf):
xx = (i) * (lp - lo)
yy = (j) * (rp - ro)
added_map[:, :, xx:(xx+lp), yy:(yy+rp)] += \
x[:, :, i, j, :, :]
added_map = added_map[:, :, lpad:lm+lpad, rpad:rm+rpad]
return th.mul(added_map, torch_map.to(added_map.device))
def random_masking(self, x, mask_ratio):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
noise = th.rand(N, L, device=x.device) # noise in [0, 1]
# sort noise for each sample
# ascend: small is keep, large is remove
ids_shuffle = th.argsort(noise, dim=1)
ids_restore = th.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = th.gather(
x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = th.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = th.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore, ids_keep
def forward_side_interpolater(self, x, y, t0, y_lens, mask, ids_restore):
# append mask tokens to sequence
mask_tokens = self.mask_token.repeat(
x.shape[0], ids_restore.shape[1] - x.shape[1], 1)
x_ = th.cat([x, mask_tokens], dim=1)
x = th.gather(
x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
# add pos embed
x = x + self.decoder_pos_embed
# pass to the basic block
x_before = x
for sideblock in self.sideblocks:
x = sideblock(x, y, t0, y_lens, ids_keep=None)
# masked shortcut
mask = mask.unsqueeze(dim=-1)
x = x*mask + (1-mask)*x_before
return x
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
th.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize (and freeze) pos_embed by sin-cos embedding:
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size, lewei_scale=self.lewei_scale, base_size=self.base_size)
self.pos_embed.data.copy_(th.from_numpy(pos_embed).float().unsqueeze(0))
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
nn.init.normal_(self.t_block[1].weight, std=0.02)
# Initialize caption embedding MLP:
nn.init.normal_(self.y_embedder.weight, std=0.02)
# nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
# Zero-out adaLN modulation layers in PixArt blocks:
for block in self.en_inblocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
for block in self.en_outblocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
for block in self.de_blocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
for block in self.sideblocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
if self.x_embedder.ol == [0, 0] or self.x_embedder.ol == (0, 0):
return
lf = self.x_embedder.grid_size[0]
rf = self.x_embedder.grid_size[1]
lp = self.x_embedder.patch_size[0]
rp = self.x_embedder.patch_size[1]
lo = self.x_embedder.ol[0]
ro = self.x_embedder.ol[1]
lm = self.x_embedder.img_size[0]
rm = self.x_embedder.img_size[1]
lpad = self.x_embedder.pad_size[0]
rpad = self.x_embedder.pad_size[1]
torch_map = th.zeros(lm+2*lpad, rm+2*rpad).to('cuda')
for i in range(lf):
for j in range(rf):
xx = (i) * (lp - lo)
yy = (j) * (rp - ro)
torch_map[xx:(xx+lp), yy:(yy+rp)]+=1
torch_map = torch_map[lpad:lm+lpad, rpad:rm+rpad]
self.torch_map = th.reciprocal(torch_map)
@property
def dtype(self):
return next(self.parameters()).dtype
class PixArt_MDT_FIT(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(self, input_size=(256,16), patch_size=(16,4), overlap=(0, 0), in_channels=8, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, pred_sigma=False, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, cond_dim=1024, lewei_scale=1.0,
use_cfg=True, cfg_scale=4.0, config=None, model_max_length=120, mask_ratio=None, decode_layer=4,**kwargs):
if window_block_indexes is None:
window_block_indexes = []
super().__init__()
self.use_cfg = use_cfg
self.cfg_scale = cfg_scale
self.input_size = input_size
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.lewei_scale = lewei_scale,
decode_layer = int(decode_layer)
self.x_embedder = PatchEmbed(input_size, patch_size, overlap, in_channels, hidden_size, bias=True)
# self.x_embedder = PatchEmbed_1D(input)
self.t_embedder = TimestepEmbedder(hidden_size)
num_patches = self.x_embedder.num_patches
self.base_size = input_size[0] // self.patch_size[0] * 2
# Will use fixed sin-cos embedding:
self.register_buffer("pos_embed", th.zeros(1, num_patches, hidden_size))
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.t_block = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
self.y_embedder = nn.Linear(cond_dim, hidden_size)
half_depth = (depth - decode_layer)//2
self.half_depth=half_depth
drop_path_half = [x.item() for x in th.linspace(0, drop_path, half_depth)] # stochastic depth decay rule
drop_path_decode = [x.item() for x in th.linspace(0, drop_path, decode_layer)]
self.en_inblocks = nn.ModuleList([
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i],
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False) for i in range(half_depth)
])
self.en_outblocks = nn.ModuleList([
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i],
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False, skip=True) for i in range(half_depth)
])
self.de_blocks = nn.ModuleList([
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_decode[i],
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False, skip=True) for i in range(decode_layer)
])
self.sideblocks = nn.ModuleList([
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio,
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False) for _ in range(1)
])
self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels)
self.decoder_pos_embed = nn.Parameter(th.zeros(
1, num_patches, hidden_size), requires_grad=True)
if mask_ratio is not None:
self.mask_token = nn.Parameter(th.zeros(1, 1, hidden_size))
self.mask_ratio = float(mask_ratio)
self.decode_layer = int(decode_layer)
else:
self.mask_token = nn.Parameter(th.zeros(
1, 1, hidden_size), requires_grad=False)
self.mask_ratio = None
self.decode_layer = int(decode_layer)
print("mask ratio:", self.mask_ratio, "decode_layer:", self.decode_layer)
self.initialize_weights()
def forward(self, x, timestep, context_list, context_mask_list=None, enable_mask=False, **kwargs):
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = context_list[0].to(self.dtype)
pos_embed = self.pos_embed.to(self.dtype)
self.h, self.w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]
# import pdb
# pdb.set_trace()
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
t = self.t_embedder(timestep.to(x.dtype)) # (N, D)
t0 = self.t_block(t)
y = self.y_embedder(y) # (N, L, D)
# if not self.training:
try:
mask = context_mask_list[0] # (N, L)
except:
mask = th.ones(x.shape[0], 1).to(x.device)
print("MASK !!!!!!!!!!!!!!!!!!!!!!!!!")
assert mask is not None
# if mask is not None:
y = y.masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
y_lens = [int(_) for _ in y_lens]
input_skip = x
masked_stage = False
skips = []
# TODO : masking op for training
if self.mask_ratio is not None and self.training:
# masking: length -> length * mask_ratio
rand_mask_ratio = th.rand(1, device=x.device) # noise in [0, 1]
rand_mask_ratio = rand_mask_ratio * 0.2 + self.mask_ratio # mask_ratio, mask_ratio + 0.2
# print(rand_mask_ratio)
x, mask, ids_restore, ids_keep = self.random_masking(
x, rand_mask_ratio)
masked_stage = True
for block in self.en_inblocks:
if masked_stage:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=ids_keep)
else:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=None)
skips.append(x)
for block in self.en_outblocks:
if masked_stage:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=ids_keep)
else:
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=None)
if self.mask_ratio is not None and self.training:
x = self.forward_side_interpolater(x, y, t0, y_lens, mask, ids_restore)
masked_stage = False
else:
# add pos embed
x = x + self.decoder_pos_embed
for i in range(len(self.de_blocks)):
block = self.de_blocks[i]
this_skip = input_skip
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=this_skip, ids_keep=None)
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x) # (N, out_channels, H, W)
return x
def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs):
"""
dpm solver donnot need variance prediction
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
model_out = self.forward(x, timestep, y, mask)
return model_out.chunk(2, dim=1)[0]
def forward_with_cfg(self, x, timestep, y, cfg_scale, mask=None, **kwargs):
"""
Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance.
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
half = x[: len(x) // 2]
combined = th.cat([half, half], dim=0)
model_out = self.forward(combined, timestep, y, mask)
model_out = model_out['x'] if isinstance(model_out, dict) else model_out
eps, rest = model_out[:, :8], model_out[:, 8:]
cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = th.cat([half_eps, half_eps], dim=0)
return eps
# return th.cat([eps, rest], dim=1)
def unpatchify(self, x):
"""
x: (N, T, patch_size 0 * patch_size 1 * C)
imgs: (Bs. 256. 16. 8)
"""
c = self.out_channels
p0 = self.x_embedder.patch_size[0]
p1 = self.x_embedder.patch_size[1]
h, w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]
x = x.reshape(shape=(x.shape[0], h, w, p0, p1, c))
x = th.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p0, w * p1))
return imgs
def random_masking(self, x, mask_ratio):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
noise = th.rand(N, L, device=x.device) # noise in [0, 1]
# sort noise for each sample
# ascend: small is keep, large is remove
ids_shuffle = th.argsort(noise, dim=1)
ids_restore = th.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = th.gather(
x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = th.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = th.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore, ids_keep
def forward_side_interpolater(self, x, y, t0, y_lens, mask, ids_restore):
# append mask tokens to sequence
mask_tokens = self.mask_token.repeat(
x.shape[0], ids_restore.shape[1] - x.shape[1], 1)
x_ = th.cat([x, mask_tokens], dim=1)
x = th.gather(
x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
# add pos embed
x = x + self.decoder_pos_embed
# pass to the basic block
x_before = x
for sideblock in self.sideblocks:
x = sideblock(x, y, t0, y_lens, ids_keep=None)
# masked shortcut
mask = mask.unsqueeze(dim=-1)
x = x*mask + (1-mask)*x_before
return x
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
th.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize (and freeze) pos_embed by sin-cos embedding:
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size, lewei_scale=self.lewei_scale, base_size=self.base_size)
# Replace the absolute embedding with 2d-rope position embedding:
# pos_embed =
self.pos_embed.data.copy_(th.from_numpy(pos_embed).float().unsqueeze(0))
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
nn.init.normal_(self.t_block[1].weight, std=0.02)
# Initialize caption embedding MLP:
nn.init.normal_(self.y_embedder.weight, std=0.02)
# nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
# Zero-out adaLN modulation layers in PixArt blocks:
for block in self.en_inblocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
for block in self.en_outblocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
for block in self.de_blocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
for block in self.sideblocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
@property
def dtype(self):
return next(self.parameters()).dtype
class PixArt_Slow(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(self, input_size=(256,16), patch_size=(16,4), overlap=(0, 0), in_channels=8, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, pred_sigma=True, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, cond_dim=1024, lewei_scale=1.0,
use_cfg=True, cfg_scale=4.0, config=None, model_max_length=120, **kwargs):
if window_block_indexes is None:
window_block_indexes = []
super().__init__()
self.use_cfg = use_cfg
self.cfg_scale = cfg_scale
self.input_size = input_size
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if pred_sigma else in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.lewei_scale = lewei_scale,
self.x_embedder = PatchEmbed(input_size, patch_size, overlap, in_channels, hidden_size, bias=True)
# self.x_embedder = PatchEmbed_1D(input)
self.t_embedder = TimestepEmbedder(hidden_size)
num_patches = self.x_embedder.num_patches
self.base_size = input_size[0] // self.patch_size[0] * 2
# Will use fixed sin-cos embedding:
self.register_buffer("pos_embed", th.zeros(1, num_patches, hidden_size))
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.t_block = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
self.y_embedder = nn.Linear(cond_dim, hidden_size)
drop_path = [x.item() for x in th.linspace(0, drop_path, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
PixArtBlock_Slow(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]),
window_size=0,
use_rel_pos=False)
for i in range(depth)
])
self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels)
self.initialize_weights()
def forward(self, x, timestep, context_list, context_mask_list=None, **kwargs):
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = context_list[0].to(self.dtype)
pos_embed = self.pos_embed.to(self.dtype)
self.h, self.w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
t = self.t_embedder(timestep.to(x.dtype)) # (N, D)
t0 = self.t_block(t)
y = self.y_embedder(y) # (N, L, D)
mask = context_mask_list[0] # (N, L)
assert mask is not None
# if mask is not None:
# y = y.masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
# y_lens = mask.sum(dim=1).tolist()
# y_lens = [int(_) for _ in y_lens]
for block in self.blocks:
x = auto_grad_checkpoint(block, x, y, t0, mask) # (N, T, D) #support grad checkpoint
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x) # (N, out_channels, H, W)
return x
def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs):
"""
dpm solver donnot need variance prediction
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
model_out = self.forward(x, timestep, y, mask)
return model_out.chunk(2, dim=1)[0]
def forward_with_cfg(self, x, timestep, y, cfg_scale, mask=None, **kwargs):
"""
Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance.
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
half = x[: len(x) // 2]
combined = th.cat([half, half], dim=0)
model_out = self.forward(combined, timestep, y, mask)
model_out = model_out['x'] if isinstance(model_out, dict) else model_out
eps, rest = model_out[:, :8], model_out[:, 8:]
cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = th.cat([half_eps, half_eps], dim=0)
return eps
# return th.cat([eps, rest], dim=1)
def unpatchify(self, x):
"""
x: (N, T, patch_size 0 * patch_size 1 * C)
imgs: (Bs. 256. 16. 8)
"""
c = self.out_channels
p0 = self.x_embedder.patch_size[0]
p1 = self.x_embedder.patch_size[1]
h, w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]
x = x.reshape(shape=(x.shape[0], h, w, p0, p1, c))
x = th.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p0, w * p1))
return imgs
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
th.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize (and freeze) pos_embed by sin-cos embedding:
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size, lewei_scale=self.lewei_scale, base_size=self.base_size)
self.pos_embed.data.copy_(th.from_numpy(pos_embed).float().unsqueeze(0))
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
nn.init.normal_(self.t_block[1].weight, std=0.02)
# Initialize caption embedding MLP:
nn.init.normal_(self.y_embedder.weight, std=0.02)
# nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
# Zero-out adaLN modulation layers in PixArt blocks:
# for block in self.blocks:
# nn.init.constant_(block.cross_attn.proj.weight, 0)
# nn.init.constant_(block.cross_attn.proj.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
@property
def dtype(self):
return next(self.parameters()).dtype
class PixArtBlock_1D(nn.Module):
"""
A PixArt block with adaptive layer norm (adaLN-single) conditioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., window_size=0, use_rel_pos=False, **block_kwargs):
super().__init__()
self.hidden_size = hidden_size
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = WindowAttention(hidden_size, num_heads=num_heads, qkv_bias=True,
input_size=None,
use_rel_pos=use_rel_pos, **block_kwargs)
# self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
# to be compatible with lower version pytorch
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.window_size = window_size
self.scale_shift_table = nn.Parameter(th.randn(6, hidden_size) / hidden_size ** 0.5)
def forward(self, x, y, t, mask=None, **kwargs):
B, N, C = x.shape
# x [3, 133, 1152]
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1)
x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C))
x = x + self.cross_attn(x, y, mask)
x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
return x
class PixArt_1D(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(self, input_size=(256,16), in_channels=8, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, pred_sigma=True, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, cond_dim=1024, lewei_scale=1.0,
use_cfg=True, cfg_scale=4.0, config=None, model_max_length=120, **kwargs):
if window_block_indexes is None:
window_block_indexes = []
super().__init__()
self.use_cfg = use_cfg
self.cfg_scale = cfg_scale
self.input_size = input_size
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.num_heads = num_heads
self.lewei_scale = lewei_scale,
self.x_embedder = PatchEmbed_1D(input_size, in_channels, hidden_size)
# self.x_embedder = PatchEmbed(input_size, patch_size, overlap, in_channels, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.p_enc_1d_model = PositionalEncoding1D(hidden_size)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.t_block = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
self.y_embedder = nn.Linear(cond_dim, hidden_size)
drop_path = [x.item() for x in th.linspace(0, drop_path, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
PixArtBlock_1D(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
window_size=0,
use_rel_pos=False)
for i in range(depth)
])
self.final_layer = T2IFinalLayer(hidden_size, (1, input_size[1]), self.out_channels)
self.initialize_weights()
# if config:
# logger = get_root_logger(os.path.join(config.work_dir, 'train_log.log'))
# logger.warning(f"lewei scale: {self.lewei_scale}, base size: {self.base_size}")
# else:
# print(f'Warning: lewei scale: {self.lewei_scale}, base size: {self.base_size}')
def forward(self, x, timestep, context_list, context_mask_list=None, **kwargs):
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = context_list[0].to(self.dtype)
x = self.x_embedder(x) # (N, T, D)
pos_embed = self.p_enc_1d_model(x)
x = x + pos_embed
t = self.t_embedder(timestep.to(x.dtype)) # (N, D)
t0 = self.t_block(t)
y = self.y_embedder(y) # (N, L, D)
try:
mask = context_mask_list[0] # (N, L)
except:
mask = th.ones(x.shape[0], 1).to(x.device)
print("MASK !!!!!!!!!!!!!!!!!!!!!!!!!")
assert mask is not None
# if mask is not None:
y = y.masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
y_lens = [int(_) for _ in y_lens]
for block in self.blocks:
x = auto_grad_checkpoint(block, x, y, t0, y_lens) # (N, T, D) #support grad checkpoint
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify_1D(x) # (N, out_channels, H, W)
return x
def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs):
"""
dpm solver donnot need variance prediction
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
model_out = self.forward(x, timestep, y, mask)
return model_out.chunk(2, dim=1)[0]
def forward_with_cfg(self, x, timestep, y, cfg_scale, mask=None, **kwargs):
"""
Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance.
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
half = x[: len(x) // 2]
combined = th.cat([half, half], dim=0)
model_out = self.forward(combined, timestep, y, mask)
model_out = model_out['x'] if isinstance(model_out, dict) else model_out
eps, rest = model_out[:, :8], model_out[:, 8:]
cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = th.cat([half_eps, half_eps], dim=0)
return eps
# return th.cat([eps, rest], dim=1)
def unpatchify_1D(self, x):
"""
"""
c = self.out_channels
x = x.reshape(shape=(x.shape[0], self.input_size[0], self.input_size[1], c))
x = th.einsum('btfc->bctf', x)
# imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
return x
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
th.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize (and freeze) pos_embed by sin-cos embedding:
# pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size, lewei_scale=self.lewei_scale, base_size=self.base_size)
# self.pos_embed.data.copy_(th.from_numpy(pos_embed).float().unsqueeze(0))
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
nn.init.normal_(self.t_block[1].weight, std=0.02)
# Initialize caption embedding MLP:
nn.init.normal_(self.y_embedder.weight, std=0.02)
# nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
# Zero-out adaLN modulation layers in PixArt blocks:
for block in self.blocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
@property
def dtype(self):
return next(self.parameters()).dtype
class PixArt_Slow_1D(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(self, input_size=(256,16), in_channels=8, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, pred_sigma=True, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, cond_dim=1024, lewei_scale=1.0,
use_cfg=True, cfg_scale=4.0, config=None, model_max_length=120, **kwargs):
if window_block_indexes is None:
window_block_indexes = []
super().__init__()
self.use_cfg = use_cfg
self.cfg_scale = cfg_scale
self.input_size = input_size
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if pred_sigma else in_channels
self.num_heads = num_heads
self.lewei_scale = lewei_scale,
self.x_embedder = PatchEmbed_1D(input_size, in_channels, hidden_size)
# self.x_embedder = PatchEmbed(input_size, patch_size, overlap, in_channels, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.p_enc_1d_model = PositionalEncoding1D(hidden_size)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.t_block = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
self.y_embedder = nn.Linear(cond_dim, hidden_size)
drop_path = [x.item() for x in th.linspace(0, drop_path, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
PixArtBlock_Slow(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
window_size=0,
use_rel_pos=False)
for i in range(depth)
])
self.final_layer = T2IFinalLayer(hidden_size, (1, input_size[1]), self.out_channels)
self.initialize_weights()
# if config:
# logger = get_root_logger(os.path.join(config.work_dir, 'train_log.log'))
# logger.warning(f"lewei scale: {self.lewei_scale}, base size: {self.base_size}")
# else:
# print(f'Warning: lewei scale: {self.lewei_scale}, base size: {self.base_size}')
def forward(self, x, timestep, context_list, context_mask_list=None, **kwargs):
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = context_list[0].to(self.dtype)
x = self.x_embedder(x) # (N, T, D)
pos_embed = self.p_enc_1d_model(x)
x = x + pos_embed
t = self.t_embedder(timestep.to(x.dtype)) # (N, D)
t0 = self.t_block(t)
y = self.y_embedder(y) # (N, L, D)
mask = context_mask_list[0] # (N, L)
assert mask is not None
# if mask is not None:
# y = y.masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
# y_lens = mask.sum(dim=1).tolist()
# y_lens = [int(_) for _ in y_lens]
for block in self.blocks:
x = auto_grad_checkpoint(block, x, y, t0, mask) # (N, T, D) #support grad checkpoint
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify_1D(x) # (N, out_channels, H, W)
return x
def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs):
"""
dpm solver donnot need variance prediction
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
model_out = self.forward(x, timestep, y, mask)
return model_out.chunk(2, dim=1)[0]
def forward_with_cfg(self, x, timestep, y, cfg_scale, mask=None, **kwargs):
"""
Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance.
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
half = x[: len(x) // 2]
combined = th.cat([half, half], dim=0)
model_out = self.forward(combined, timestep, y, mask)
model_out = model_out['x'] if isinstance(model_out, dict) else model_out
eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = th.cat([half_eps, half_eps], dim=0)
return eps
# return th.cat([eps, rest], dim=1)
def unpatchify_1D(self, x):
"""
"""
c = self.out_channels
x = x.reshape(shape=(x.shape[0], self.input_size[0], self.input_size[1], c))
x = th.einsum('btfc->bctf', x)
# imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
return x
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
th.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize (and freeze) pos_embed by sin-cos embedding:
# pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size, lewei_scale=self.lewei_scale, base_size=self.base_size)
# self.pos_embed.data.copy_(th.from_numpy(pos_embed).float().unsqueeze(0))
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
nn.init.normal_(self.t_block[1].weight, std=0.02)
# Initialize caption embedding MLP:
nn.init.normal_(self.y_embedder.weight, std=0.02)
# nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
# Zero-out adaLN modulation layers in PixArt blocks:
# for block in self.blocks:
# nn.init.constant_(block.cross_attn.proj.weight, 0)
# nn.init.constant_(block.cross_attn.proj.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
@property
def dtype(self):
return next(self.parameters()).dtype
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, lewei_scale=1.0, base_size=16):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
# import pdb
# pdb.set_trace()
if isinstance(grid_size, int):
grid_size = to_2tuple(grid_size)
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0]/base_size) / lewei_scale
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1]/base_size) / lewei_scale
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
return np.concatenate([emb_h, emb_w], axis=1)
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.
omega = 1. / 10000 ** omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
return np.concatenate([emb_sin, emb_cos], axis=1)
# if __name__ == '__main__' :
# import pdb
# pdb.set_trace()
# model = PixArt_1D().to('cuda')
# # x: (N, T, patch_size 0 * patch_size 1 * C)
# th.manual_seed(233)
# # x = th.rand(1, 4*16, 16*4*16).to('cuda')
# x = th.rand(3, 8, 256, 16).to('cuda')
# t = th.tensor([1, 2, 3]).to('cuda')
# c = th.rand(3, 20, 1024).to('cuda')
# c_mask = th.ones(3, 20).to('cuda')
# c_list = [c]
# c_mask_list = [c_mask]
# y = model.forward(x, t, c_list, c_mask_list)
# res = model.unpatchify(x)
# class DiTModel(nn.Module):
# """
# The full UNet model with attention and timestep embedding.
# :param in_channels: channels in the input Tensor.
# :param model_channels: base channel count for the model.
# :param out_channels: channels in the output Tensor.
# :param num_res_blocks: number of residual blocks per downsample.
# :param attention_resolutions: a collection of downsample rates at which
# attention will take place. May be a set, list, or tuple.
# For example, if this contains 4, then at 4x downsampling, attention
# will be used.
# :param dropout: the dropout probability.
# :param channel_mult: channel multiplier for each level of the UNet.
# :param conv_resample: if True, use learned convolutions for upsampling and
# downsampling.
# :param dims: determines if the signal is 1D, 2D, or 3D.
# :param num_classes: if specified (as an int), then this model will be
# class-conditional with `num_classes` classes.
# :param use_checkpoint: use gradient checkpointing to reduce memory usage.
# :param num_heads: the number of attention heads in each attention layer.
# :param num_heads_channels: if specified, ignore num_heads and instead use
# a fixed channel width per attention head.
# :param num_heads_upsample: works with num_heads to set a different number
# of heads for upsampling. Deprecated.
# :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
# :param resblock_updown: use residual blocks for up/downsampling.
# :param use_new_attention_order: use a different attention pattern for potentially
# increased efficiency.
# """
# def __init__(
# self,
# input_size,
# patch_size,
# overlap,
# in_channels,
# embed_dim,
# model_channels,
# out_channels,
# dims=2,
# extra_film_condition_dim=None,
# use_checkpoint=False,
# use_fp16=False,
# num_heads=-1,
# num_head_channels=-1,
# use_scale_shift_norm=False,
# use_new_attention_order=False,
# transformer_depth=1, # custom transformer support
# context_dim=None, # custom transformer support
# n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
# legacy=True,
# ):
# super().__init__()
# self.x_embedder = PatchEmbed(input_size, patch_size, overlap, in_channels, embed_dim, bias=True)
# num_patches = self.x_embedder.num_patches
# self.pos_embed = nn.Parameter(th.zeros(1, num_patches, embed_dim), requires_grad=False)
# self.blocks = nn.ModuleList([
# DiTBlock_crossattn
# ])
# def convert_to_fp16(self):
# """
# Convert the torso of the model to float16.
# """
# # self.input_blocks.apply(convert_module_to_f16)
# # self.middle_block.apply(convert_module_to_f16)
# # self.output_blocks.apply(convert_module_to_f16)
# def convert_to_fp32(self):
# """
# Convert the torso of the model to float32.
# """
# # self.input_blocks.apply(convert_module_to_f32)
# # self.middle_block.apply(convert_module_to_f32)
# # self.output_blocks.apply(convert_module_to_f32)
# def forward(
# self,
# x,
# timesteps=None,
# y=None,
# context_list=None,
# context_attn_mask_list=None,
# **kwargs,
# ):
# """
# Apply the model to an input batch.
# :param x: an [N x C x ...] Tensor of inputs.
# :param timesteps: a 1-D batch of timesteps.
# :param context: conditioning plugged in via crossattn
# :param y: an [N] Tensor of labels, if class-conditional. an [N, extra_film_condition_dim] Tensor if film-embed conditional
# :return: an [N x C x ...] Tensor of outputs.
# """
# x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
# t = self.t_embedder(timesteps) # (N, D)
# y = self.y_embedder(y, self.training) # (N, D)
# c = t + y # (N, D)
# for block in self.blocks:
# x = block(x, c) # (N, T, D)
# x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
# x = self.unpatchify(x) # (N, out_channels, H, W)