File size: 4,981 Bytes
fd943c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
from typing import List, Optional, Tuple, Union
import timm
import torch
import numpy as np
from einops import rearrange
from timm.layers import Mlp
Shape3d = Union[List[int], Tuple[int, int, int]]
# PatchEmbed
class PatchEmbed(torch.nn.Module):
""" 3D Image to Patch Embedding
"""
def __init__(
self,
img_size: Shape3d = (4, 224, 224),
patch_size: Shape3d = (1, 16, 16),
in_chans: int = 3,
embed_dim: int = 768,
norm_layer: Optional[torch.nn.Module] = None,
flatten: bool = True,
bias: bool = True,
):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
assert len(self.img_size) == 3 and len(self.patch_size) == 3
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1], img_size[2] // patch_size[2])
self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
self.flatten = flatten
self.proj = torch.nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else torch.nn.Identity()
def forward(self, x):
B, C, T, H, W = x.shape
assert H == self.img_size[1], f"Input data height ({H}) doesn't match model ({self.img_size[1]})."
assert W == self.img_size[2], f"Input data width ({W}) doesn't match model ({self.img_size[2]})."
assert T == self.img_size[0], f"Input data timesteps ({T}) doesn't match model ({self.img_size[0]})."
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCTHW -> BCL -> BLC
x = self.norm(x)
return x
# DropPath
def generate_mask(x, drop_prob: float = 0., scale_by_keep: bool = True):
""" Create drop mask for x. Adapted from timm.models.layers.drop_path. """
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return random_tensor
class DropPath(torch.nn.Module):
""" Adapted from timm.models.layers.DropPath. In this version, drop mask can be saved and reused.
This is useful when applying the same drop mask more than once.
"""
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
super().__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
self.drop_mask = None
def generate_mask(self, x: torch.Tensor):
self.drop_mask = generate_mask(x, self.drop_prob, self.scale_by_keep)
def forward(self, x: torch.Tensor, new_mask: bool = True):
if self.drop_prob == 0. or not self.training:
return x
if self.drop_mask is None or new_mask:
self.generate_mask(x)
return self.drop_mask * x
def extra_repr(self):
return f'drop_prob={round(self.drop_prob, 3):0.3f}'
class Attention(torch.nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = torch.nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = torch.nn.Dropout(attn_drop)
self.proj = torch.nn.Linear(dim, dim)
self.proj_drop = torch.nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x, attn
class Block(torch.nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=torch.nn.GELU, norm_layer=torch.nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else torch.nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x, return_attention=False):
y, attn = self.attn(self.norm1(x))
if return_attention:
return attn
x = x + self.drop_path(y)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x |