|
from abc import abstractmethod |
|
from functools import partial |
|
import math |
|
from typing import Iterable |
|
|
|
import numpy as np |
|
import torch as th |
|
|
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import sys |
|
from fairscale.nn.model_parallel.layers import ( |
|
ColumnParallelLinear, |
|
ParallelEmbedding, |
|
RowParallelLinear, |
|
) |
|
|
|
from timm.models.layers import DropPath |
|
from .utils import auto_grad_checkpoint, to_2tuple |
|
from .PixArt_blocks import t2i_modulate, WindowAttention, MultiHeadCrossAttention, T2IFinalLayer, TimestepEmbedder, FinalLayer |
|
import xformers.ops |
|
|
|
import math |
|
class PatchEmbed(nn.Module): |
|
""" 2D Image to Patch Embedding |
|
""" |
|
def __init__( |
|
self, |
|
img_size=(256, 16), |
|
patch_size=(16, 4), |
|
overlap = (0, 0), |
|
in_chans=128, |
|
embed_dim=768, |
|
norm_layer=None, |
|
flatten=True, |
|
bias=True, |
|
): |
|
super().__init__() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.img_size = img_size |
|
self.patch_size = patch_size |
|
self.ol = overlap |
|
self.grid_size = (math.ceil((img_size[0] - patch_size[0]) / (patch_size[0]-overlap[0])) + 1, |
|
math.ceil((img_size[1] - patch_size[1]) / (patch_size[1]-overlap[1])) + 1) |
|
self.pad_size = ((self.grid_size[0]-1) * (self.patch_size[0]-overlap[0])+self.patch_size[0]-self.img_size[0], |
|
+(self.grid_size[1]-1)*(self.patch_size[1]-overlap[1])+self.patch_size[1]-self.img_size[1]) |
|
self.pad_size = (self.pad_size[0] // 2, self.pad_size[1] // 2) |
|
|
|
self.num_patches = self.grid_size[0] * self.grid_size[1] |
|
self.flatten = flatten |
|
|
|
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0]-overlap[0], patch_size[1]-overlap[1]), bias=bias) |
|
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() |
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
|
|
|
|
x = F.pad(x, (self.pad_size[-1], self.pad_size[-1], self.pad_size[-2], self.pad_size[-2]), "constant", 0) |
|
x = self.proj(x) |
|
if self.flatten: |
|
x = x.flatten(2).transpose(1, 2) |
|
x = self.norm(x) |
|
return x |
|
|
|
class PatchEmbed_1D(nn.Module): |
|
def __init__( |
|
self, |
|
img_size=(256, 16), |
|
|
|
|
|
in_chans=8, |
|
embed_dim=1152, |
|
norm_layer=None, |
|
|
|
bias=True, |
|
): |
|
super().__init__() |
|
|
|
self.proj = nn.Linear(in_chans*img_size[1], embed_dim, bias=bias) |
|
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() |
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = th.einsum('bctf->btfc', x) |
|
x = x.flatten(2) |
|
x = self.proj(x) |
|
x = self.norm(x) |
|
return x |
|
|
|
|
|
|
|
|
|
|
|
from timm.models.vision_transformer import Attention, Mlp |
|
|
|
def modulate(x, shift, scale): |
|
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) |
|
|
|
from positional_encodings.torch_encodings import PositionalEncoding1D |
|
|
|
def t2i_modulate(x, shift, scale): |
|
return x * (1 + scale) + shift |
|
|
|
class PixArtBlock(nn.Module): |
|
""" |
|
A PixArt block with adaptive layer norm (adaLN-single) conditioning. |
|
""" |
|
|
|
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., window_size=0, input_size=None, use_rel_pos=False, **block_kwargs): |
|
super().__init__() |
|
self.hidden_size = hidden_size |
|
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
|
self.attn = WindowAttention(hidden_size, num_heads=num_heads, qkv_bias=True, |
|
input_size=input_size if window_size == 0 else (window_size, window_size), |
|
use_rel_pos=use_rel_pos, **block_kwargs) |
|
self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs) |
|
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
|
|
|
approx_gelu = lambda: nn.GELU(approximate="tanh") |
|
self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0) |
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
self.window_size = window_size |
|
self.scale_shift_table = nn.Parameter(th.randn(6, hidden_size) / hidden_size ** 0.5) |
|
|
|
def forward(self, x, y, t, mask=None, **kwargs): |
|
B, N, C = x.shape |
|
|
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1) |
|
x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C)) |
|
x = x + self.cross_attn(x, y, mask) |
|
x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) |
|
|
|
return x |
|
|
|
from qa_mdt.audioldm_train.modules.diffusionmodules.attention import CrossAttention_1D |
|
|
|
class PixArtBlock_Slow(nn.Module): |
|
""" |
|
A PixArt block with adaptive layer norm (adaLN-single) conditioning. |
|
""" |
|
|
|
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., window_size=0, input_size=None, use_rel_pos=False, **block_kwargs): |
|
super().__init__() |
|
self.hidden_size = hidden_size |
|
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
|
self.attn = CrossAttention_1D(query_dim=hidden_size, context_dim=hidden_size, heads=num_heads, dim_head=int(hidden_size/num_heads)) |
|
self.cross_attn = CrossAttention_1D(query_dim=hidden_size, context_dim=hidden_size, heads=num_heads, dim_head=int(hidden_size/num_heads)) |
|
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
|
|
|
approx_gelu = lambda: nn.GELU(approximate="tanh") |
|
self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0) |
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
self.window_size = window_size |
|
self.scale_shift_table = nn.Parameter(th.randn(6, hidden_size) / hidden_size ** 0.5) |
|
|
|
def forward(self, x, y, t, mask=None, **kwargs): |
|
B, N, C = x.shape |
|
|
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1) |
|
x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C)) |
|
x = x + self.cross_attn(x, y, mask) |
|
x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) |
|
|
|
return x |
|
|
|
class PixArt(nn.Module): |
|
""" |
|
Diffusion model with a Transformer backbone. |
|
""" |
|
|
|
def __init__(self, input_size=(256,16), patch_size=(16,4), overlap=(0, 0), in_channels=8, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, pred_sigma=True, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, cond_dim=1024, lewei_scale=1.0, |
|
use_cfg=True, cfg_scale=4.0, config=None, model_max_length=120, **kwargs): |
|
if window_block_indexes is None: |
|
window_block_indexes = [] |
|
super().__init__() |
|
self.use_cfg = use_cfg |
|
self.cfg_scale = cfg_scale |
|
self.input_size = input_size |
|
self.pred_sigma = pred_sigma |
|
self.in_channels = in_channels |
|
self.out_channels = in_channels * 2 if pred_sigma else in_channels |
|
self.patch_size = patch_size |
|
self.num_heads = num_heads |
|
self.lewei_scale = lewei_scale, |
|
|
|
self.x_embedder = PatchEmbed(input_size, patch_size, overlap, in_channels, hidden_size, bias=True) |
|
|
|
self.t_embedder = TimestepEmbedder(hidden_size) |
|
num_patches = self.x_embedder.num_patches |
|
self.base_size = input_size[0] // self.patch_size[0] * 2 |
|
|
|
self.register_buffer("pos_embed", th.zeros(1, num_patches, hidden_size)) |
|
|
|
approx_gelu = lambda: nn.GELU(approximate="tanh") |
|
self.t_block = nn.Sequential( |
|
nn.SiLU(), |
|
nn.Linear(hidden_size, 6 * hidden_size, bias=True) |
|
) |
|
self.y_embedder = nn.Linear(cond_dim, hidden_size) |
|
drop_path = [x.item() for x in th.linspace(0, drop_path, depth)] |
|
self.blocks = nn.ModuleList([ |
|
PixArtBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i], |
|
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), |
|
window_size=0, |
|
use_rel_pos=False) |
|
for i in range(depth) |
|
]) |
|
self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels) |
|
|
|
self.initialize_weights() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x, timestep, context_list, context_mask_list=None, **kwargs): |
|
""" |
|
Forward pass of PixArt. |
|
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) |
|
t: (N,) tensor of diffusion timesteps |
|
y: (N, 1, 120, C) tensor of class labels |
|
""" |
|
x = x.to(self.dtype) |
|
timestep = timestep.to(self.dtype) |
|
y = context_list[0].to(self.dtype) |
|
pos_embed = self.pos_embed.to(self.dtype) |
|
self.h, self.w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1] |
|
|
|
x = self.x_embedder(x) + pos_embed |
|
t = self.t_embedder(timestep.to(x.dtype)) |
|
t0 = self.t_block(t) |
|
y = self.y_embedder(y) |
|
mask = context_mask_list[0] |
|
|
|
assert mask is not None |
|
|
|
|
|
y = y.masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) |
|
y_lens = mask.sum(dim=1).tolist() |
|
y_lens = [int(_) for _ in y_lens] |
|
for block in self.blocks: |
|
x = auto_grad_checkpoint(block, x, y, t0, y_lens) |
|
x = self.final_layer(x, t) |
|
x = self.unpatchify(x) |
|
return x |
|
|
|
def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs): |
|
""" |
|
dpm solver donnot need variance prediction |
|
""" |
|
|
|
model_out = self.forward(x, timestep, y, mask) |
|
return model_out.chunk(2, dim=1)[0] |
|
|
|
def forward_with_cfg(self, x, timestep, y, cfg_scale, mask=None, **kwargs): |
|
""" |
|
Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance. |
|
""" |
|
|
|
half = x[: len(x) // 2] |
|
combined = th.cat([half, half], dim=0) |
|
model_out = self.forward(combined, timestep, y, mask) |
|
model_out = model_out['x'] if isinstance(model_out, dict) else model_out |
|
eps, rest = model_out[:, :8], model_out[:, 8:] |
|
cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0) |
|
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) |
|
eps = th.cat([half_eps, half_eps], dim=0) |
|
return eps |
|
|
|
|
|
def unpatchify(self, x): |
|
|
|
""" |
|
x: (N, T, patch_size 0 * patch_size 1 * C) |
|
imgs: (Bs. 256. 16. 8) |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
c = self.out_channels |
|
p0 = self.x_embedder.patch_size[0] |
|
p1 = self.x_embedder.patch_size[1] |
|
h, w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1] |
|
|
|
|
|
|
|
|
|
|
|
x = x.reshape(shape=(x.shape[0], h, w, p0, p1, c)) |
|
x = th.einsum('nhwpqc->nchpwq', x) |
|
imgs = x.reshape(shape=(x.shape[0], c, h * p0, w * p1)) |
|
return imgs |
|
|
|
def initialize_weights(self): |
|
|
|
def _basic_init(module): |
|
if isinstance(module, nn.Linear): |
|
th.nn.init.xavier_uniform_(module.weight) |
|
if module.bias is not None: |
|
nn.init.constant_(module.bias, 0) |
|
|
|
self.apply(_basic_init) |
|
|
|
|
|
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size, lewei_scale=self.lewei_scale, base_size=self.base_size) |
|
self.pos_embed.data.copy_(th.from_numpy(pos_embed).float().unsqueeze(0)) |
|
|
|
|
|
w = self.x_embedder.proj.weight.data |
|
nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
|
|
|
|
|
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) |
|
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) |
|
nn.init.normal_(self.t_block[1].weight, std=0.02) |
|
|
|
|
|
nn.init.normal_(self.y_embedder.weight, std=0.02) |
|
|
|
|
|
|
|
for block in self.blocks: |
|
nn.init.constant_(block.cross_attn.proj.weight, 0) |
|
nn.init.constant_(block.cross_attn.proj.bias, 0) |
|
|
|
|
|
nn.init.constant_(self.final_layer.linear.weight, 0) |
|
nn.init.constant_(self.final_layer.linear.bias, 0) |
|
|
|
@property |
|
def dtype(self): |
|
return next(self.parameters()).dtype |
|
|
|
class SwiGLU(nn.Module): |
|
def __init__( |
|
self, |
|
dim: int, |
|
hidden_dim: int, |
|
multiple_of: int, |
|
): |
|
super().__init__() |
|
hidden_dim = int(2 * hidden_dim / 3) |
|
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) |
|
|
|
self.w1 = ColumnParallelLinear( |
|
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x |
|
) |
|
self.w2 = RowParallelLinear( |
|
hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x |
|
) |
|
self.w3 = ColumnParallelLinear( |
|
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x |
|
) |
|
|
|
def forward(self, x): |
|
return self.w2(F.silu(self.w1(x)) * self.w3(x)) |
|
class DEBlock(nn.Module): |
|
""" |
|
Decoder block with added SpecTNT transformer |
|
""" |
|
|
|
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, FFN_type='SwiGLU', drop_path=0., window_size=0, input_size=None, use_rel_pos=False, skip=False, num_f=None, num_t=None, **block_kwargs): |
|
super().__init__() |
|
self.hidden_size = hidden_size |
|
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
|
self.attn = WindowAttention(hidden_size, num_heads=num_heads, qkv_bias=True, |
|
input_size=input_size if window_size == 0 else (window_size, window_size), |
|
use_rel_pos=use_rel_pos, **block_kwargs) |
|
self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs) |
|
|
|
|
|
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
|
self.norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
|
self.norm4 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
|
self.norm5 = nn.LayerNorm(hidden_size * num_f, elementwise_affine=False, eps=1e-6) |
|
self.norm6 = nn.LayerNorm(hidden_size * num_f, elementwise_affine=False, eps=1e-6) |
|
|
|
approx_gelu = lambda: nn.GELU(approximate="tanh") |
|
if FFN_type == 'mlp': |
|
self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0) |
|
|
|
|
|
elif FFN_type == 'SwiGLU': |
|
self.mlp = SwiGLU(hidden_size, int(hidden_size * mlp_ratio), 1) |
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
self.window_size = window_size |
|
self.scale_shift_table = nn.Parameter(th.randn(6, hidden_size) / hidden_size ** 0.5) |
|
|
|
|
|
self.skip_linear = nn.Linear(2 * hidden_size, hidden_size) if skip else None |
|
|
|
self.F_transformer = WindowAttention(hidden_size, num_heads=4, qkv_bias=True, |
|
input_size=input_size if window_size == 0 else (window_size, window_size), |
|
use_rel_pos=use_rel_pos, **block_kwargs) |
|
|
|
self.T_transformer = WindowAttention(hidden_size * num_f, num_heads=16, qkv_bias=True, |
|
input_size=input_size if window_size == 0 else (window_size, window_size), |
|
use_rel_pos=use_rel_pos, **block_kwargs) |
|
|
|
self.f_pos = nn.Embedding(num_f, hidden_size) |
|
self.t_pos = nn.Embedding(num_t, hidden_size * num_f) |
|
self.num_f = num_f |
|
self.num_t = num_t |
|
|
|
def forward(self, x_normal, end, y, t, mask=None, skip=None, ids_keep=None, **kwargs): |
|
|
|
|
|
B, D, C = x_normal.shape |
|
T = self.num_t |
|
F_add_1 = self.num_f |
|
|
|
|
|
|
|
|
|
|
|
if self.skip_linear is not None: |
|
x_normal = self.skip_linear(th.cat([x_normal, skip], dim=-1)) |
|
|
|
D = T * (F_add_1 - 1) |
|
|
|
|
|
|
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1) |
|
x_normal = x_normal + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x_normal), shift_msa, scale_msa)).reshape(B, D, C)) |
|
|
|
x_normal = x_normal.reshape(B, T, F_add_1-1, C) |
|
x_normal = th.cat((x_normal, end), 2) |
|
|
|
|
|
x_normal = x_normal.reshape(B*T, F_add_1, C) |
|
pos_f = th.arange(self.num_f, device=x.device).unsqueeze(0).expand(B*T, -1) |
|
|
|
x_normal = x_normal + self.f_pos(pos_f) |
|
|
|
x_normal = x_normal + self.F_transformer(self.norm3(x_normal)) |
|
|
|
|
|
|
|
|
|
x_normal = x_normal.reshape(B, T, F_add_1 * C) |
|
pos_t = th.arange(self.num_t, device=x.device).unsqueeze(0).expand(B, -1) |
|
x_normal = x_normal + self.t_pos(pos_t) |
|
x_normal = x_normal + self.T_transformer(self.norm5(x_normal)) |
|
|
|
|
|
|
|
x_normal = x_normal.reshape(B, T ,F_add_1, C) |
|
end = x_normal[:, :, -1, :].unsqueeze(2) |
|
x_normal = x_normal[:, :, :-1, :] |
|
x_normal = x_normal.reshape(B, T*(F_add_1 - 1), C) |
|
|
|
x_normal = x_normal + self.cross_attn(x_normal, y, mask) |
|
x_normal = x_normal + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x_normal), shift_mlp, scale_mlp))) |
|
|
|
|
|
return x_normal, end |
|
class MDTBlock(nn.Module): |
|
""" |
|
A PixArt block with adaptive layer norm (adaLN-single) conditioning. |
|
""" |
|
|
|
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, FFN_type='mlp', drop_path=0., window_size=0, input_size=None, use_rel_pos=False, skip=False, **block_kwargs): |
|
super().__init__() |
|
self.hidden_size = hidden_size |
|
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
|
self.attn = WindowAttention(hidden_size, num_heads=num_heads, qkv_bias=True, |
|
input_size=input_size if window_size == 0 else (window_size, window_size), |
|
use_rel_pos=use_rel_pos, **block_kwargs) |
|
self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs) |
|
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
|
|
|
approx_gelu = lambda: nn.GELU(approximate="tanh") |
|
if FFN_type == 'mlp': |
|
self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0) |
|
elif FFN_type == 'SwiGLU': |
|
self.mlp = SwiGLU(hidden_size, int(hidden_size * mlp_ratio), 1) |
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
self.window_size = window_size |
|
self.scale_shift_table = nn.Parameter(th.randn(6, hidden_size) / hidden_size ** 0.5) |
|
|
|
self.skip_linear = nn.Linear(2 * hidden_size, hidden_size) if skip else None |
|
|
|
def forward(self, x, y, t, mask=None, skip=None, ids_keep=None, **kwargs): |
|
B, N, C = x.shape |
|
if self.skip_linear is not None: |
|
x = self.skip_linear(th.cat([x, skip], dim=-1)) |
|
|
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1) |
|
x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C)) |
|
x = x + self.cross_attn(x, y, mask) |
|
x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) |
|
|
|
return x |
|
class PixArt_MDT_MASK_TF(nn.Module): |
|
""" |
|
Diffusion model with a Transformer backbone. |
|
""" |
|
|
|
def __init__(self, input_size=(256,16), patch_size=(16,4), overlap=(0, 0), in_channels=8, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, pred_sigma=False, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, cond_dim=1024, lewei_scale=1.0, |
|
use_cfg=True, cfg_scale=4.0, config=None, model_max_length=120, mask_t=0.17, mask_f=0.15, decode_layer=4,**kwargs): |
|
if window_block_indexes is None: |
|
window_block_indexes = [] |
|
super().__init__() |
|
self.use_cfg = use_cfg |
|
self.cfg_scale = cfg_scale |
|
self.input_size = input_size |
|
self.pred_sigma = pred_sigma |
|
self.in_channels = in_channels |
|
self.out_channels = in_channels |
|
self.patch_size = patch_size |
|
self.num_heads = num_heads |
|
self.lewei_scale = lewei_scale, |
|
decode_layer = int(decode_layer) |
|
|
|
self.x_embedder = PatchEmbed(input_size, patch_size, overlap, in_channels, hidden_size, bias=True) |
|
|
|
self.t_embedder = TimestepEmbedder(hidden_size) |
|
num_patches = self.x_embedder.num_patches |
|
self.base_size = input_size[0] // self.patch_size[0] * 2 |
|
|
|
self.register_buffer("pos_embed", th.zeros(1, num_patches, hidden_size)) |
|
|
|
approx_gelu = lambda: nn.GELU(approximate="tanh") |
|
self.t_block = nn.Sequential( |
|
nn.SiLU(), |
|
nn.Linear(hidden_size, 6 * hidden_size, bias=True) |
|
) |
|
self.y_embedder = nn.Linear(cond_dim, hidden_size) |
|
|
|
half_depth = (depth - decode_layer)//2 |
|
self.half_depth=half_depth |
|
|
|
drop_path_half = [x.item() for x in th.linspace(0, drop_path, half_depth)] |
|
drop_path_decode = [x.item() for x in th.linspace(0, drop_path, decode_layer)] |
|
self.en_inblocks = nn.ModuleList([ |
|
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i], |
|
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), |
|
window_size=0, |
|
use_rel_pos=False, FFN_type='mlp') for i in range(half_depth) |
|
]) |
|
self.en_outblocks = nn.ModuleList([ |
|
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i], |
|
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), |
|
window_size=0, |
|
use_rel_pos=False, skip=True, FFN_type='mlp') for i in range(half_depth) |
|
]) |
|
self.de_blocks = nn.ModuleList([ |
|
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_decode[i], |
|
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), |
|
window_size=0, |
|
use_rel_pos=False, skip=True, FFN_type='mlp') for i in range(decode_layer) |
|
]) |
|
self.sideblocks = nn.ModuleList([ |
|
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, |
|
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), |
|
window_size=0, |
|
use_rel_pos=False, FFN_type='mlp') for _ in range(1) |
|
]) |
|
|
|
self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels) |
|
|
|
self.decoder_pos_embed = nn.Parameter(th.zeros( |
|
1, num_patches, hidden_size), requires_grad=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert mask_t != 0 and mask_f != 0 |
|
self.mask_token = nn.Parameter(th.zeros(1, 1, hidden_size)) |
|
self.mask_t = mask_t |
|
self.mask_f = mask_f |
|
self.decode_layer = int(decode_layer) |
|
print(f"mask ratio: T-{self.mask_t} F-{self.mask_f}", "decode_layer:", self.decode_layer) |
|
self.initialize_weights() |
|
|
|
|
|
def forward(self, x, timestep, context_list, context_mask_list=None, enable_mask=False, **kwargs): |
|
""" |
|
Forward pass of PixArt. |
|
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) |
|
t: (N,) tensor of diffusion timesteps |
|
y: (N, 1, 120, C) tensor of class labels |
|
""" |
|
x = x.to(self.dtype) |
|
timestep = timestep.to(self.dtype) |
|
y = context_list[0].to(self.dtype) |
|
pos_embed = self.pos_embed.to(self.dtype) |
|
self.h, self.w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1] |
|
|
|
|
|
x = self.x_embedder(x) + pos_embed |
|
t = self.t_embedder(timestep.to(x.dtype)) |
|
t0 = self.t_block(t) |
|
y = self.y_embedder(y) |
|
|
|
try: |
|
mask = context_mask_list[0] |
|
except: |
|
mask = th.ones(x.shape[0], 1).to(x.device) |
|
print("MASK !!!!!!!!!!!!!!!!!!!!!!!!!") |
|
|
|
assert mask is not None |
|
|
|
|
|
y = y.masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) |
|
y_lens = mask.sum(dim=1).tolist() |
|
y_lens = [int(_) for _ in y_lens] |
|
|
|
input_skip = x |
|
|
|
masked_stage = False |
|
skips = [] |
|
|
|
if self.mask_t is not None and self.training: |
|
|
|
rand_mask_ratio = th.rand(1, device=x.device) |
|
rand_mask_ratio_t = rand_mask_ratio * 0.13 + self.mask_t |
|
rand_mask_ratio_f = rand_mask_ratio * 0.13 + self.mask_f |
|
|
|
x, mask, ids_restore, ids_keep = self.random_masking_2d( |
|
x, rand_mask_ratio_t, rand_mask_ratio_f) |
|
masked_stage = True |
|
|
|
|
|
for block in self.en_inblocks: |
|
if masked_stage: |
|
x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=ids_keep) |
|
else: |
|
x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=None) |
|
skips.append(x) |
|
|
|
for block in self.en_outblocks: |
|
if masked_stage: |
|
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=ids_keep) |
|
else: |
|
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=None) |
|
|
|
if self.mask_t is not None and self.mask_f is not None and self.training: |
|
x = self.forward_side_interpolater(x, y, t0, y_lens, mask, ids_restore) |
|
masked_stage = False |
|
else: |
|
|
|
x = x + self.decoder_pos_embed |
|
|
|
for i in range(len(self.de_blocks)): |
|
block = self.de_blocks[i] |
|
this_skip = input_skip |
|
|
|
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=this_skip, ids_keep=None) |
|
|
|
x = self.final_layer(x, t) |
|
x = self.unpatchify(x) |
|
return x |
|
|
|
def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs): |
|
""" |
|
dpm solver donnot need variance prediction |
|
""" |
|
|
|
model_out = self.forward(x, timestep, y, mask) |
|
return model_out.chunk(2, dim=1)[0] |
|
|
|
def forward_with_cfg(self, x, timestep, context_list, context_mask_list=None, cfg_scale=4.0, **kwargs): |
|
""" |
|
Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance. |
|
""" |
|
|
|
|
|
|
|
half = x[: len(x) // 2] |
|
combined = th.cat([half, half], dim=0) |
|
model_out = self.forward(combined, timestep, context_list, context_mask_list=None) |
|
model_out = model_out['x'] if isinstance(model_out, dict) else model_out |
|
eps, rest = model_out[:, :8], model_out[:, 8:] |
|
cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0) |
|
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) |
|
eps = th.cat([half_eps, half_eps], dim=0) |
|
return eps |
|
|
|
|
|
def unpatchify(self, x): |
|
|
|
""" |
|
x: (N, T, patch_size 0 * patch_size 1 * C) |
|
imgs: (Bs. 256. 16. 8) |
|
""" |
|
if self.x_embedder.ol == (0, 0) or self.x_embedder.ol == [0, 0]: |
|
c = self.out_channels |
|
p0 = self.x_embedder.patch_size[0] |
|
p1 = self.x_embedder.patch_size[1] |
|
h, w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1] |
|
|
|
x = x.reshape(shape=(x.shape[0], h, w, p0, p1, c)) |
|
x = th.einsum('nhwpqc->nchpwq', x) |
|
imgs = x.reshape(shape=(x.shape[0], c, h * p0, w * p1)) |
|
return imgs |
|
|
|
lf = self.x_embedder.grid_size[0] |
|
rf = self.x_embedder.grid_size[1] |
|
lp = self.x_embedder.patch_size[0] |
|
rp = self.x_embedder.patch_size[1] |
|
lo = self.x_embedder.ol[0] |
|
ro = self.x_embedder.ol[1] |
|
lm = self.x_embedder.img_size[0] |
|
rm = self.x_embedder.img_size[1] |
|
lpad = self.x_embedder.pad_size[0] |
|
rpad = self.x_embedder.pad_size[1] |
|
bs = x.shape[0] |
|
|
|
torch_map = self.torch_map |
|
|
|
c = self.out_channels |
|
|
|
x = x.reshape(shape=(bs, lf, rf, lp, rp, c)) |
|
x = th.einsum('nhwpqc->nchwpq', x) |
|
|
|
added_map = th.zeros(bs, c, lm+2*lpad, rm+2*rpad).to(x.device) |
|
|
|
for i in range(lf): |
|
for j in range(rf): |
|
xx = (i) * (lp - lo) |
|
yy = (j) * (rp - ro) |
|
added_map[:, :, xx:(xx+lp), yy:(yy+rp)] += \ |
|
x[:, :, i, j, :, :] |
|
|
|
|
|
added_map = added_map[:][:][lpad:lm+lpad, rpad:rm+rpad] |
|
return th.mul(added_map.to(x.device), torch_map.to(x.device)) |
|
|
|
def random_masking_2d(self, x, mask_t_prob, mask_f_prob): |
|
""" |
|
2D: Spectrogram (msking t and f under mask_t_prob and mask_f_prob) |
|
Perform per-sample random masking by per-sample shuffling. |
|
Per-sample shuffling is done by argsort random noise. |
|
x: [N, L, D], sequence |
|
""" |
|
N, L, D = x.shape |
|
|
|
|
|
|
|
|
|
|
|
|
|
T = self.x_embedder.grid_size[0] |
|
F = self.x_embedder.grid_size[1] |
|
|
|
len_keep_t = int(T * (1 - mask_t_prob)) |
|
len_keep_f = int(F * (1 - mask_f_prob)) |
|
|
|
|
|
noise_t = th.rand(N, T, device=x.device) |
|
|
|
ids_shuffle_t = th.argsort(noise_t, dim=1) |
|
ids_restore_t = th.argsort(ids_shuffle_t, dim=1) |
|
ids_keep_t = ids_shuffle_t[:,:len_keep_t] |
|
|
|
noise_f = th.rand(N, F, device=x.device) |
|
ids_shuffle_f = th.argsort(noise_f, dim=1) |
|
ids_restore_f = th.argsort(ids_shuffle_f, dim=1) |
|
ids_keep_f = ids_shuffle_f[:,:len_keep_f] |
|
|
|
|
|
|
|
mask_f = th.ones(N, F, device=x.device) |
|
mask_f[:,:len_keep_f] = 0 |
|
mask_f = th.gather(mask_f, dim=1, index=ids_restore_f).unsqueeze(1).repeat(1,T,1) |
|
|
|
mask_t = th.ones(N, T, device=x.device) |
|
mask_t[:,:len_keep_t] = 0 |
|
mask_t = th.gather(mask_t, dim=1, index=ids_restore_t).unsqueeze(1).repeat(1,F,1).permute(0,2,1) |
|
mask = 1-(1-mask_t)*(1-mask_f) |
|
|
|
|
|
id2res=th.Tensor(list(range(N*T*F))).reshape(N,T,F).to(x.device) |
|
id2res = id2res + 999*mask |
|
id2res2 = th.argsort(id2res.flatten(start_dim=1)) |
|
ids_keep=id2res2.flatten(start_dim=1)[:,:len_keep_f*len_keep_t] |
|
x_masked = th.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) |
|
|
|
ids_restore = th.argsort(id2res2.flatten(start_dim=1)) |
|
mask = mask.flatten(start_dim=1) |
|
|
|
return x_masked, mask, ids_restore, ids_keep |
|
|
|
def random_masking(self, x, mask_ratio): |
|
""" |
|
Perform per-sample random masking by per-sample shuffling. |
|
Per-sample shuffling is done by argsort random noise. |
|
x: [N, L, D], sequence |
|
""" |
|
N, L, D = x.shape |
|
len_keep = int(L * (1 - mask_ratio)) |
|
|
|
noise = th.rand(N, L, device=x.device) |
|
|
|
|
|
|
|
ids_shuffle = th.argsort(noise, dim=1) |
|
ids_restore = th.argsort(ids_shuffle, dim=1) |
|
|
|
|
|
ids_keep = ids_shuffle[:, :len_keep] |
|
x_masked = th.gather( |
|
x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) |
|
|
|
|
|
mask = th.ones([N, L], device=x.device) |
|
mask[:, :len_keep] = 0 |
|
|
|
mask = th.gather(mask, dim=1, index=ids_restore) |
|
|
|
return x_masked, mask, ids_restore, ids_keep |
|
|
|
def forward_side_interpolater(self, x, y, t0, y_lens, mask, ids_restore): |
|
|
|
mask_tokens = self.mask_token.repeat( |
|
x.shape[0], ids_restore.shape[1] - x.shape[1], 1) |
|
x_ = th.cat([x, mask_tokens], dim=1) |
|
x = th.gather( |
|
x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) |
|
|
|
|
|
x = x + self.decoder_pos_embed |
|
|
|
|
|
x_before = x |
|
for sideblock in self.sideblocks: |
|
x = sideblock(x, y, t0, y_lens, ids_keep=None) |
|
|
|
|
|
mask = mask.unsqueeze(dim=-1) |
|
x = x*mask + (1-mask)*x_before |
|
|
|
return x |
|
|
|
def initialize_weights(self): |
|
|
|
def _basic_init(module): |
|
if isinstance(module, nn.Linear): |
|
th.nn.init.xavier_uniform_(module.weight) |
|
if module.bias is not None: |
|
nn.init.constant_(module.bias, 0) |
|
|
|
self.apply(_basic_init) |
|
|
|
|
|
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size, lewei_scale=self.lewei_scale, base_size=self.base_size) |
|
self.pos_embed.data.copy_(th.from_numpy(pos_embed).float().unsqueeze(0)) |
|
|
|
|
|
w = self.x_embedder.proj.weight.data |
|
nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
|
|
|
|
|
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) |
|
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) |
|
nn.init.normal_(self.t_block[1].weight, std=0.02) |
|
|
|
|
|
nn.init.normal_(self.y_embedder.weight, std=0.02) |
|
|
|
|
|
|
|
for block in self.en_inblocks: |
|
nn.init.constant_(block.cross_attn.proj.weight, 0) |
|
nn.init.constant_(block.cross_attn.proj.bias, 0) |
|
for block in self.en_outblocks: |
|
nn.init.constant_(block.cross_attn.proj.weight, 0) |
|
nn.init.constant_(block.cross_attn.proj.bias, 0) |
|
for block in self.de_blocks: |
|
nn.init.constant_(block.cross_attn.proj.weight, 0) |
|
nn.init.constant_(block.cross_attn.proj.bias, 0) |
|
for block in self.sideblocks: |
|
nn.init.constant_(block.cross_attn.proj.weight, 0) |
|
nn.init.constant_(block.cross_attn.proj.bias, 0) |
|
|
|
|
|
nn.init.constant_(self.final_layer.linear.weight, 0) |
|
nn.init.constant_(self.final_layer.linear.bias, 0) |
|
|
|
if self.x_embedder.ol == [0, 0] or self.x_embedder.ol == (0, 0): |
|
return |
|
|
|
lf = self.x_embedder.grid_size[0] |
|
rf = self.x_embedder.grid_size[1] |
|
lp = self.x_embedder.patch_size[0] |
|
rp = self.x_embedder.patch_size[1] |
|
lo = self.x_embedder.ol[0] |
|
ro = self.x_embedder.ol[1] |
|
lm = self.x_embedder.img_size[0] |
|
rm = self.x_embedder.img_size[1] |
|
lpad = self.x_embedder.pad_size[0] |
|
rpad = self.x_embedder.pad_size[1] |
|
|
|
torch_map = th.zeros(lm+2*lpad, rm+2*rpad).to('cuda') |
|
for i in range(lf): |
|
for j in range(rf): |
|
xx = (i) * (lp - lo) |
|
yy = (j) * (rp - ro) |
|
torch_map[xx:(xx+lp), yy:(yy+rp)]+=1 |
|
torch_map = torch_map[lpad:lm+lpad, rpad:rm+rpad] |
|
self.torch_map = th.reciprocal(torch_map) |
|
|
|
@property |
|
def dtype(self): |
|
return next(self.parameters()).dtype |
|
|
|
class PixArt_MDT_MOS_AS_TOKEN(nn.Module): |
|
""" |
|
Diffusion model with a Transformer backbone. |
|
""" |
|
|
|
def __init__(self, input_size=(256,16), patch_size=(16,4), overlap=(0, 0), in_channels=8, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, pred_sigma=False, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, cond_dim=1024, lewei_scale=1.0, |
|
use_cfg=True, cfg_scale=4.0, config=None, model_max_length=120, mask_ratio=None, decode_layer=4,**kwargs): |
|
if window_block_indexes is None: |
|
window_block_indexes = [] |
|
super().__init__() |
|
self.use_cfg = use_cfg |
|
self.cfg_scale = cfg_scale |
|
self.input_size = input_size |
|
self.pred_sigma = pred_sigma |
|
self.in_channels = in_channels |
|
self.out_channels = in_channels |
|
self.patch_size = patch_size |
|
self.num_heads = num_heads |
|
self.lewei_scale = lewei_scale, |
|
decode_layer = int(decode_layer) |
|
|
|
self.mos_embed = nn.Embedding(num_embeddings=5, embedding_dim=hidden_size) |
|
|
|
self.x_embedder = PatchEmbed(input_size, patch_size, overlap, in_channels, hidden_size, bias=True) |
|
|
|
self.t_embedder = TimestepEmbedder(hidden_size) |
|
num_patches = self.x_embedder.num_patches |
|
self.base_size = input_size[0] // self.patch_size[0] * 2 |
|
|
|
self.register_buffer("pos_embed", th.zeros(1, num_patches, hidden_size)) |
|
|
|
approx_gelu = lambda: nn.GELU(approximate="tanh") |
|
self.t_block = nn.Sequential( |
|
nn.SiLU(), |
|
nn.Linear(hidden_size, 6 * hidden_size, bias=True) |
|
) |
|
|
|
|
|
|
|
|
|
self.y_embedder = nn.Linear(cond_dim, hidden_size) |
|
|
|
half_depth = (depth - decode_layer)//2 |
|
self.half_depth=half_depth |
|
|
|
drop_path_half = [x.item() for x in th.linspace(0, drop_path, half_depth)] |
|
drop_path_decode = [x.item() for x in th.linspace(0, drop_path, decode_layer)] |
|
self.en_inblocks = nn.ModuleList([ |
|
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i], |
|
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), |
|
window_size=0, |
|
use_rel_pos=False, FFN_type='mlp') for i in range(half_depth) |
|
]) |
|
self.en_outblocks = nn.ModuleList([ |
|
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i], |
|
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), |
|
window_size=0, |
|
use_rel_pos=False, skip=True, FFN_type='mlp') for i in range(half_depth) |
|
]) |
|
self.de_blocks = nn.ModuleList([ |
|
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_decode[i], |
|
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), |
|
window_size=0, |
|
use_rel_pos=False, skip=True, FFN_type='mlp') for i in range(decode_layer) |
|
]) |
|
self.sideblocks = nn.ModuleList([ |
|
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, |
|
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), |
|
window_size=0, |
|
use_rel_pos=False, FFN_type='mlp') for _ in range(1) |
|
]) |
|
|
|
self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels) |
|
|
|
self.decoder_pos_embed = nn.Parameter(th.zeros( |
|
1, num_patches, hidden_size), requires_grad=True) |
|
if mask_ratio is not None: |
|
self.mask_token = nn.Parameter(th.zeros(1, 1, hidden_size)) |
|
self.mask_ratio = float(mask_ratio) |
|
self.decode_layer = int(decode_layer) |
|
else: |
|
self.mask_token = nn.Parameter(th.zeros( |
|
1, 1, hidden_size), requires_grad=False) |
|
self.mask_ratio = None |
|
self.decode_layer = int(decode_layer) |
|
print("mask ratio:", self.mask_ratio, "decode_layer:", self.decode_layer) |
|
self.initialize_weights() |
|
|
|
|
|
def forward(self, x, timestep, context_list, context_mask_list=None, enable_mask=False, mos=None, **kwargs): |
|
""" |
|
Forward pass of PixArt. |
|
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) |
|
t: (N,) tensor of diffusion timesteps |
|
y: (N, 1, 120, C) tensor of class labels |
|
""" |
|
|
|
|
|
assert mos.shape[0] == x.shape[0] |
|
|
|
mos = mos - 1 |
|
mos = self.mos_embed(mos.to(x.device).to(th.int)) |
|
x = x.to(self.dtype) |
|
timestep = timestep.to(self.dtype) |
|
y = context_list[0].to(self.dtype) |
|
pos_embed = self.pos_embed.to(self.dtype) |
|
self.h, self.w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1] |
|
|
|
|
|
x = self.x_embedder(x) + pos_embed |
|
|
|
t = self.t_embedder(timestep.to(x.dtype)) |
|
t0 = self.t_block(t) |
|
y = self.y_embedder(y) |
|
|
|
try: |
|
mask = context_mask_list[0] |
|
except: |
|
mask = th.ones(x.shape[0], 1).to(x.device) |
|
|
|
assert mask is not None |
|
|
|
|
|
y = y.masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) |
|
y_lens = mask.sum(dim=1).tolist() |
|
y_lens = [int(_) for _ in y_lens] |
|
|
|
masked_stage = False |
|
skips = [] |
|
|
|
try: |
|
x = th.cat([mos, x], dim=1) |
|
except: |
|
x = th.cat([mos.unsqueeze(1), x], dim=1) |
|
input_skip = x |
|
|
|
if self.mask_ratio is not None and self.training: |
|
|
|
rand_mask_ratio = th.rand(1, device=x.device) |
|
rand_mask_ratio = rand_mask_ratio * 0.2 + self.mask_ratio |
|
|
|
x, mask, ids_restore, ids_keep = self.random_masking( |
|
x, rand_mask_ratio) |
|
masked_stage = True |
|
for block in self.en_inblocks: |
|
if masked_stage: |
|
x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=ids_keep) |
|
else: |
|
x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=None) |
|
skips.append(x) |
|
|
|
for block in self.en_outblocks: |
|
if masked_stage: |
|
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=ids_keep) |
|
else: |
|
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=None) |
|
|
|
if self.mask_ratio is not None and self.training: |
|
x = self.forward_side_interpolater(x, y, t0, y_lens, mask, ids_restore) |
|
masked_stage = False |
|
else: |
|
|
|
x[:, 1:, :] = x[:, 1:, :] + self.decoder_pos_embed |
|
|
|
|
|
|
|
for i in range(len(self.de_blocks)): |
|
block = self.de_blocks[i] |
|
this_skip = input_skip |
|
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=this_skip, ids_keep=None) |
|
x = x[:, 1:, :] |
|
x = self.final_layer(x, t) |
|
x = self.unpatchify(x) |
|
|
|
|
|
return x |
|
|
|
def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs): |
|
""" |
|
dpm solver donnot need variance prediction |
|
""" |
|
|
|
model_out = self.forward(x, timestep, y, mask) |
|
return model_out.chunk(2, dim=1)[0] |
|
|
|
def forward_with_cfg(self, x, timestep, context_list, context_mask_list=None, cfg_scale=4.0, **kwargs): |
|
""" |
|
Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance. |
|
""" |
|
|
|
|
|
|
|
half = x[: len(x) // 2] |
|
combined = th.cat([half, half], dim=0) |
|
model_out = self.forward(combined, timestep, context_list, context_mask_list=None) |
|
model_out = model_out['x'] if isinstance(model_out, dict) else model_out |
|
eps, rest = model_out[:, :8], model_out[:, 8:] |
|
cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0) |
|
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) |
|
eps = th.cat([half_eps, half_eps], dim=0) |
|
return eps |
|
|
|
|
|
def unpatchify(self, x): |
|
|
|
""" |
|
x: (N, T, patch_size 0 * patch_size 1 * C) |
|
imgs: (Bs. 256. 16. 8) |
|
""" |
|
if self.x_embedder.ol == (0, 0) or self.x_embedder.ol == [0, 0]: |
|
c = self.out_channels |
|
p0 = self.x_embedder.patch_size[0] |
|
p1 = self.x_embedder.patch_size[1] |
|
h, w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1] |
|
|
|
x = x.reshape(shape=(x.shape[0], h, w, p0, p1, c)) |
|
x = th.einsum('nhwpqc->nchpwq', x) |
|
imgs = x.reshape(shape=(x.shape[0], c, h * p0, w * p1)) |
|
return imgs |
|
|
|
lf = self.x_embedder.grid_size[0] |
|
rf = self.x_embedder.grid_size[1] |
|
lp = self.x_embedder.patch_size[0] |
|
rp = self.x_embedder.patch_size[1] |
|
lo = self.x_embedder.ol[0] |
|
ro = self.x_embedder.ol[1] |
|
lm = self.x_embedder.img_size[0] |
|
rm = self.x_embedder.img_size[1] |
|
lpad = self.x_embedder.pad_size[0] |
|
rpad = self.x_embedder.pad_size[1] |
|
bs = x.shape[0] |
|
|
|
torch_map = self.torch_map |
|
|
|
c = self.out_channels |
|
|
|
x = x.reshape(shape=(bs, lf, rf, lp, rp, c)) |
|
x = th.einsum('nhwpqc->nchwpq', x) |
|
|
|
added_map = th.zeros(bs, c, lm+2*lpad, rm+2*rpad).to(x.device) |
|
|
|
for i in range(lf): |
|
for j in range(rf): |
|
xx = (i) * (lp - lo) |
|
yy = (j) * (rp - ro) |
|
added_map[:, :, xx:(xx+lp), yy:(yy+rp)] += \ |
|
x[:, :, i, j, :, :] |
|
|
|
|
|
added_map = added_map[:][:][lpad:lm+lpad, rpad:rm+rpad] |
|
return th.mul(added_map.to(x.device), torch_map.to(x.device)) |
|
|
|
|
|
def random_masking(self, x, mask_ratio): |
|
""" |
|
Perform per-sample random masking by per-sample shuffling. |
|
Per-sample shuffling is done by argsort random noise. |
|
x: [N, L, D], sequence |
|
""" |
|
N, L, D = x.shape |
|
L = L - 1 |
|
len_keep = int(L * (1 - mask_ratio)) |
|
|
|
noise = th.rand(N, L, device=x.device) |
|
|
|
|
|
|
|
ids_shuffle = th.argsort(noise, dim=1) |
|
ids_restore = th.argsort(ids_shuffle, dim=1) |
|
|
|
|
|
ids_keep = ids_shuffle[:, :len_keep] |
|
x_masked = th.gather( |
|
x[:, 1:, :], dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) |
|
|
|
x_masked = th.cat([x[:, 0, :].unsqueeze(1), x_masked], dim=1) |
|
|
|
|
|
|
|
mask = th.ones([N, L], device=x.device) |
|
mask[:, :len_keep] = 0 |
|
|
|
mask = th.gather(mask, dim=1, index=ids_restore) |
|
|
|
return x_masked, mask, ids_restore, ids_keep |
|
|
|
def forward_side_interpolater(self, x, y, t0, y_lens, mask, ids_restore): |
|
|
|
mask_tokens = self.mask_token.repeat( |
|
x.shape[0], ids_restore.shape[1] - x.shape[1] + 1, 1) |
|
x_ = th.cat([x[:, 1:, :], mask_tokens], dim=1) |
|
|
|
x_ = th.gather( |
|
x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) |
|
|
|
|
|
x_ = x_ + self.decoder_pos_embed |
|
x = th.cat([x[:, 0, :].unsqueeze(1), x_], dim=1) |
|
|
|
|
|
|
|
x_before = x |
|
for sideblock in self.sideblocks: |
|
x = sideblock(x, y, t0, y_lens, ids_keep=None) |
|
|
|
|
|
mask = mask.unsqueeze(dim=-1) |
|
|
|
mask = th.cat([th.ones(mask.shape[0], 1, 1).to(mask.device), mask], dim=1) |
|
x = x*mask + (1-mask)*x_before |
|
|
|
return x |
|
|
|
def initialize_weights(self): |
|
|
|
def _basic_init(module): |
|
if isinstance(module, nn.Linear): |
|
th.nn.init.xavier_uniform_(module.weight) |
|
if module.bias is not None: |
|
nn.init.constant_(module.bias, 0) |
|
|
|
self.apply(_basic_init) |
|
|
|
|
|
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size, lewei_scale=self.lewei_scale, base_size=self.base_size) |
|
self.pos_embed.data.copy_(th.from_numpy(pos_embed).float().unsqueeze(0)) |
|
|
|
|
|
w = self.x_embedder.proj.weight.data |
|
nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
|
|
|
|
|
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) |
|
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) |
|
nn.init.normal_(self.t_block[1].weight, std=0.02) |
|
|
|
|
|
nn.init.normal_(self.y_embedder.weight, std=0.02) |
|
|
|
|
|
|
|
for block in self.en_inblocks: |
|
nn.init.constant_(block.cross_attn.proj.weight, 0) |
|
nn.init.constant_(block.cross_attn.proj.bias, 0) |
|
for block in self.en_outblocks: |
|
nn.init.constant_(block.cross_attn.proj.weight, 0) |
|
nn.init.constant_(block.cross_attn.proj.bias, 0) |
|
for block in self.de_blocks: |
|
nn.init.constant_(block.cross_attn.proj.weight, 0) |
|
nn.init.constant_(block.cross_attn.proj.bias, 0) |
|
for block in self.sideblocks: |
|
nn.init.constant_(block.cross_attn.proj.weight, 0) |
|
nn.init.constant_(block.cross_attn.proj.bias, 0) |
|
|
|
|
|
nn.init.constant_(self.final_layer.linear.weight, 0) |
|
nn.init.constant_(self.final_layer.linear.bias, 0) |
|
|
|
if self.x_embedder.ol == [0, 0] or self.x_embedder.ol == (0, 0): |
|
return |
|
|
|
lf = self.x_embedder.grid_size[0] |
|
rf = self.x_embedder.grid_size[1] |
|
lp = self.x_embedder.patch_size[0] |
|
rp = self.x_embedder.patch_size[1] |
|
lo = self.x_embedder.ol[0] |
|
ro = self.x_embedder.ol[1] |
|
lm = self.x_embedder.img_size[0] |
|
rm = self.x_embedder.img_size[1] |
|
lpad = self.x_embedder.pad_size[0] |
|
rpad = self.x_embedder.pad_size[1] |
|
|
|
torch_map = th.zeros(lm+2*lpad, rm+2*rpad).to('cuda') |
|
for i in range(lf): |
|
for j in range(rf): |
|
xx = (i) * (lp - lo) |
|
yy = (j) * (rp - ro) |
|
torch_map[xx:(xx+lp), yy:(yy+rp)]+=1 |
|
torch_map = torch_map[lpad:lm+lpad, rpad:rm+rpad] |
|
self.torch_map = th.reciprocal(torch_map) |
|
|
|
@property |
|
def dtype(self): |
|
return next(self.parameters()).dtype |
|
|
|
|
|
class PixArt_MDT_LC(nn.Module): |
|
""" |
|
Diffusion model with a Transformer backbone. |
|
""" |
|
|
|
def __init__(self, input_size=(256,16), patch_size=(16,4), overlap=(0, 0), in_channels=8, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, pred_sigma=False, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, cond_dim=1024, lewei_scale=1.0, |
|
use_cfg=True, cfg_scale=4.0, config=None, model_max_length=120, mask_ratio=None, decode_layer=4,**kwargs): |
|
if window_block_indexes is None: |
|
window_block_indexes = [] |
|
super().__init__() |
|
self.use_cfg = use_cfg |
|
self.cfg_scale = cfg_scale |
|
self.input_size = input_size |
|
self.pred_sigma = pred_sigma |
|
self.in_channels = in_channels |
|
self.out_channels = in_channels |
|
self.patch_size = patch_size |
|
self.num_heads = num_heads |
|
self.lewei_scale = lewei_scale, |
|
decode_layer = int(decode_layer) |
|
|
|
self.x_embedder = PatchEmbed(input_size, patch_size, overlap, in_channels, hidden_size, bias=True) |
|
|
|
self.t_embedder = TimestepEmbedder(hidden_size) |
|
num_patches = self.x_embedder.num_patches |
|
self.base_size = input_size[0] // self.patch_size[0] * 2 |
|
|
|
self.register_buffer("pos_embed", th.zeros(1, num_patches, hidden_size)) |
|
|
|
approx_gelu = lambda: nn.GELU(approximate="tanh") |
|
self.t_block = nn.Sequential( |
|
nn.SiLU(), |
|
nn.Linear(hidden_size, 6 * hidden_size, bias=True) |
|
) |
|
self.y_embedder = nn.Linear(cond_dim, hidden_size) |
|
|
|
half_depth = (depth - decode_layer)//2 |
|
self.half_depth=half_depth |
|
|
|
drop_path_half = [x.item() for x in th.linspace(0, drop_path, half_depth)] |
|
drop_path_decode = [x.item() for x in th.linspace(0, drop_path, decode_layer)] |
|
self.en_inblocks = nn.ModuleList([ |
|
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i], |
|
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), |
|
window_size=0, |
|
use_rel_pos=False, FFN_type='mlp') for i in range(half_depth) |
|
]) |
|
self.en_outblocks = nn.ModuleList([ |
|
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i], |
|
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), |
|
window_size=0, |
|
use_rel_pos=False, skip=True, FFN_type='mlp') for i in range(half_depth) |
|
]) |
|
self.de_blocks = nn.ModuleList([ |
|
DEBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_decode[i], |
|
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), |
|
window_size=0, |
|
use_rel_pos=False, skip=True, FFN_type='mlp', num_f=self.x_embedder.grid_size[1]+1, num_t=self.x_embedder.grid_size[0]) for i in range(decode_layer) |
|
]) |
|
self.sideblocks = nn.ModuleList([ |
|
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, |
|
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), |
|
window_size=0, |
|
use_rel_pos=False, FFN_type='mlp') for _ in range(1) |
|
]) |
|
|
|
self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels) |
|
|
|
self.decoder_pos_embed = nn.Parameter(th.zeros( |
|
1, num_patches, hidden_size), requires_grad=True) |
|
if mask_ratio is not None: |
|
self.mask_token = nn.Parameter(th.zeros(1, 1, hidden_size)) |
|
self.mask_ratio = float(mask_ratio) |
|
self.decode_layer = int(decode_layer) |
|
else: |
|
self.mask_token = nn.Parameter(th.zeros( |
|
1, 1, hidden_size), requires_grad=False) |
|
self.mask_ratio = None |
|
self.decode_layer = int(decode_layer) |
|
print("mask ratio:", self.mask_ratio, "decode_layer:", self.decode_layer) |
|
self.initialize_weights() |
|
|
|
|
|
def forward(self, x, timestep, context_list, context_mask_list=None, enable_mask=False, **kwargs): |
|
""" |
|
Forward pass of PixArt. |
|
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) |
|
t: (N,) tensor of diffusion timesteps |
|
y: (N, 1, 120, C) tensor of class labels |
|
""" |
|
x = x.to(self.dtype) |
|
timestep = timestep.to(self.dtype) |
|
y = context_list[0].to(self.dtype) |
|
pos_embed = self.pos_embed.to(self.dtype) |
|
self.h, self.w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1] |
|
|
|
|
|
x = self.x_embedder(x) + pos_embed |
|
t = self.t_embedder(timestep.to(x.dtype)) |
|
t0 = self.t_block(t) |
|
y = self.y_embedder(y) |
|
|
|
try: |
|
mask = context_mask_list[0] |
|
except: |
|
mask = th.ones(x.shape[0], 1).to(x.device) |
|
print("MASK !!!!!!!!!!!!!!!!!!!!!!!!!") |
|
|
|
assert mask is not None |
|
|
|
|
|
y = y.masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) |
|
y_lens = mask.sum(dim=1).tolist() |
|
y_lens = [int(_) for _ in y_lens] |
|
|
|
input_skip = x |
|
|
|
masked_stage = False |
|
skips = [] |
|
|
|
if self.mask_ratio is not None and self.training: |
|
|
|
rand_mask_ratio = th.rand(1, device=x.device) |
|
rand_mask_ratio = rand_mask_ratio * 0.2 + self.mask_ratio |
|
|
|
x, mask, ids_restore, ids_keep = self.random_masking( |
|
x, rand_mask_ratio) |
|
masked_stage = True |
|
|
|
|
|
for block in self.en_inblocks: |
|
if masked_stage: |
|
x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=ids_keep) |
|
else: |
|
x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=None) |
|
skips.append(x) |
|
|
|
for block in self.en_outblocks: |
|
if masked_stage: |
|
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=ids_keep) |
|
else: |
|
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=None) |
|
|
|
if self.mask_ratio is not None and self.training: |
|
x = self.forward_side_interpolater(x, y, t0, y_lens, mask, ids_restore) |
|
masked_stage = False |
|
else: |
|
|
|
x = x + self.decoder_pos_embed |
|
bs = x.shape[0] |
|
|
|
bs, D, L = x.shape |
|
T, F = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1] |
|
|
|
end = th.zeros(bs, T, 1, L).to(x.device) |
|
|
|
|
|
|
|
for i in range(len(self.de_blocks)): |
|
block = self.de_blocks[i] |
|
this_skip = input_skip |
|
x, end = auto_grad_checkpoint(block, x, end, y, t0, y_lens, skip=this_skip, ids_keep=None) |
|
|
|
x = self.final_layer(x, t) |
|
x = self.unpatchify(x) |
|
return x |
|
|
|
def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs): |
|
""" |
|
dpm solver donnot need variance prediction |
|
""" |
|
|
|
model_out = self.forward(x, timestep, y, mask) |
|
return model_out.chunk(2, dim=1)[0] |
|
|
|
def forward_with_cfg(self, x, timestep, context_list, context_mask_list=None, cfg_scale=4.0, **kwargs): |
|
""" |
|
Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance. |
|
""" |
|
|
|
|
|
|
|
half = x[: len(x) // 2] |
|
combined = th.cat([half, half], dim=0) |
|
model_out = self.forward(combined, timestep, context_list, context_mask_list=None) |
|
model_out = model_out['x'] if isinstance(model_out, dict) else model_out |
|
eps, rest = model_out[:, :8], model_out[:, 8:] |
|
cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0) |
|
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) |
|
eps = th.cat([half_eps, half_eps], dim=0) |
|
return eps |
|
|
|
|
|
def unpatchify(self, x): |
|
|
|
""" |
|
x: (N, T, patch_size 0 * patch_size 1 * C) |
|
imgs: (Bs. 256. 16. 8) |
|
""" |
|
if self.x_embedder.ol == (0, 0) or self.x_embedder.ol == [0, 0]: |
|
c = self.out_channels |
|
p0 = self.x_embedder.patch_size[0] |
|
p1 = self.x_embedder.patch_size[1] |
|
h, w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1] |
|
|
|
x = x.reshape(shape=(x.shape[0], h, w, p0, p1, c)) |
|
x = th.einsum('nhwpqc->nchpwq', x) |
|
imgs = x.reshape(shape=(x.shape[0], c, h * p0, w * p1)) |
|
return imgs |
|
|
|
lf = self.x_embedder.grid_size[0] |
|
rf = self.x_embedder.grid_size[1] |
|
lp = self.x_embedder.patch_size[0] |
|
rp = self.x_embedder.patch_size[1] |
|
lo = self.x_embedder.ol[0] |
|
ro = self.x_embedder.ol[1] |
|
lm = self.x_embedder.img_size[0] |
|
rm = self.x_embedder.img_size[1] |
|
lpad = self.x_embedder.pad_size[0] |
|
rpad = self.x_embedder.pad_size[1] |
|
bs = x.shape[0] |
|
|
|
torch_map = self.torch_map |
|
|
|
c = self.out_channels |
|
|
|
x = x.reshape(shape=(bs, lf, rf, lp, rp, c)) |
|
x = th.einsum('nhwpqc->nchwpq', x) |
|
|
|
added_map = th.zeros(bs, c, lm+2*lpad, rm+2*rpad).to(x.device) |
|
|
|
for i in range(lf): |
|
for j in range(rf): |
|
xx = (i) * (lp - lo) |
|
yy = (j) * (rp - ro) |
|
added_map[:, :, xx:(xx+lp), yy:(yy+rp)] += \ |
|
x[:, :, i, j, :, :] |
|
|
|
|
|
added_map = added_map[:][:][lpad:lm+lpad, rpad:rm+rpad] |
|
return th.mul(added_map.to(x.device), torch_map.to(x.device)) |
|
|
|
def random_masking(self, x, mask_ratio): |
|
""" |
|
Perform per-sample random masking by per-sample shuffling. |
|
Per-sample shuffling is done by argsort random noise. |
|
x: [N, L, D], sequence |
|
""" |
|
N, L, D = x.shape |
|
len_keep = int(L * (1 - mask_ratio)) |
|
|
|
noise = th.rand(N, L, device=x.device) |
|
|
|
|
|
|
|
ids_shuffle = th.argsort(noise, dim=1) |
|
ids_restore = th.argsort(ids_shuffle, dim=1) |
|
|
|
|
|
ids_keep = ids_shuffle[:, :len_keep] |
|
x_masked = th.gather( |
|
x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) |
|
|
|
|
|
mask = th.ones([N, L], device=x.device) |
|
mask[:, :len_keep] = 0 |
|
|
|
mask = th.gather(mask, dim=1, index=ids_restore) |
|
|
|
return x_masked, mask, ids_restore, ids_keep |
|
|
|
def forward_side_interpolater(self, x, y, t0, y_lens, mask, ids_restore): |
|
|
|
mask_tokens = self.mask_token.repeat( |
|
x.shape[0], ids_restore.shape[1] - x.shape[1], 1) |
|
x_ = th.cat([x, mask_tokens], dim=1) |
|
x = th.gather( |
|
x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) |
|
|
|
|
|
x = x + self.decoder_pos_embed |
|
|
|
|
|
x_before = x |
|
for sideblock in self.sideblocks: |
|
x = sideblock(x, y, t0, y_lens, ids_keep=None) |
|
|
|
|
|
mask = mask.unsqueeze(dim=-1) |
|
x = x*mask + (1-mask)*x_before |
|
|
|
return x |
|
|
|
def initialize_weights(self): |
|
|
|
def _basic_init(module): |
|
if isinstance(module, nn.Linear): |
|
th.nn.init.xavier_uniform_(module.weight) |
|
if module.bias is not None: |
|
nn.init.constant_(module.bias, 0) |
|
|
|
self.apply(_basic_init) |
|
|
|
|
|
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size, lewei_scale=self.lewei_scale, base_size=self.base_size) |
|
self.pos_embed.data.copy_(th.from_numpy(pos_embed).float().unsqueeze(0)) |
|
|
|
|
|
w = self.x_embedder.proj.weight.data |
|
nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
|
|
|
|
|
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) |
|
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) |
|
nn.init.normal_(self.t_block[1].weight, std=0.02) |
|
|
|
|
|
nn.init.normal_(self.y_embedder.weight, std=0.02) |
|
|
|
|
|
|
|
for block in self.en_inblocks: |
|
nn.init.constant_(block.cross_attn.proj.weight, 0) |
|
nn.init.constant_(block.cross_attn.proj.bias, 0) |
|
for block in self.en_outblocks: |
|
nn.init.constant_(block.cross_attn.proj.weight, 0) |
|
nn.init.constant_(block.cross_attn.proj.bias, 0) |
|
for block in self.de_blocks: |
|
nn.init.constant_(block.cross_attn.proj.weight, 0) |
|
nn.init.constant_(block.cross_attn.proj.bias, 0) |
|
for block in self.sideblocks: |
|
nn.init.constant_(block.cross_attn.proj.weight, 0) |
|
nn.init.constant_(block.cross_attn.proj.bias, 0) |
|
|
|
|
|
nn.init.constant_(self.final_layer.linear.weight, 0) |
|
nn.init.constant_(self.final_layer.linear.bias, 0) |
|
|
|
if self.x_embedder.ol == [0, 0] or self.x_embedder.ol == (0, 0): |
|
return |
|
|
|
lf = self.x_embedder.grid_size[0] |
|
rf = self.x_embedder.grid_size[1] |
|
lp = self.x_embedder.patch_size[0] |
|
rp = self.x_embedder.patch_size[1] |
|
lo = self.x_embedder.ol[0] |
|
ro = self.x_embedder.ol[1] |
|
lm = self.x_embedder.img_size[0] |
|
rm = self.x_embedder.img_size[1] |
|
lpad = self.x_embedder.pad_size[0] |
|
rpad = self.x_embedder.pad_size[1] |
|
|
|
torch_map = th.zeros(lm+2*lpad, rm+2*rpad).to('cuda') |
|
for i in range(lf): |
|
for j in range(rf): |
|
xx = (i) * (lp - lo) |
|
yy = (j) * (rp - ro) |
|
torch_map[xx:(xx+lp), yy:(yy+rp)]+=1 |
|
torch_map = torch_map[lpad:lm+lpad, rpad:rm+rpad] |
|
self.torch_map = th.reciprocal(torch_map) |
|
|
|
@property |
|
def dtype(self): |
|
return next(self.parameters()).dtype |
|
|
|
|
|
class PixArt_MDT(nn.Module): |
|
""" |
|
Diffusion model with a Transformer backbone. |
|
""" |
|
|
|
def __init__(self, input_size=(256,16), patch_size=(16,4), overlap=(0, 0), in_channels=8, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, pred_sigma=False, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, cond_dim=1024, lewei_scale=1.0, |
|
use_cfg=True, cfg_scale=4.0, config=None, model_max_length=120, mask_ratio=None, decode_layer=4,**kwargs): |
|
if window_block_indexes is None: |
|
window_block_indexes = [] |
|
super().__init__() |
|
self.use_cfg = use_cfg |
|
self.cfg_scale = cfg_scale |
|
self.input_size = input_size |
|
self.pred_sigma = pred_sigma |
|
self.in_channels = in_channels |
|
self.out_channels = in_channels |
|
self.patch_size = patch_size |
|
self.num_heads = num_heads |
|
self.lewei_scale = lewei_scale, |
|
decode_layer = int(decode_layer) |
|
|
|
self.x_embedder = PatchEmbed(input_size, patch_size, overlap, in_channels, hidden_size, bias=True) |
|
|
|
self.t_embedder = TimestepEmbedder(hidden_size) |
|
num_patches = self.x_embedder.num_patches |
|
self.base_size = input_size[0] // self.patch_size[0] * 2 |
|
|
|
self.register_buffer("pos_embed", th.zeros(1, num_patches, hidden_size)) |
|
|
|
approx_gelu = lambda: nn.GELU(approximate="tanh") |
|
self.t_block = nn.Sequential( |
|
nn.SiLU(), |
|
nn.Linear(hidden_size, 6 * hidden_size, bias=True) |
|
) |
|
self.y_embedder = nn.Linear(cond_dim, hidden_size) |
|
|
|
half_depth = (depth - decode_layer)//2 |
|
self.half_depth=half_depth |
|
|
|
drop_path_half = [x.item() for x in th.linspace(0, drop_path, half_depth)] |
|
drop_path_decode = [x.item() for x in th.linspace(0, drop_path, decode_layer)] |
|
self.en_inblocks = nn.ModuleList([ |
|
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i], |
|
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), |
|
window_size=0, |
|
use_rel_pos=False) for i in range(half_depth) |
|
]) |
|
self.en_outblocks = nn.ModuleList([ |
|
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i], |
|
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), |
|
window_size=0, |
|
use_rel_pos=False, skip=True) for i in range(half_depth) |
|
]) |
|
self.de_blocks = nn.ModuleList([ |
|
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_decode[i], |
|
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), |
|
window_size=0, |
|
use_rel_pos=False, skip=True) for i in range(decode_layer) |
|
]) |
|
self.sideblocks = nn.ModuleList([ |
|
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, |
|
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), |
|
window_size=0, |
|
use_rel_pos=False) for _ in range(1) |
|
]) |
|
|
|
self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels) |
|
|
|
self.decoder_pos_embed = nn.Parameter(th.zeros( |
|
1, num_patches, hidden_size), requires_grad=True) |
|
if mask_ratio is not None: |
|
self.mask_token = nn.Parameter(th.zeros(1, 1, hidden_size)) |
|
self.mask_ratio = float(mask_ratio) |
|
self.decode_layer = int(decode_layer) |
|
else: |
|
self.mask_token = nn.Parameter(th.zeros( |
|
1, 1, hidden_size), requires_grad=False) |
|
self.mask_ratio = None |
|
self.decode_layer = int(decode_layer) |
|
print("mask ratio:", self.mask_ratio, "decode_layer:", self.decode_layer) |
|
self.initialize_weights() |
|
|
|
|
|
def forward(self, x, timestep, context_list, context_mask_list=None, enable_mask=False, **kwargs): |
|
""" |
|
Forward pass of PixArt. |
|
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) |
|
t: (N,) tensor of diffusion timesteps |
|
y: (N, 1, 120, C) tensor of class labels |
|
""" |
|
|
|
x = x.to(self.dtype) |
|
timestep = timestep.to(self.dtype) |
|
y = context_list[0].to(self.dtype) |
|
pos_embed = self.pos_embed.to(self.dtype) |
|
self.h, self.w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1] |
|
|
|
|
|
|
|
x = self.x_embedder(x) + pos_embed |
|
|
|
t = self.t_embedder(timestep.to(x.dtype)) |
|
t0 = self.t_block(t) |
|
|
|
y = self.y_embedder(y) |
|
|
|
try: |
|
mask = context_mask_list[0] |
|
except: |
|
mask = th.ones(x.shape[0], 1).to(x.device) |
|
print("MASK !!!!!!!!!!!!!!!!!!!!!!!!!") |
|
|
|
assert mask is not None |
|
|
|
|
|
y = y.masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) |
|
y_lens = mask.sum(dim=1).tolist() |
|
y_lens = [int(_) for _ in y_lens] |
|
|
|
input_skip = x |
|
|
|
masked_stage = False |
|
skips = [] |
|
|
|
if self.mask_ratio is not None and self.training: |
|
|
|
rand_mask_ratio = th.rand(1, device=x.device) |
|
rand_mask_ratio = rand_mask_ratio * 0.2 + self.mask_ratio |
|
|
|
x, mask, ids_restore, ids_keep = self.random_masking( |
|
x, rand_mask_ratio) |
|
masked_stage = True |
|
|
|
|
|
for block in self.en_inblocks: |
|
if masked_stage: |
|
x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=ids_keep) |
|
else: |
|
x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=None) |
|
skips.append(x) |
|
|
|
for block in self.en_outblocks: |
|
if masked_stage: |
|
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=ids_keep) |
|
else: |
|
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=None) |
|
|
|
if self.mask_ratio is not None and self.training: |
|
x = self.forward_side_interpolater(x, y, t0, y_lens, mask, ids_restore) |
|
masked_stage = False |
|
else: |
|
|
|
x = x + self.decoder_pos_embed |
|
|
|
for i in range(len(self.de_blocks)): |
|
block = self.de_blocks[i] |
|
this_skip = input_skip |
|
|
|
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=this_skip, ids_keep=None) |
|
|
|
x = self.final_layer(x, t) |
|
|
|
x = self.unpatchify(x) |
|
|
|
return x |
|
|
|
def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs): |
|
""" |
|
dpm solver donnot need variance prediction |
|
""" |
|
|
|
model_out = self.forward(x, timestep, y, mask) |
|
return model_out.chunk(2, dim=1)[0] |
|
|
|
def forward_with_cfg(self, x, timestep, y, cfg_scale, mask=None, **kwargs): |
|
""" |
|
Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance. |
|
""" |
|
|
|
half = x[: len(x) // 2] |
|
combined = th.cat([half, half], dim=0) |
|
model_out = self.forward(combined, timestep, y, mask) |
|
model_out = model_out['x'] if isinstance(model_out, dict) else model_out |
|
eps, rest = model_out[:, :8], model_out[:, 8:] |
|
cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0) |
|
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) |
|
eps = th.cat([half_eps, half_eps], dim=0) |
|
return eps |
|
|
|
|
|
def unpatchify(self, x): |
|
|
|
""" |
|
x: (N, T, patch_size 0 * patch_size 1 * C) |
|
imgs: (Bs. 256. 16. 8) |
|
""" |
|
if self.x_embedder.ol == (0, 0) or self.x_embedder.ol == [0, 0]: |
|
|
|
c = self.out_channels |
|
|
|
p0 = self.x_embedder.patch_size[0] |
|
p1 = self.x_embedder.patch_size[1] |
|
h, w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1] |
|
|
|
x = x.reshape(shape=(x.shape[0], h, w, p0, p1, c)) |
|
|
|
x = th.einsum('nhwpqc->nchpwq', x) |
|
|
|
imgs = x.reshape(shape=(x.shape[0], c, h * p0, w * p1)) |
|
|
|
return imgs |
|
|
|
lf = self.x_embedder.grid_size[0] |
|
rf = self.x_embedder.grid_size[1] |
|
lp = self.x_embedder.patch_size[0] |
|
rp = self.x_embedder.patch_size[1] |
|
lo = self.x_embedder.ol[0] |
|
ro = self.x_embedder.ol[1] |
|
lm = self.x_embedder.img_size[0] |
|
rm = self.x_embedder.img_size[1] |
|
lpad = self.x_embedder.pad_size[0] |
|
rpad = self.x_embedder.pad_size[1] |
|
bs = x.shape[0] |
|
|
|
torch_map = self.torch_map |
|
|
|
c = self.out_channels |
|
|
|
x = x.reshape(shape=(bs, lf, rf, lp, rp, c)) |
|
x = th.einsum('nhwpqc->nchwpq', x) |
|
|
|
added_map = th.zeros(bs, c, lm+2*lpad, rm+2*rpad).to(x.device) |
|
|
|
for i in range(lf): |
|
for j in range(rf): |
|
xx = (i) * (lp - lo) |
|
yy = (j) * (rp - ro) |
|
added_map[:, :, xx:(xx+lp), yy:(yy+rp)] += \ |
|
x[:, :, i, j, :, :] |
|
|
|
added_map = added_map[:, :, lpad:lm+lpad, rpad:rm+rpad] |
|
return th.mul(added_map, torch_map.to(added_map.device)) |
|
|
|
|
|
def random_masking(self, x, mask_ratio): |
|
""" |
|
Perform per-sample random masking by per-sample shuffling. |
|
Per-sample shuffling is done by argsort random noise. |
|
x: [N, L, D], sequence |
|
""" |
|
N, L, D = x.shape |
|
len_keep = int(L * (1 - mask_ratio)) |
|
|
|
noise = th.rand(N, L, device=x.device) |
|
|
|
|
|
|
|
ids_shuffle = th.argsort(noise, dim=1) |
|
ids_restore = th.argsort(ids_shuffle, dim=1) |
|
|
|
|
|
ids_keep = ids_shuffle[:, :len_keep] |
|
x_masked = th.gather( |
|
x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) |
|
|
|
|
|
mask = th.ones([N, L], device=x.device) |
|
mask[:, :len_keep] = 0 |
|
|
|
mask = th.gather(mask, dim=1, index=ids_restore) |
|
|
|
return x_masked, mask, ids_restore, ids_keep |
|
|
|
def forward_side_interpolater(self, x, y, t0, y_lens, mask, ids_restore): |
|
|
|
mask_tokens = self.mask_token.repeat( |
|
x.shape[0], ids_restore.shape[1] - x.shape[1], 1) |
|
x_ = th.cat([x, mask_tokens], dim=1) |
|
x = th.gather( |
|
x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) |
|
|
|
|
|
x = x + self.decoder_pos_embed |
|
|
|
|
|
x_before = x |
|
for sideblock in self.sideblocks: |
|
x = sideblock(x, y, t0, y_lens, ids_keep=None) |
|
|
|
|
|
mask = mask.unsqueeze(dim=-1) |
|
x = x*mask + (1-mask)*x_before |
|
|
|
return x |
|
|
|
def initialize_weights(self): |
|
|
|
def _basic_init(module): |
|
if isinstance(module, nn.Linear): |
|
th.nn.init.xavier_uniform_(module.weight) |
|
if module.bias is not None: |
|
nn.init.constant_(module.bias, 0) |
|
|
|
self.apply(_basic_init) |
|
|
|
|
|
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size, lewei_scale=self.lewei_scale, base_size=self.base_size) |
|
self.pos_embed.data.copy_(th.from_numpy(pos_embed).float().unsqueeze(0)) |
|
|
|
|
|
w = self.x_embedder.proj.weight.data |
|
nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
|
|
|
|
|
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) |
|
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) |
|
nn.init.normal_(self.t_block[1].weight, std=0.02) |
|
|
|
|
|
nn.init.normal_(self.y_embedder.weight, std=0.02) |
|
|
|
|
|
|
|
for block in self.en_inblocks: |
|
nn.init.constant_(block.cross_attn.proj.weight, 0) |
|
nn.init.constant_(block.cross_attn.proj.bias, 0) |
|
for block in self.en_outblocks: |
|
nn.init.constant_(block.cross_attn.proj.weight, 0) |
|
nn.init.constant_(block.cross_attn.proj.bias, 0) |
|
for block in self.de_blocks: |
|
nn.init.constant_(block.cross_attn.proj.weight, 0) |
|
nn.init.constant_(block.cross_attn.proj.bias, 0) |
|
for block in self.sideblocks: |
|
nn.init.constant_(block.cross_attn.proj.weight, 0) |
|
nn.init.constant_(block.cross_attn.proj.bias, 0) |
|
|
|
|
|
nn.init.constant_(self.final_layer.linear.weight, 0) |
|
nn.init.constant_(self.final_layer.linear.bias, 0) |
|
|
|
if self.x_embedder.ol == [0, 0] or self.x_embedder.ol == (0, 0): |
|
return |
|
|
|
lf = self.x_embedder.grid_size[0] |
|
rf = self.x_embedder.grid_size[1] |
|
lp = self.x_embedder.patch_size[0] |
|
rp = self.x_embedder.patch_size[1] |
|
lo = self.x_embedder.ol[0] |
|
ro = self.x_embedder.ol[1] |
|
lm = self.x_embedder.img_size[0] |
|
rm = self.x_embedder.img_size[1] |
|
lpad = self.x_embedder.pad_size[0] |
|
rpad = self.x_embedder.pad_size[1] |
|
|
|
torch_map = th.zeros(lm+2*lpad, rm+2*rpad).to('cuda') |
|
for i in range(lf): |
|
for j in range(rf): |
|
xx = (i) * (lp - lo) |
|
yy = (j) * (rp - ro) |
|
torch_map[xx:(xx+lp), yy:(yy+rp)]+=1 |
|
torch_map = torch_map[lpad:lm+lpad, rpad:rm+rpad] |
|
self.torch_map = th.reciprocal(torch_map) |
|
|
|
@property |
|
def dtype(self): |
|
return next(self.parameters()).dtype |
|
|
|
class PixArt_MDT_FIT(nn.Module): |
|
""" |
|
Diffusion model with a Transformer backbone. |
|
""" |
|
|
|
def __init__(self, input_size=(256,16), patch_size=(16,4), overlap=(0, 0), in_channels=8, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, pred_sigma=False, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, cond_dim=1024, lewei_scale=1.0, |
|
use_cfg=True, cfg_scale=4.0, config=None, model_max_length=120, mask_ratio=None, decode_layer=4,**kwargs): |
|
if window_block_indexes is None: |
|
window_block_indexes = [] |
|
super().__init__() |
|
self.use_cfg = use_cfg |
|
self.cfg_scale = cfg_scale |
|
self.input_size = input_size |
|
self.pred_sigma = pred_sigma |
|
self.in_channels = in_channels |
|
self.out_channels = in_channels |
|
self.patch_size = patch_size |
|
self.num_heads = num_heads |
|
self.lewei_scale = lewei_scale, |
|
decode_layer = int(decode_layer) |
|
|
|
self.x_embedder = PatchEmbed(input_size, patch_size, overlap, in_channels, hidden_size, bias=True) |
|
|
|
self.t_embedder = TimestepEmbedder(hidden_size) |
|
num_patches = self.x_embedder.num_patches |
|
self.base_size = input_size[0] // self.patch_size[0] * 2 |
|
|
|
self.register_buffer("pos_embed", th.zeros(1, num_patches, hidden_size)) |
|
|
|
approx_gelu = lambda: nn.GELU(approximate="tanh") |
|
self.t_block = nn.Sequential( |
|
nn.SiLU(), |
|
nn.Linear(hidden_size, 6 * hidden_size, bias=True) |
|
) |
|
self.y_embedder = nn.Linear(cond_dim, hidden_size) |
|
|
|
half_depth = (depth - decode_layer)//2 |
|
self.half_depth=half_depth |
|
|
|
drop_path_half = [x.item() for x in th.linspace(0, drop_path, half_depth)] |
|
drop_path_decode = [x.item() for x in th.linspace(0, drop_path, decode_layer)] |
|
self.en_inblocks = nn.ModuleList([ |
|
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i], |
|
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), |
|
window_size=0, |
|
use_rel_pos=False) for i in range(half_depth) |
|
]) |
|
self.en_outblocks = nn.ModuleList([ |
|
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_half[i], |
|
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), |
|
window_size=0, |
|
use_rel_pos=False, skip=True) for i in range(half_depth) |
|
]) |
|
self.de_blocks = nn.ModuleList([ |
|
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path_decode[i], |
|
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), |
|
window_size=0, |
|
use_rel_pos=False, skip=True) for i in range(decode_layer) |
|
]) |
|
self.sideblocks = nn.ModuleList([ |
|
MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, |
|
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), |
|
window_size=0, |
|
use_rel_pos=False) for _ in range(1) |
|
]) |
|
|
|
self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels) |
|
|
|
self.decoder_pos_embed = nn.Parameter(th.zeros( |
|
1, num_patches, hidden_size), requires_grad=True) |
|
if mask_ratio is not None: |
|
self.mask_token = nn.Parameter(th.zeros(1, 1, hidden_size)) |
|
self.mask_ratio = float(mask_ratio) |
|
self.decode_layer = int(decode_layer) |
|
else: |
|
self.mask_token = nn.Parameter(th.zeros( |
|
1, 1, hidden_size), requires_grad=False) |
|
self.mask_ratio = None |
|
self.decode_layer = int(decode_layer) |
|
print("mask ratio:", self.mask_ratio, "decode_layer:", self.decode_layer) |
|
self.initialize_weights() |
|
|
|
|
|
def forward(self, x, timestep, context_list, context_mask_list=None, enable_mask=False, **kwargs): |
|
""" |
|
Forward pass of PixArt. |
|
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) |
|
t: (N,) tensor of diffusion timesteps |
|
y: (N, 1, 120, C) tensor of class labels |
|
""" |
|
x = x.to(self.dtype) |
|
timestep = timestep.to(self.dtype) |
|
y = context_list[0].to(self.dtype) |
|
pos_embed = self.pos_embed.to(self.dtype) |
|
self.h, self.w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1] |
|
|
|
|
|
x = self.x_embedder(x) + pos_embed |
|
t = self.t_embedder(timestep.to(x.dtype)) |
|
t0 = self.t_block(t) |
|
y = self.y_embedder(y) |
|
|
|
try: |
|
mask = context_mask_list[0] |
|
except: |
|
mask = th.ones(x.shape[0], 1).to(x.device) |
|
print("MASK !!!!!!!!!!!!!!!!!!!!!!!!!") |
|
|
|
assert mask is not None |
|
|
|
|
|
y = y.masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) |
|
y_lens = mask.sum(dim=1).tolist() |
|
y_lens = [int(_) for _ in y_lens] |
|
|
|
input_skip = x |
|
|
|
masked_stage = False |
|
skips = [] |
|
|
|
if self.mask_ratio is not None and self.training: |
|
|
|
rand_mask_ratio = th.rand(1, device=x.device) |
|
rand_mask_ratio = rand_mask_ratio * 0.2 + self.mask_ratio |
|
|
|
x, mask, ids_restore, ids_keep = self.random_masking( |
|
x, rand_mask_ratio) |
|
masked_stage = True |
|
|
|
|
|
for block in self.en_inblocks: |
|
if masked_stage: |
|
x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=ids_keep) |
|
else: |
|
x = auto_grad_checkpoint(block, x, y, t0, y_lens, ids_keep=None) |
|
skips.append(x) |
|
|
|
for block in self.en_outblocks: |
|
if masked_stage: |
|
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=ids_keep) |
|
else: |
|
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=skips.pop(), ids_keep=None) |
|
|
|
if self.mask_ratio is not None and self.training: |
|
x = self.forward_side_interpolater(x, y, t0, y_lens, mask, ids_restore) |
|
masked_stage = False |
|
else: |
|
|
|
x = x + self.decoder_pos_embed |
|
|
|
for i in range(len(self.de_blocks)): |
|
block = self.de_blocks[i] |
|
this_skip = input_skip |
|
|
|
x = auto_grad_checkpoint(block, x, y, t0, y_lens, skip=this_skip, ids_keep=None) |
|
|
|
x = self.final_layer(x, t) |
|
x = self.unpatchify(x) |
|
return x |
|
|
|
def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs): |
|
""" |
|
dpm solver donnot need variance prediction |
|
""" |
|
|
|
model_out = self.forward(x, timestep, y, mask) |
|
return model_out.chunk(2, dim=1)[0] |
|
|
|
def forward_with_cfg(self, x, timestep, y, cfg_scale, mask=None, **kwargs): |
|
""" |
|
Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance. |
|
""" |
|
|
|
half = x[: len(x) // 2] |
|
combined = th.cat([half, half], dim=0) |
|
model_out = self.forward(combined, timestep, y, mask) |
|
model_out = model_out['x'] if isinstance(model_out, dict) else model_out |
|
eps, rest = model_out[:, :8], model_out[:, 8:] |
|
cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0) |
|
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) |
|
eps = th.cat([half_eps, half_eps], dim=0) |
|
return eps |
|
|
|
|
|
def unpatchify(self, x): |
|
|
|
""" |
|
x: (N, T, patch_size 0 * patch_size 1 * C) |
|
imgs: (Bs. 256. 16. 8) |
|
""" |
|
|
|
c = self.out_channels |
|
p0 = self.x_embedder.patch_size[0] |
|
p1 = self.x_embedder.patch_size[1] |
|
h, w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1] |
|
|
|
x = x.reshape(shape=(x.shape[0], h, w, p0, p1, c)) |
|
x = th.einsum('nhwpqc->nchpwq', x) |
|
imgs = x.reshape(shape=(x.shape[0], c, h * p0, w * p1)) |
|
return imgs |
|
|
|
def random_masking(self, x, mask_ratio): |
|
""" |
|
Perform per-sample random masking by per-sample shuffling. |
|
Per-sample shuffling is done by argsort random noise. |
|
x: [N, L, D], sequence |
|
""" |
|
N, L, D = x.shape |
|
len_keep = int(L * (1 - mask_ratio)) |
|
|
|
noise = th.rand(N, L, device=x.device) |
|
|
|
|
|
|
|
ids_shuffle = th.argsort(noise, dim=1) |
|
ids_restore = th.argsort(ids_shuffle, dim=1) |
|
|
|
|
|
ids_keep = ids_shuffle[:, :len_keep] |
|
x_masked = th.gather( |
|
x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) |
|
|
|
|
|
mask = th.ones([N, L], device=x.device) |
|
mask[:, :len_keep] = 0 |
|
|
|
mask = th.gather(mask, dim=1, index=ids_restore) |
|
|
|
return x_masked, mask, ids_restore, ids_keep |
|
|
|
def forward_side_interpolater(self, x, y, t0, y_lens, mask, ids_restore): |
|
|
|
mask_tokens = self.mask_token.repeat( |
|
x.shape[0], ids_restore.shape[1] - x.shape[1], 1) |
|
x_ = th.cat([x, mask_tokens], dim=1) |
|
x = th.gather( |
|
x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) |
|
|
|
|
|
x = x + self.decoder_pos_embed |
|
|
|
|
|
x_before = x |
|
for sideblock in self.sideblocks: |
|
x = sideblock(x, y, t0, y_lens, ids_keep=None) |
|
|
|
|
|
mask = mask.unsqueeze(dim=-1) |
|
x = x*mask + (1-mask)*x_before |
|
|
|
return x |
|
|
|
def initialize_weights(self): |
|
|
|
def _basic_init(module): |
|
if isinstance(module, nn.Linear): |
|
th.nn.init.xavier_uniform_(module.weight) |
|
if module.bias is not None: |
|
nn.init.constant_(module.bias, 0) |
|
|
|
self.apply(_basic_init) |
|
|
|
|
|
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size, lewei_scale=self.lewei_scale, base_size=self.base_size) |
|
|
|
|
|
|
|
self.pos_embed.data.copy_(th.from_numpy(pos_embed).float().unsqueeze(0)) |
|
|
|
|
|
w = self.x_embedder.proj.weight.data |
|
nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
|
|
|
|
|
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) |
|
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) |
|
nn.init.normal_(self.t_block[1].weight, std=0.02) |
|
|
|
|
|
nn.init.normal_(self.y_embedder.weight, std=0.02) |
|
|
|
|
|
|
|
for block in self.en_inblocks: |
|
nn.init.constant_(block.cross_attn.proj.weight, 0) |
|
nn.init.constant_(block.cross_attn.proj.bias, 0) |
|
for block in self.en_outblocks: |
|
nn.init.constant_(block.cross_attn.proj.weight, 0) |
|
nn.init.constant_(block.cross_attn.proj.bias, 0) |
|
for block in self.de_blocks: |
|
nn.init.constant_(block.cross_attn.proj.weight, 0) |
|
nn.init.constant_(block.cross_attn.proj.bias, 0) |
|
for block in self.sideblocks: |
|
nn.init.constant_(block.cross_attn.proj.weight, 0) |
|
nn.init.constant_(block.cross_attn.proj.bias, 0) |
|
|
|
|
|
nn.init.constant_(self.final_layer.linear.weight, 0) |
|
nn.init.constant_(self.final_layer.linear.bias, 0) |
|
|
|
@property |
|
def dtype(self): |
|
return next(self.parameters()).dtype |
|
|
|
class PixArt_Slow(nn.Module): |
|
""" |
|
Diffusion model with a Transformer backbone. |
|
""" |
|
|
|
def __init__(self, input_size=(256,16), patch_size=(16,4), overlap=(0, 0), in_channels=8, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, pred_sigma=True, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, cond_dim=1024, lewei_scale=1.0, |
|
use_cfg=True, cfg_scale=4.0, config=None, model_max_length=120, **kwargs): |
|
if window_block_indexes is None: |
|
window_block_indexes = [] |
|
super().__init__() |
|
self.use_cfg = use_cfg |
|
self.cfg_scale = cfg_scale |
|
self.input_size = input_size |
|
self.pred_sigma = pred_sigma |
|
self.in_channels = in_channels |
|
self.out_channels = in_channels * 2 if pred_sigma else in_channels |
|
self.patch_size = patch_size |
|
self.num_heads = num_heads |
|
self.lewei_scale = lewei_scale, |
|
|
|
self.x_embedder = PatchEmbed(input_size, patch_size, overlap, in_channels, hidden_size, bias=True) |
|
|
|
self.t_embedder = TimestepEmbedder(hidden_size) |
|
num_patches = self.x_embedder.num_patches |
|
self.base_size = input_size[0] // self.patch_size[0] * 2 |
|
|
|
self.register_buffer("pos_embed", th.zeros(1, num_patches, hidden_size)) |
|
|
|
approx_gelu = lambda: nn.GELU(approximate="tanh") |
|
self.t_block = nn.Sequential( |
|
nn.SiLU(), |
|
nn.Linear(hidden_size, 6 * hidden_size, bias=True) |
|
) |
|
self.y_embedder = nn.Linear(cond_dim, hidden_size) |
|
drop_path = [x.item() for x in th.linspace(0, drop_path, depth)] |
|
self.blocks = nn.ModuleList([ |
|
PixArtBlock_Slow(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i], |
|
input_size=(self.x_embedder.grid_size[0], self.x_embedder.grid_size[1]), |
|
window_size=0, |
|
use_rel_pos=False) |
|
for i in range(depth) |
|
]) |
|
self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels) |
|
|
|
self.initialize_weights() |
|
|
|
def forward(self, x, timestep, context_list, context_mask_list=None, **kwargs): |
|
""" |
|
Forward pass of PixArt. |
|
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) |
|
t: (N,) tensor of diffusion timesteps |
|
y: (N, 1, 120, C) tensor of class labels |
|
""" |
|
x = x.to(self.dtype) |
|
timestep = timestep.to(self.dtype) |
|
y = context_list[0].to(self.dtype) |
|
pos_embed = self.pos_embed.to(self.dtype) |
|
self.h, self.w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1] |
|
|
|
x = self.x_embedder(x) + pos_embed |
|
t = self.t_embedder(timestep.to(x.dtype)) |
|
t0 = self.t_block(t) |
|
y = self.y_embedder(y) |
|
mask = context_mask_list[0] |
|
|
|
assert mask is not None |
|
|
|
|
|
|
|
|
|
|
|
for block in self.blocks: |
|
x = auto_grad_checkpoint(block, x, y, t0, mask) |
|
x = self.final_layer(x, t) |
|
x = self.unpatchify(x) |
|
return x |
|
|
|
def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs): |
|
""" |
|
dpm solver donnot need variance prediction |
|
""" |
|
|
|
model_out = self.forward(x, timestep, y, mask) |
|
return model_out.chunk(2, dim=1)[0] |
|
|
|
def forward_with_cfg(self, x, timestep, y, cfg_scale, mask=None, **kwargs): |
|
""" |
|
Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance. |
|
""" |
|
|
|
half = x[: len(x) // 2] |
|
combined = th.cat([half, half], dim=0) |
|
model_out = self.forward(combined, timestep, y, mask) |
|
model_out = model_out['x'] if isinstance(model_out, dict) else model_out |
|
eps, rest = model_out[:, :8], model_out[:, 8:] |
|
cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0) |
|
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) |
|
eps = th.cat([half_eps, half_eps], dim=0) |
|
return eps |
|
|
|
|
|
def unpatchify(self, x): |
|
|
|
""" |
|
x: (N, T, patch_size 0 * patch_size 1 * C) |
|
imgs: (Bs. 256. 16. 8) |
|
""" |
|
c = self.out_channels |
|
p0 = self.x_embedder.patch_size[0] |
|
p1 = self.x_embedder.patch_size[1] |
|
h, w = self.x_embedder.grid_size[0], self.x_embedder.grid_size[1] |
|
|
|
x = x.reshape(shape=(x.shape[0], h, w, p0, p1, c)) |
|
x = th.einsum('nhwpqc->nchpwq', x) |
|
imgs = x.reshape(shape=(x.shape[0], c, h * p0, w * p1)) |
|
return imgs |
|
|
|
def initialize_weights(self): |
|
|
|
def _basic_init(module): |
|
if isinstance(module, nn.Linear): |
|
th.nn.init.xavier_uniform_(module.weight) |
|
if module.bias is not None: |
|
nn.init.constant_(module.bias, 0) |
|
|
|
self.apply(_basic_init) |
|
|
|
|
|
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size, lewei_scale=self.lewei_scale, base_size=self.base_size) |
|
self.pos_embed.data.copy_(th.from_numpy(pos_embed).float().unsqueeze(0)) |
|
|
|
|
|
w = self.x_embedder.proj.weight.data |
|
nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
|
|
|
|
|
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) |
|
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) |
|
nn.init.normal_(self.t_block[1].weight, std=0.02) |
|
|
|
|
|
nn.init.normal_(self.y_embedder.weight, std=0.02) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nn.init.constant_(self.final_layer.linear.weight, 0) |
|
nn.init.constant_(self.final_layer.linear.bias, 0) |
|
|
|
@property |
|
def dtype(self): |
|
return next(self.parameters()).dtype |
|
|
|
class PixArtBlock_1D(nn.Module): |
|
""" |
|
A PixArt block with adaptive layer norm (adaLN-single) conditioning. |
|
""" |
|
|
|
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., window_size=0, use_rel_pos=False, **block_kwargs): |
|
super().__init__() |
|
self.hidden_size = hidden_size |
|
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
|
self.attn = WindowAttention(hidden_size, num_heads=num_heads, qkv_bias=True, |
|
input_size=None, |
|
use_rel_pos=use_rel_pos, **block_kwargs) |
|
|
|
self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs) |
|
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
|
|
|
approx_gelu = lambda: nn.GELU(approximate="tanh") |
|
self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0) |
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
self.window_size = window_size |
|
self.scale_shift_table = nn.Parameter(th.randn(6, hidden_size) / hidden_size ** 0.5) |
|
|
|
def forward(self, x, y, t, mask=None, **kwargs): |
|
B, N, C = x.shape |
|
|
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1) |
|
x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C)) |
|
x = x + self.cross_attn(x, y, mask) |
|
x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) |
|
|
|
return x |
|
|
|
class PixArt_1D(nn.Module): |
|
""" |
|
Diffusion model with a Transformer backbone. |
|
""" |
|
|
|
def __init__(self, input_size=(256,16), in_channels=8, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, pred_sigma=True, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, cond_dim=1024, lewei_scale=1.0, |
|
use_cfg=True, cfg_scale=4.0, config=None, model_max_length=120, **kwargs): |
|
if window_block_indexes is None: |
|
window_block_indexes = [] |
|
super().__init__() |
|
self.use_cfg = use_cfg |
|
self.cfg_scale = cfg_scale |
|
self.input_size = input_size |
|
self.pred_sigma = pred_sigma |
|
self.in_channels = in_channels |
|
self.out_channels = in_channels |
|
self.num_heads = num_heads |
|
self.lewei_scale = lewei_scale, |
|
|
|
self.x_embedder = PatchEmbed_1D(input_size, in_channels, hidden_size) |
|
|
|
self.t_embedder = TimestepEmbedder(hidden_size) |
|
self.p_enc_1d_model = PositionalEncoding1D(hidden_size) |
|
|
|
approx_gelu = lambda: nn.GELU(approximate="tanh") |
|
self.t_block = nn.Sequential( |
|
nn.SiLU(), |
|
nn.Linear(hidden_size, 6 * hidden_size, bias=True) |
|
) |
|
self.y_embedder = nn.Linear(cond_dim, hidden_size) |
|
drop_path = [x.item() for x in th.linspace(0, drop_path, depth)] |
|
self.blocks = nn.ModuleList([ |
|
PixArtBlock_1D(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i], |
|
window_size=0, |
|
use_rel_pos=False) |
|
for i in range(depth) |
|
]) |
|
self.final_layer = T2IFinalLayer(hidden_size, (1, input_size[1]), self.out_channels) |
|
|
|
self.initialize_weights() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x, timestep, context_list, context_mask_list=None, **kwargs): |
|
""" |
|
Forward pass of PixArt. |
|
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) |
|
t: (N,) tensor of diffusion timesteps |
|
y: (N, 1, 120, C) tensor of class labels |
|
""" |
|
x = x.to(self.dtype) |
|
timestep = timestep.to(self.dtype) |
|
y = context_list[0].to(self.dtype) |
|
|
|
x = self.x_embedder(x) |
|
pos_embed = self.p_enc_1d_model(x) |
|
x = x + pos_embed |
|
t = self.t_embedder(timestep.to(x.dtype)) |
|
t0 = self.t_block(t) |
|
y = self.y_embedder(y) |
|
try: |
|
mask = context_mask_list[0] |
|
except: |
|
mask = th.ones(x.shape[0], 1).to(x.device) |
|
print("MASK !!!!!!!!!!!!!!!!!!!!!!!!!") |
|
|
|
assert mask is not None |
|
|
|
y = y.masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) |
|
y_lens = mask.sum(dim=1).tolist() |
|
y_lens = [int(_) for _ in y_lens] |
|
for block in self.blocks: |
|
x = auto_grad_checkpoint(block, x, y, t0, y_lens) |
|
x = self.final_layer(x, t) |
|
x = self.unpatchify_1D(x) |
|
return x |
|
|
|
def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs): |
|
""" |
|
dpm solver donnot need variance prediction |
|
""" |
|
|
|
model_out = self.forward(x, timestep, y, mask) |
|
return model_out.chunk(2, dim=1)[0] |
|
|
|
def forward_with_cfg(self, x, timestep, y, cfg_scale, mask=None, **kwargs): |
|
""" |
|
Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance. |
|
""" |
|
|
|
half = x[: len(x) // 2] |
|
combined = th.cat([half, half], dim=0) |
|
model_out = self.forward(combined, timestep, y, mask) |
|
model_out = model_out['x'] if isinstance(model_out, dict) else model_out |
|
eps, rest = model_out[:, :8], model_out[:, 8:] |
|
cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0) |
|
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) |
|
eps = th.cat([half_eps, half_eps], dim=0) |
|
return eps |
|
|
|
|
|
def unpatchify_1D(self, x): |
|
|
|
""" |
|
""" |
|
c = self.out_channels |
|
|
|
x = x.reshape(shape=(x.shape[0], self.input_size[0], self.input_size[1], c)) |
|
x = th.einsum('btfc->bctf', x) |
|
|
|
return x |
|
|
|
def initialize_weights(self): |
|
|
|
def _basic_init(module): |
|
if isinstance(module, nn.Linear): |
|
th.nn.init.xavier_uniform_(module.weight) |
|
if module.bias is not None: |
|
nn.init.constant_(module.bias, 0) |
|
|
|
self.apply(_basic_init) |
|
|
|
|
|
|
|
|
|
|
|
|
|
w = self.x_embedder.proj.weight.data |
|
nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
|
|
|
|
|
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) |
|
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) |
|
nn.init.normal_(self.t_block[1].weight, std=0.02) |
|
|
|
|
|
nn.init.normal_(self.y_embedder.weight, std=0.02) |
|
|
|
|
|
|
|
for block in self.blocks: |
|
nn.init.constant_(block.cross_attn.proj.weight, 0) |
|
nn.init.constant_(block.cross_attn.proj.bias, 0) |
|
|
|
|
|
nn.init.constant_(self.final_layer.linear.weight, 0) |
|
nn.init.constant_(self.final_layer.linear.bias, 0) |
|
|
|
@property |
|
def dtype(self): |
|
return next(self.parameters()).dtype |
|
|
|
class PixArt_Slow_1D(nn.Module): |
|
""" |
|
Diffusion model with a Transformer backbone. |
|
""" |
|
|
|
def __init__(self, input_size=(256,16), in_channels=8, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, pred_sigma=True, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, cond_dim=1024, lewei_scale=1.0, |
|
use_cfg=True, cfg_scale=4.0, config=None, model_max_length=120, **kwargs): |
|
if window_block_indexes is None: |
|
window_block_indexes = [] |
|
super().__init__() |
|
self.use_cfg = use_cfg |
|
self.cfg_scale = cfg_scale |
|
self.input_size = input_size |
|
self.pred_sigma = pred_sigma |
|
self.in_channels = in_channels |
|
self.out_channels = in_channels * 2 if pred_sigma else in_channels |
|
self.num_heads = num_heads |
|
self.lewei_scale = lewei_scale, |
|
|
|
self.x_embedder = PatchEmbed_1D(input_size, in_channels, hidden_size) |
|
|
|
self.t_embedder = TimestepEmbedder(hidden_size) |
|
self.p_enc_1d_model = PositionalEncoding1D(hidden_size) |
|
|
|
approx_gelu = lambda: nn.GELU(approximate="tanh") |
|
self.t_block = nn.Sequential( |
|
nn.SiLU(), |
|
nn.Linear(hidden_size, 6 * hidden_size, bias=True) |
|
) |
|
self.y_embedder = nn.Linear(cond_dim, hidden_size) |
|
drop_path = [x.item() for x in th.linspace(0, drop_path, depth)] |
|
self.blocks = nn.ModuleList([ |
|
PixArtBlock_Slow(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i], |
|
window_size=0, |
|
use_rel_pos=False) |
|
for i in range(depth) |
|
]) |
|
self.final_layer = T2IFinalLayer(hidden_size, (1, input_size[1]), self.out_channels) |
|
|
|
self.initialize_weights() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x, timestep, context_list, context_mask_list=None, **kwargs): |
|
""" |
|
Forward pass of PixArt. |
|
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) |
|
t: (N,) tensor of diffusion timesteps |
|
y: (N, 1, 120, C) tensor of class labels |
|
""" |
|
x = x.to(self.dtype) |
|
timestep = timestep.to(self.dtype) |
|
y = context_list[0].to(self.dtype) |
|
|
|
x = self.x_embedder(x) |
|
pos_embed = self.p_enc_1d_model(x) |
|
x = x + pos_embed |
|
t = self.t_embedder(timestep.to(x.dtype)) |
|
t0 = self.t_block(t) |
|
y = self.y_embedder(y) |
|
mask = context_mask_list[0] |
|
|
|
assert mask is not None |
|
|
|
|
|
|
|
|
|
for block in self.blocks: |
|
x = auto_grad_checkpoint(block, x, y, t0, mask) |
|
x = self.final_layer(x, t) |
|
x = self.unpatchify_1D(x) |
|
return x |
|
|
|
def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs): |
|
""" |
|
dpm solver donnot need variance prediction |
|
""" |
|
|
|
model_out = self.forward(x, timestep, y, mask) |
|
return model_out.chunk(2, dim=1)[0] |
|
|
|
def forward_with_cfg(self, x, timestep, y, cfg_scale, mask=None, **kwargs): |
|
""" |
|
Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance. |
|
""" |
|
|
|
half = x[: len(x) // 2] |
|
combined = th.cat([half, half], dim=0) |
|
model_out = self.forward(combined, timestep, y, mask) |
|
model_out = model_out['x'] if isinstance(model_out, dict) else model_out |
|
eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] |
|
cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0) |
|
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) |
|
eps = th.cat([half_eps, half_eps], dim=0) |
|
return eps |
|
|
|
|
|
def unpatchify_1D(self, x): |
|
|
|
""" |
|
""" |
|
c = self.out_channels |
|
|
|
x = x.reshape(shape=(x.shape[0], self.input_size[0], self.input_size[1], c)) |
|
x = th.einsum('btfc->bctf', x) |
|
|
|
return x |
|
|
|
def initialize_weights(self): |
|
|
|
def _basic_init(module): |
|
if isinstance(module, nn.Linear): |
|
th.nn.init.xavier_uniform_(module.weight) |
|
if module.bias is not None: |
|
nn.init.constant_(module.bias, 0) |
|
|
|
self.apply(_basic_init) |
|
|
|
|
|
|
|
|
|
|
|
|
|
w = self.x_embedder.proj.weight.data |
|
nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
|
|
|
|
|
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) |
|
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) |
|
nn.init.normal_(self.t_block[1].weight, std=0.02) |
|
|
|
|
|
nn.init.normal_(self.y_embedder.weight, std=0.02) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nn.init.constant_(self.final_layer.linear.weight, 0) |
|
nn.init.constant_(self.final_layer.linear.bias, 0) |
|
|
|
@property |
|
def dtype(self): |
|
return next(self.parameters()).dtype |
|
|
|
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, lewei_scale=1.0, base_size=16): |
|
""" |
|
grid_size: int of the grid height and width |
|
return: |
|
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) |
|
""" |
|
|
|
|
|
if isinstance(grid_size, int): |
|
grid_size = to_2tuple(grid_size) |
|
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0]/base_size) / lewei_scale |
|
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1]/base_size) / lewei_scale |
|
grid = np.meshgrid(grid_w, grid_h) |
|
grid = np.stack(grid, axis=0) |
|
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) |
|
|
|
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) |
|
if cls_token and extra_tokens > 0: |
|
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) |
|
return pos_embed |
|
|
|
|
|
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): |
|
assert embed_dim % 2 == 0 |
|
|
|
|
|
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) |
|
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) |
|
|
|
return np.concatenate([emb_h, emb_w], axis=1) |
|
|
|
|
|
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): |
|
""" |
|
embed_dim: output dimension for each position |
|
pos: a list of positions to be encoded: size (M,) |
|
out: (M, D) |
|
""" |
|
assert embed_dim % 2 == 0 |
|
omega = np.arange(embed_dim // 2, dtype=np.float64) |
|
omega /= embed_dim / 2. |
|
omega = 1. / 10000 ** omega |
|
|
|
pos = pos.reshape(-1) |
|
out = np.einsum('m,d->md', pos, omega) |
|
|
|
emb_sin = np.sin(out) |
|
emb_cos = np.cos(out) |
|
|
|
return np.concatenate([emb_sin, emb_cos], axis=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|