File size: 20,522 Bytes
3ed0796 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 |
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# GLIDE: https://github.com/openai/glide-text2im
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
# --------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from timm.layers.mlp import SwiGLU
from timm.models.vision_transformer import PatchEmbed, Attention
from tim.models.utils.funcs import build_mlp, modulate, get_parameter_dtype
from tim.models.utils.rope import VisionRotaryEmbedding, rotate_half
#################################################################################
# Embedding Layers for Timesteps and Class Labels #
#################################################################################
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def positional_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
return embedding
def forward(self, t):
self.timestep_embedding = self.positional_embedding
t_freq = self.timestep_embedding(t, dim=self.frequency_embedding_size).to(
t.dtype
)
t_emb = self.mlp(t_freq)
return t_emb
class CaptionEmbedder(nn.Module):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
"""
def __init__(self, cap_feat_dim, hidden_size):
super().__init__()
self.norm = nn.LayerNorm(cap_feat_dim)
self.mlp = SwiGLU(
in_features=cap_feat_dim,
hidden_features=hidden_size * 4,
out_features=hidden_size,
)
def forward(self, cap_feats):
"""
cfg is also essential in text-to-image generation
"""
cap_feats = self.mlp(self.norm(cap_feats))
return cap_feats
#################################################################################
# Attention Block #
#################################################################################
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
norm_layer: nn.Module = nn.LayerNorm,
distance_aware: bool = False,
) -> None:
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.distance_aware = distance_aware
if distance_aware:
self.qkv_d = nn.Linear(dim, dim * 3, bias=False)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(
self,
x: torch.Tensor,
freqs_cos,
freqs_sin,
attn_type="fused_attn",
delta_t=None,
) -> torch.Tensor:
B, N, C = x.shape
if self.distance_aware:
qkv = self.qkv(x) + self.qkv_d(delta_t)
else:
qkv = self.qkv(x)
if attn_type == "flash_attn": # q, k, v: (B, N, n_head, d_head)
qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim).permute(
2, 0, 1, 3, 4
)
else: # q, k, v: (B, n_head, N, d_head)
qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim).permute(
2, 0, 3, 1, 4
)
ori_dtype = qkv.dtype
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
q = q * freqs_cos + rotate_half(q) * freqs_sin
k = k * freqs_cos + rotate_half(k) * freqs_sin
q, k = q.to(ori_dtype), k.to(ori_dtype)
if attn_type == "flash_attn":
from flash_attn import flash_attn_func
x = flash_attn_func(
q,
k,
v,
dropout_p=self.attn_drop.p if self.training else 0.0,
)
x = x.reshape(B, N, C)
elif attn_type == "fused_attn":
x = F.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.attn_drop.p if self.training else 0.0,
)
x = x.transpose(1, 2).reshape(B, N, C)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
#################################################################################
# Cross Attention Block #
#################################################################################
class CrossAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
norm_layer: nn.Module = nn.LayerNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(
self,
x: torch.Tensor,
y: torch.Tensor,
freqs_cos,
freqs_sin,
attn_type="fused_attn",
) -> torch.Tensor:
B, N, C = x.shape
_, M, _ = y.shape
if attn_type == "flash_attn": # q, k, v: (B, N, n_head, d_head)
q = self.q(x).reshape(B, N, self.num_heads, self.head_dim)
kv = (
self.kv(y)
.reshape(B, M, 2, self.num_heads, self.head_dim)
.permute(2, 0, 1, 3, 4)
)
else: # q, k, v: (B, n_head, N, d_head)
q = (
self.q(x)
.reshape(B, N, self.num_heads, self.head_dim)
.permute(0, 2, 1, 3)
)
kv = (
self.kv(y)
.reshape(B, M, 2, self.num_heads, self.head_dim)
.permute(2, 0, 3, 1, 4)
)
ori_dtype = q.dtype
k, v = kv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
q = q * freqs_cos + rotate_half(q) * freqs_sin
q, k = q.to(ori_dtype), k.to(ori_dtype)
if attn_type == "flash_attn":
from flash_attn import flash_attn_func
x = flash_attn_func(
q,
k,
v,
dropout_p=self.attn_drop.p if self.training else 0.0,
)
x = x.reshape(B, N, C)
elif attn_type == "fused_attn":
x = F.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.attn_drop.p if self.training else 0.0,
)
x = x.transpose(1, 2).reshape(B, N, C)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
#################################################################################
# Core TiM Model #
#################################################################################
class TiMBlock(nn.Module):
"""
A TiM block with adaptive layer norm zero (adaLN-Zero) conditioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
distance_aware = block_kwargs.get("distance_aware", False)
self.attn = Attention(
hidden_size,
num_heads=num_heads,
qkv_bias=True,
qk_norm=block_kwargs["qk_norm"],
distance_aware=distance_aware,
)
self.norm2_i = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.norm2_t = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.cross_attn = CrossAttention(
hidden_size,
num_heads=num_heads,
qkv_bias=True,
qk_norm=block_kwargs["qk_norm"],
)
self.norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = SwiGLU(
in_features=hidden_size,
hidden_features=(mlp_hidden_dim * 2) // 3,
bias=True,
)
if block_kwargs.get("lora_hidden_size", None) != None:
lora_hidden_size = block_kwargs["lora_hidden_size"]
else:
lora_hidden_size = (hidden_size // 4) * 3
self.adaLN_modulation = SwiGLU(
in_features=hidden_size,
hidden_features=lora_hidden_size,
out_features=9 * hidden_size,
bias=True,
)
def forward(self, x, y, c, freqs_cos, freqs_sin, attn_type, delta_t=None):
(
shift_msa,
scale_msa,
gate_msa,
shift_msc,
scale_msc,
gate_msc,
shift_mlp,
scale_mlp,
gate_mlp,
) = self.adaLN_modulation(c).chunk(9, dim=-1)
x = x + gate_msa * self.attn(
modulate(self.norm1(x), shift_msa, scale_msa),
freqs_cos,
freqs_sin,
attn_type,
delta_t,
)
x = x + gate_msc * self.cross_attn(
modulate(self.norm2_i(x), shift_msc, scale_msc),
self.norm2_t(y),
freqs_cos,
freqs_sin,
attn_type,
)
x = x + gate_mlp * self.mlp(modulate(self.norm3(x), shift_mlp, scale_mlp))
return x
class FinalLayer(nn.Module):
"""
The final layer of TiM.
"""
def __init__(self, hidden_size, patch_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(
hidden_size, patch_size * patch_size * out_channels, bias=True
)
self.adaLN_modulation = SwiGLU(
in_features=hidden_size,
hidden_features=hidden_size // 2,
out_features=2 * hidden_size,
bias=True,
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class TiM(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
encoder_depth=8,
depth=28,
num_heads=16,
mlp_ratio=4.0,
cap_feat_dim=2048,
z_dim=768,
projector_dim=2048,
use_checkpoint: bool = False,
new_condition: str = "t-r",
use_new_embed: bool = False,
**block_kwargs, # qk_norm
):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.cap_feat_dim = cap_feat_dim
self.encoder_depth = encoder_depth
self.use_checkpoint = use_checkpoint
self.new_condition = new_condition
self.use_new_embed = use_new_embed
self.x_embedder = PatchEmbed(
input_size,
patch_size,
in_channels,
hidden_size,
bias=True,
strict_img_size=False,
)
self.t_embedder = TimestepEmbedder(hidden_size) # timestep embedding type
if use_new_embed:
self.delta_embedder = TimestepEmbedder(hidden_size)
self.y_embedder = CaptionEmbedder(cap_feat_dim, hidden_size)
# Will use fixed sin-cos embedding:
self.rope = VisionRotaryEmbedding(head_dim=hidden_size // num_heads)
self.blocks = nn.ModuleList(
[
TiMBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, **block_kwargs)
for _ in range(depth)
]
)
self.projector = build_mlp(hidden_size, projector_dim, z_dim)
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
self.initialize_weights()
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
# Initialize label embedding table:
nn.init.normal_(self.y_embedder.mlp.fc1_g.weight, std=0.02)
nn.init.normal_(self.y_embedder.mlp.fc1_x.weight, std=0.02)
nn.init.normal_(self.y_embedder.mlp.fc2.weight, std=0.02)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers in TiM blocks:
for block in self.blocks:
nn.init.constant_(block.adaLN_modulation.fc2.weight, 0)
nn.init.constant_(block.adaLN_modulation.fc2.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation.fc2.weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation.fc2.bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def unpatchify(self, x, H, W):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.out_channels
p = self.patch_size
h, w = int(H / p), int(W / p)
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum("nhwpqc->nchpwq", x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
return imgs
def get_rope(self, h, w, attn_type):
grid_h = torch.arange(h)
grid_w = torch.arange(w)
grid = torch.meshgrid(grid_h, grid_w, indexing="xy")
grid = torch.stack(grid, dim=0).reshape(2, -1).unsqueeze(0)
freqs_cos, freqs_sin = self.rope.get_cached_2d_rope_from_grid(grid)
if attn_type == "flash_attn": # (1, N, 1, d_head)
return freqs_cos.unsqueeze(2), freqs_sin.unsqueeze(2)
else: # (1, 1, N, d_head)
return freqs_cos.unsqueeze(1), freqs_sin.unsqueeze(1)
def forward(self, x, t, r, y, attn_type="fused_attn", return_zs=False, jvp=False):
"""
Forward pass of TiM.
x: (B, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (B,) tensor of diffusion timesteps
y: (B,) tensor of class labels
"""
B, C, H, W = x.shape
x = self.x_embedder(x) # (N, N, D), where T = H * W / patch_size ** 2
# timestep and class embedding
t_embed = self.t_embedder(t).unsqueeze(1) # (B, 1, D)
delta_embed = self.get_delta_embed(t, r).unsqueeze(1) # (B, 1, D)
y = self.y_embedder(y) # (B, M, D)
c = t_embed + delta_embed # (B, 1, D)
freqs_cos, freqs_sin = self.get_rope(
int(H / self.patch_size), int(W / self.patch_size), attn_type
)
for i, block in enumerate(self.blocks):
if not self.use_checkpoint or jvp:
x = block(
x, y, c, freqs_cos, freqs_sin, attn_type, delta_embed
) # (B, N, D)
else:
x = torch.utils.checkpoint.checkpoint(
self.ckpt_wrapper(block),
x,
y,
c,
freqs_cos,
freqs_sin,
attn_type,
delta_embed,
)
if (i + 1) == self.encoder_depth:
h_proj = self.projector(x)
x = self.final_layer(x, c) # (B, N, patch_size ** 2 * out_channels)
x = self.unpatchify(x, H, W) # (b, out_channels, H, W)
if return_zs:
return x, h_proj
else:
return x
def get_delta_embed(self, t, r):
if self.use_new_embed:
delta_embedder = self.delta_embedder
else:
delta_embedder = self.t_embedder
if self.new_condition == "t-r":
delta_embed = delta_embedder(t - r)
elif self.new_condition == "r":
delta_embed = delta_embedder(r)
elif self.new_condition == "t,r":
delta_embed = self.t_embedder(t) + delta_embedder(r)
elif self.new_condition == "t,t-r":
delta_embed = self.t_embedder(t) + delta_embedder(t - r)
elif self.new_condition == "r,t-r":
delta_embed = self.t_embedder(r) + delta_embedder(t - r)
elif self.new_condition == "t,r,t-r":
delta_embed = (
self.t_embedder(t) + self.t_embedder(r) + delta_embedder(t - r)
)
else:
raise NotImplementedError
return delta_embed
def ckpt_wrapper(self, module):
def ckpt_forward(*inputs):
outputs = module(*inputs)
return outputs
return ckpt_forward
@property
def dtype(self) -> torch.dtype:
"""
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
"""
return get_parameter_dtype(self)
|