Harmon-1_5B / mar.py
wusize's picture
Add files using upload-large-folder tool
22a2012 verified
from functools import partial
import numpy as np
from tqdm import tqdm
import scipy.stats as stats
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.utils.checkpoint import checkpoint
from timm.models.vision_transformer import Block
from .diffloss import DiffLoss
def mask_by_order(mask_len, order, bsz, seq_len):
masking = torch.zeros(bsz, seq_len).to(order.device)
masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()],
src=torch.ones(bsz, seq_len).to(order.device)).bool()
return masking
class MAR(nn.Module):
""" Masked Autoencoder with VisionTransformer backbone
"""
def __init__(self, img_size=256, vae_stride=16, patch_size=1,
encoder_embed_dim=1024, encoder_depth=16, encoder_num_heads=16,
decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16,
mlp_ratio=4., norm_layer=nn.LayerNorm,
vae_embed_dim=16,
mask_ratio_min=0.7,
label_drop_prob=0.1,
class_num=1000,
attn_dropout=0.1,
proj_dropout=0.1,
buffer_size=64,
diffloss_d=3,
diffloss_w=1024,
num_sampling_steps='100',
diffusion_batch_mul=4,
grad_checkpointing=False,
):
super().__init__()
# --------------------------------------------------------------------------
# VAE and patchify specifics
self.vae_embed_dim = vae_embed_dim
self.img_size = img_size
self.vae_stride = vae_stride
self.patch_size = patch_size
self.seq_h = self.seq_w = img_size // vae_stride // patch_size
self.seq_len = self.seq_h * self.seq_w
self.token_embed_dim = vae_embed_dim * patch_size**2
self.grad_checkpointing = grad_checkpointing
# --------------------------------------------------------------------------
# Class Embedding
self.num_classes = class_num
self.class_emb = nn.Embedding(class_num, encoder_embed_dim)
self.label_drop_prob = label_drop_prob
# Fake class embedding for CFG's unconditional generation
self.fake_latent = nn.Parameter(torch.zeros(1, encoder_embed_dim))
# --------------------------------------------------------------------------
# MAR variant masking ratio, a left-half truncated Gaussian centered at 100% masking ratio with std 0.25
self.mask_ratio_generator = stats.truncnorm((mask_ratio_min - 1.0) / 0.25, 0, loc=1.0, scale=0.25)
# --------------------------------------------------------------------------
# MAR encoder specifics
self.encoder_embed_dim = encoder_embed_dim
self.z_proj = nn.Linear(self.token_embed_dim, encoder_embed_dim, bias=True)
self.z_proj_ln = nn.LayerNorm(encoder_embed_dim, eps=1e-6)
self.buffer_size = buffer_size
self.encoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, encoder_embed_dim))
self.encoder_blocks = nn.ModuleList([
Block(encoder_embed_dim, encoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(encoder_depth)])
self.encoder_norm = norm_layer(encoder_embed_dim)
# --------------------------------------------------------------------------
# MAR decoder specifics
self.decoder_embed_dim = decoder_embed_dim
self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, decoder_embed_dim))
self.decoder_blocks = nn.ModuleList([
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(decoder_depth)])
self.decoder_norm = norm_layer(decoder_embed_dim)
self.diffusion_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len, decoder_embed_dim))
self.initialize_weights()
# --------------------------------------------------------------------------
# Diffusion Loss
self.diffloss = DiffLoss(
target_channels=self.token_embed_dim,
z_channels=decoder_embed_dim,
width=diffloss_w,
depth=diffloss_d,
num_sampling_steps=num_sampling_steps,
grad_checkpointing=self.grad_checkpointing
)
self.diffusion_batch_mul = diffusion_batch_mul
def get_encoder_pos_embed(self, h, w):
if h == self.seq_h and w == self.seq_w:
return self.encoder_pos_embed_learned
buffer_pe, image_pe = self.encoder_pos_embed_learned.split(
[self.buffer_size, self.seq_len], dim=1)
image_pe = rearrange(image_pe, 'b (h w) c -> b c h w',
h=self.seq_h, w=self.seq_w)
image_pe = F.interpolate(image_pe, size=(h, w), mode='bilinear')
image_pe = rearrange(image_pe, 'b c h w -> b (h w) c')
return torch.cat([buffer_pe, image_pe], dim=1)
def get_decoder_pos_embed(self, h, w):
if h == self.seq_h and w == self.seq_w:
return self.decoder_pos_embed_learned
buffer_pe, image_pe = self.decoder_pos_embed_learned.split(
[self.buffer_size, self.seq_len], dim=1)
image_pe = rearrange(image_pe, 'b (h w) c -> b c h w',
h=self.seq_h, w=self.seq_w)
image_pe = F.interpolate(image_pe, size=(h, w), mode='bilinear')
image_pe = rearrange(image_pe, 'b c h w -> b (h w) c')
return torch.cat([buffer_pe, image_pe], dim=1)
def get_diffusion_pos_embed(self, h, w):
if h == self.seq_h and w == self.seq_w:
return self.diffusion_pos_embed_learned
image_pe = self.diffusion_pos_embed_learned
image_pe = rearrange(image_pe, 'b (h w) c -> b c h w',
h=self.seq_h, w=self.seq_w)
image_pe = F.interpolate(image_pe, size=(h, w), mode='bilinear')
image_pe = rearrange(image_pe, 'b c h w -> b (h w) c')
return image_pe
def initialize_weights(self):
# parameters
torch.nn.init.normal_(self.class_emb.weight, std=.02)
torch.nn.init.normal_(self.fake_latent, std=.02)
torch.nn.init.normal_(self.mask_token, std=.02)
torch.nn.init.normal_(self.encoder_pos_embed_learned, std=.02)
torch.nn.init.normal_(self.decoder_pos_embed_learned, std=.02)
torch.nn.init.normal_(self.diffusion_pos_embed_learned, std=.02)
# initialize nn.Linear and nn.LayerNorm
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
if m.bias is not None:
nn.init.constant_(m.bias, 0)
if m.weight is not None:
nn.init.constant_(m.weight, 1.0)
@property
def device(self):
return self.fake_latent.data.device
@property
def dtype(self):
return self.fake_latent.data.dtype
def patchify(self, x):
bsz, c, h, w = x.shape
p = self.patch_size
h_, w_ = h // p, w // p
x = x.reshape(bsz, c, h_, p, w_, p)
x = torch.einsum('nchpwq->nhwcpq', x)
x = x.reshape(bsz, h_ * w_, c * p ** 2)
return x # [n, l, d]
def unpatchify(self, x):
bsz = x.shape[0]
p = self.patch_size
c = self.vae_embed_dim
h_, w_ = self.seq_h, self.seq_w
x = x.reshape(bsz, h_, w_, c, p, p)
x = torch.einsum('nhwcpq->nchpwq', x)
x = x.reshape(bsz, c, h_ * p, w_ * p)
return x # [n, c, h, w]
def sample_orders(self, bsz, seq_len=None):
if seq_len is None:
seq_len = self.seq_len
# generate a batch of random generation orders
orders = []
for _ in range(bsz):
order = np.array(list(range(seq_len)))
np.random.shuffle(order)
orders.append(order)
orders = torch.Tensor(np.array(orders)).to(self.device).long()
return orders
def random_masking(self, x, orders):
# generate token mask
bsz, seq_len, embed_dim = x.shape
assert seq_len == orders.shape[1]
mask_rate = self.mask_ratio_generator.rvs(1)[0]
num_masked_tokens = int(np.ceil(seq_len * mask_rate))
mask = torch.zeros(bsz, seq_len, device=x.device)
mask = torch.scatter(mask, dim=-1, index=orders[:, :num_masked_tokens],
src=torch.ones(bsz, seq_len, device=x.device))
return mask
def forward_mae_encoder(self, x, mask, class_embedding, image_shape=None):
x = x.to(self.dtype)
x = self.z_proj(x)
bsz, seq_len, embed_dim = x.shape
# concat buffer
x = torch.cat([x.new_zeros(bsz, self.buffer_size, embed_dim), x], dim=1)
mask_with_buffer = torch.cat([mask.new_zeros(x.size(0), self.buffer_size), mask], dim=1)
# random drop class embedding during training
# if self.training:
# drop_latent_mask = torch.rand(bsz) < self.label_drop_prob
# drop_latent_mask = drop_latent_mask.unsqueeze(-1).to(self.device).to(x.dtype)
# class_embedding = drop_latent_mask * self.fake_latent + (1 - drop_latent_mask) * class_embedding
x[:, :self.buffer_size] = class_embedding.view(bsz, -1, embed_dim)
# encoder position embedding
# x = x + self.encoder_pos_embed_learned
if image_shape is None:
x = x + self.encoder_pos_embed_learned
else:
h, w = image_shape
assert h * w == seq_len
x = x + self.get_encoder_pos_embed(h=h, w=w)
# import pdb; pdb.set_trace()
x = self.z_proj_ln(x)
# dropping
x = x[(1-mask_with_buffer).nonzero(as_tuple=True)].reshape(bsz, -1, embed_dim)
# apply Transformer blocks
if self.grad_checkpointing and not torch.jit.is_scripting():
for block in self.encoder_blocks:
x = checkpoint(block, x,
use_reentrant=False
)
else:
for block in self.encoder_blocks:
x = block(x)
x = self.encoder_norm(x)
return x
def forward_mae_decoder(self, x, mask, image_shape=None, x_con=None):
bsz, seq_len = mask.shape
x = self.decoder_embed(x)
mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1)
# pad mask tokens
mask_tokens = self.mask_token.repeat(mask_with_buffer.shape[0], mask_with_buffer.shape[1], 1).to(x.dtype)
if x_con is not None:
x_after_pad = self.decoder_embed(x_con)
else:
x_after_pad = mask_tokens.clone()
x_after_pad[(1 - mask_with_buffer).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
# decoder position embedding
# x = x_after_pad + self.decoder_pos_embed_learned
if image_shape is None:
x = x_after_pad + self.decoder_pos_embed_learned
else:
h, w = image_shape
assert h * w == seq_len
x = x_after_pad + self.get_decoder_pos_embed(h=h, w=w)
# apply Transformer blocks
if self.grad_checkpointing and not torch.jit.is_scripting():
for block in self.decoder_blocks:
x = checkpoint(block, x,
# use_reentrant=False
)
else:
for block in self.decoder_blocks:
x = block(x)
x = self.decoder_norm(x)
x = x[:, self.buffer_size:]
# x = x + self.diffusion_pos_embed_learned
if image_shape is None:
x = x + self.diffusion_pos_embed_learned
else:
h, w = image_shape
assert h * w == seq_len
x = x + self.get_diffusion_pos_embed(h=h, w=w)
return x
def mae_decoder_prepare(self, x, mask):
x = self.decoder_embed(x)
mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1)
# pad mask tokens
mask_tokens = self.mask_token.repeat(mask_with_buffer.shape[0], mask_with_buffer.shape[1], 1).to(x.dtype)
x_after_pad = mask_tokens.clone()
x_after_pad[(1 - mask_with_buffer).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
# decoder position embedding
x = x_after_pad + self.decoder_pos_embed_learned
return x
def mae_decoder_forward(self, x):
# apply Transformer blocks
if self.grad_checkpointing and not torch.jit.is_scripting():
for block in self.decoder_blocks:
x = checkpoint(block, x,
# use_reentrant=False
)
else:
for block in self.decoder_blocks:
x = block(x)
x = self.decoder_norm(x)
x = x[:, self.buffer_size:]
x = x + self.diffusion_pos_embed_learned
return x
def forward_loss(self, z, target, mask):
bsz, seq_len, _ = target.shape
target = target.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1)
z = z.reshape(bsz*seq_len, -1).repeat(self.diffusion_batch_mul, 1)
mask = mask.reshape(bsz*seq_len).repeat(self.diffusion_batch_mul)
loss = self.diffloss(z=z, target=target, mask=mask)
return loss
def forward(self, imgs, labels):
# class embed
class_embedding = self.class_emb(labels)
# patchify and mask (drop) tokens
x = self.patchify(imgs)
gt_latents = x.clone().detach()
orders = self.sample_orders(bsz=x.size(0))
mask = self.random_masking(x, orders)
# mae encoder
x = self.forward_mae_encoder(x, mask, class_embedding)
# mae decoder
z = self.forward_mae_decoder(x, mask)
# diffloss
loss = self.forward_loss(z=z, target=gt_latents, mask=mask)
return loss
def sample_tokens(self, bsz, num_iter=64, cfg=1.0, cfg_schedule="linear", labels=None, temperature=1.0, progress=False):
import pdb; pdb.set_trace()
# init and sample generation orders
mask = torch.ones(bsz, self.seq_len).to(self.device)
tokens = torch.zeros(bsz, self.seq_len, self.token_embed_dim).to(self.device)
orders = self.sample_orders(bsz)
indices = list(range(num_iter))
if progress:
indices = tqdm(indices)
# generate latents
for step in indices:
cur_tokens = tokens.clone()
# class embedding and CFG
if labels is not None:
class_embedding = self.class_emb(labels)
else:
class_embedding = self.fake_latent.repeat(bsz, 1)
if not cfg == 1.0:
tokens = torch.cat([tokens, tokens], dim=0)
class_embedding = torch.cat([class_embedding, self.fake_latent.repeat(bsz, 1)], dim=0)
mask = torch.cat([mask, mask], dim=0)
# mae encoder
x = self.forward_mae_encoder(tokens, mask.to(self.dtype), class_embedding)
# mae decoder
z = self.forward_mae_decoder(x, mask.to(self.dtype))
import pdb; pdb.set_trace()
# mask ratio for the next round, following MaskGIT and MAGE.
mask_ratio = np.cos(math.pi / 2. * (step + 1) / num_iter)
mask_len = torch.Tensor([np.floor(self.seq_len * mask_ratio)]).to(self.device)
import pdb; pdb.set_trace()
# masks out at least one for the next iteration
mask_len = torch.maximum(torch.Tensor([1]).to(self.device),
torch.minimum(torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len))
import pdb; pdb.set_trace()
# get masking for next iteration and locations to be predicted in this iteration
mask_next = mask_by_order(mask_len[0], orders, bsz, self.seq_len)
import pdb; pdb.set_trace()
if step >= num_iter - 1:
mask_to_pred = mask[:bsz].bool()
else:
mask_to_pred = torch.logical_xor(mask[:bsz].bool(), mask_next.bool())
mask = mask_next
if not cfg == 1.0:
mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0)
import pdb; pdb.set_trace()
# sample token latents for this step
z = z[mask_to_pred.nonzero(as_tuple=True)]
# cfg schedule follow Muse
if cfg_schedule == "linear":
cfg_iter = 1 + (cfg - 1) * (self.seq_len - mask_len[0]) / self.seq_len
elif cfg_schedule == "constant":
cfg_iter = cfg
else:
raise NotImplementedError
sampled_token_latent = self.diffloss.sample(z, temperature, cfg_iter)
if not cfg == 1.0:
sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0) # Remove null class samples
mask_to_pred, _ = mask_to_pred.chunk(2, dim=0)
import pdb; pdb.set_trace()
cur_tokens[mask_to_pred.nonzero(as_tuple=True)] = sampled_token_latent
tokens = cur_tokens.clone()
# unpatchify
tokens = self.unpatchify(tokens)
return tokens
def gradient_checkpointing_enable(self):
self.grad_checkpointing = True
def gradient_checkpointing_disable(self):
self.grad_checkpointing = False
def mar_base(**kwargs):
model = MAR(
encoder_embed_dim=768, encoder_depth=12, encoder_num_heads=12,
decoder_embed_dim=768, decoder_depth=12, decoder_num_heads=12,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
def mar_large(**kwargs):
model = MAR(
encoder_embed_dim=1024, encoder_depth=16, encoder_num_heads=16,
decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
def mar_huge(**kwargs):
model = MAR(
encoder_embed_dim=1280, encoder_depth=20, encoder_num_heads=16,
decoder_embed_dim=1280, decoder_depth=20, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model