|
from typing import List, Dict, Optional, Tuple, Union |
|
import random |
|
import torch |
|
import timm |
|
import numpy as np |
|
from einops import rearrange |
|
import torch.distributed as dist |
|
from timm.layers import drop_path, DropPath, Mlp, trunc_normal_ |
|
import logging |
|
|
|
from mae_dino.model_layers.layers import PatchEmbed, Attention, Block, PatchEmbed |
|
from mae_dino.model_layers.pos_embed import get_3d_sincos_pos_embed |
|
|
|
Shape3d = Union[List[int], Tuple[int, int, int]] |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
class TemporalEncoder(torch.nn.Module): |
|
def __init__(self, embed_dim, tokens_per_frame): |
|
super().__init__() |
|
self.embed_dim = embed_dim |
|
self.tokens_per_frame = tokens_per_frame |
|
|
|
|
|
self.year_embed_dim = embed_dim // 4 |
|
self.doy_embed_dim = embed_dim // 4 |
|
self.hour_embed_dim = embed_dim // 4 |
|
self.minute_embed_dim = embed_dim - ( |
|
self.year_embed_dim + self.doy_embed_dim + self.hour_embed_dim |
|
) |
|
|
|
|
|
self.year_embedding = torch.nn.Embedding(3000, self.year_embed_dim) |
|
self.doy_embedding = torch.nn.Embedding(367, self.doy_embed_dim) |
|
self.hour_embedding = torch.nn.Embedding(24, self.hour_embed_dim) |
|
self.minute_embedding = torch.nn.Embedding(60, self.minute_embed_dim) |
|
|
|
|
|
self._init_weights() |
|
|
|
def _init_weights(self): |
|
torch.nn.init.xavier_uniform_(self.year_embedding.weight) |
|
torch.nn.init.xavier_uniform_(self.doy_embedding.weight) |
|
torch.nn.init.xavier_uniform_(self.hour_embedding.weight) |
|
torch.nn.init.xavier_uniform_(self.minute_embedding.weight) |
|
|
|
def forward(self, year, doy, hour, minute): |
|
""" |
|
Args: |
|
year (torch.Tensor): Shape (batch_size, time), integer years |
|
doy (torch.Tensor): Shape (batch_size, time), values [1, 366] |
|
hour (torch.Tensor): Shape (batch_size, time), values [0, 23] |
|
minute (torch.Tensor): Shape (batch_size, time), values [0, 59] |
|
|
|
Returns: |
|
torch.Tensor: Temporal embeddings of shape (batch_size, time * tokens_per_frame, embed_dim) |
|
""" |
|
|
|
year = year.long() |
|
doy = doy.long() |
|
hour = hour.long() |
|
minute = minute.long() |
|
|
|
|
|
year_emb = self.year_embedding(year) |
|
doy_emb = self.doy_embedding(doy) |
|
hour_emb = self.hour_embedding(hour) |
|
minute_emb = self.minute_embedding(minute) |
|
|
|
|
|
temporal_emb = torch.cat( |
|
[year_emb, doy_emb, hour_emb, minute_emb], dim=-1 |
|
) |
|
|
|
|
|
batch_size, time_steps, _ = temporal_emb.shape |
|
temporal_emb = torch.repeat_interleave(temporal_emb, self.tokens_per_frame, dim=1) |
|
|
|
return temporal_emb |
|
|
|
|
|
class Encoder(torch.nn.Module): |
|
def __init__(self, |
|
img_size: Shape3d = [4, 224, 224], |
|
patch_size: Shape3d = [1, 16, 16], |
|
in_chans: int = 3, |
|
encoder_embed_dim: int = 1024, |
|
encoder_depth: int = 8, |
|
encoder_num_heads: int = 16, |
|
mlp_ratio: float = 4., |
|
norm_layer: torch.nn.Module = torch.nn.LayerNorm, |
|
drop_channels_rate: float = 0.0, |
|
adjacent_masking: bool = False, |
|
): |
|
super().__init__() |
|
self.img_size = img_size |
|
self.patch_size = patch_size |
|
self.in_chans = in_chans |
|
self.encoder_embed_dim = encoder_embed_dim |
|
self.encoder_depth = encoder_depth |
|
self.encoder_num_heads = encoder_num_heads |
|
self.mlp_ratio = mlp_ratio |
|
self.norm_layer = norm_layer |
|
self.drop_channels_rate = drop_channels_rate |
|
self.adjacent_masking = adjacent_masking |
|
|
|
|
|
|
|
self.drop_channels = torch.nn.Dropout3d(self.drop_channels_rate) if self.drop_channels_rate > 0 else torch.nn.Identity() |
|
|
|
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, encoder_embed_dim) |
|
num_patches = self.patch_embed.num_patches |
|
tokens_per_frame = num_patches // img_size[0] |
|
|
|
self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, encoder_embed_dim)) |
|
self.register_buffer("encoder_pos_embed", torch.zeros(1, num_patches + 1, encoder_embed_dim)) |
|
|
|
self.encoder_blocks = torch.nn.ModuleList([ |
|
Block(encoder_embed_dim, encoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=self.norm_layer) |
|
for i in range(self.encoder_depth)]) |
|
self.norm = norm_layer(encoder_embed_dim) |
|
|
|
self.temporal_embed_enc = TemporalEncoder(embed_dim=encoder_embed_dim, tokens_per_frame=tokens_per_frame) |
|
|
|
|
|
self.initialize_weights() |
|
|
|
def initialize_weights(self): |
|
encoder_pos_embed = get_3d_sincos_pos_embed( |
|
self.encoder_pos_embed.shape[-1], self.patch_embed.grid_size, cls_token=True |
|
) |
|
self.encoder_pos_embed.data.copy_(torch.from_numpy(encoder_pos_embed).float().unsqueeze(0)) |
|
|
|
|
|
|
|
w = self.patch_embed.proj.weight.data |
|
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
|
|
|
|
|
torch.nn.init.normal_(self.cls_token, std=0.02) |
|
|
|
|
|
self.apply(self._init_weights) |
|
|
|
def _init_weights(self, m): |
|
if isinstance(m, torch.nn.Linear): |
|
|
|
torch.nn.init.xavier_uniform_(m.weight) |
|
if isinstance(m, torch.nn.Linear) and m.bias is not None: |
|
torch.nn.init.constant_(m.bias, 0) |
|
elif isinstance(m, torch.nn.LayerNorm): |
|
torch.nn.init.constant_(m.bias, 0) |
|
torch.nn.init.constant_(m.weight, 1.0) |
|
|
|
def patchify(self, imgs): |
|
""" |
|
imgs: B, C, T, H, W |
|
x: B, L, D |
|
""" |
|
s, p, q = self.patch_embed.patch_size |
|
x = rearrange(imgs, 'b c (t s) (h p) (w q) -> b (t h w) (s p q c)', s=s, p=p, q=q) |
|
|
|
return x |
|
|
|
def unpatchify(self, x): |
|
""" |
|
x: B, L, D |
|
imgs: B, C, T, H, W |
|
""" |
|
s, p, q = self.patch_embed.patch_size |
|
gs = self.patch_embed.grid_size |
|
imgs = rearrange(x, 'b (t h w) (s p q c) -> b c (t s) (h p) (w q)', h=gs[1], w=gs[2], t=gs[0], s=s, p=p, q=q) |
|
return imgs |
|
|
|
def log_helper(self, message): |
|
logger.info(message) |
|
|
|
def hinted_random_masking(self, x: torch.Tensor, x_mask: torch.Tensor, mask_ratio: int): |
|
""" |
|
Perform per-sample random masking by per-sample shuffling. |
|
Per-sample shuffling is done by argsort random noise on patches without missing value pixels (indicated by x_mask). |
|
x: [N, L, D], sequence |
|
""" |
|
N, L, D = x.shape |
|
|
|
|
|
x_mask = x_mask.sum(dim=-1) > 0 |
|
missing_ratio = x_mask.float().mean(dim=1) |
|
|
|
adjusted_mask_ratio = max(max(missing_ratio).item(),mask_ratio) |
|
|
|
len_keep = int(L * (1 - adjusted_mask_ratio)) |
|
|
|
noise = torch.rand(N, L, device=x.device) |
|
|
|
|
|
ids_x_mask = torch.where(x_mask) |
|
|
|
|
|
ids_x_mask_dict = {} |
|
for bs, p in zip(ids_x_mask[0], ids_x_mask[1]): |
|
bs, p = int(bs), int(p) |
|
if bs not in ids_x_mask_dict: |
|
ids_x_mask_dict[bs] = [] |
|
ids_x_mask_dict[bs].append(p) |
|
|
|
noise += x_mask |
|
|
|
|
|
ids_shuffle = torch.argsort(noise, dim=1) |
|
ids_restore = torch.argsort(ids_shuffle, dim=1) |
|
|
|
|
|
ids_keep = ids_shuffle[:, :len_keep] |
|
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) |
|
|
|
|
|
mask = torch.ones([N, L], device=x.device) |
|
mask[:, :len_keep] = 0 |
|
|
|
mask = torch.gather(mask, dim=1, index=ids_restore) |
|
|
|
return x_masked, mask, ids_restore, ids_x_mask_dict |
|
|
|
|
|
def hinted_adjacent_masking(self, x: torch.Tensor, x_mask: torch.Tensor, mask_ratio: float): |
|
""" |
|
Perform masking by keeping contiguous blocks of patches in time and space, |
|
excluding patches where x_mask is 1 in each patch, and adjust to keep mask_ratio consistent. |
|
x: [N, L, D], sequence |
|
x_mask: [N, L, D], mask indicating invalid patches (1 where invalid) |
|
patch_embed: object containing grid_size attribute (T, H, W) |
|
""" |
|
N, L, D = x.shape |
|
|
|
|
|
gs = self.patch_embed.grid_size |
|
T, H, W = gs |
|
|
|
|
|
total_patches = T * H * W |
|
len_keep = int(total_patches * (1 - mask_ratio)) |
|
|
|
|
|
x = x.view(N, T, H, W, D) |
|
|
|
|
|
|
|
x_mask = x_mask.sum(dim=-1) > 0 |
|
|
|
ids_x_mask = torch.where(x_mask) |
|
ids_x_mask_dict = {} |
|
for bs, p in zip(ids_x_mask[0], ids_x_mask[1]): |
|
bs, p = int(bs), int(p) |
|
if bs not in ids_x_mask_dict: |
|
ids_x_mask_dict[bs] = [] |
|
ids_x_mask_dict[bs].append(p) |
|
|
|
x_mask = x_mask.view(N, T, H, W) |
|
|
|
|
|
mask = torch.zeros((N, T, H, W), device=x.device) |
|
|
|
ids_keep_list = [] |
|
|
|
|
|
for i in range(N): |
|
|
|
valid_patches_mask = ~x_mask[i] |
|
num_valid_patches = valid_patches_mask.sum().item() |
|
|
|
|
|
len_keep_i = min(len_keep, num_valid_patches) |
|
if len_keep_i == 0: |
|
ids_keep_list.append(torch.tensor([], device=x.device, dtype=torch.long)) |
|
continue |
|
|
|
patches_selected = 0 |
|
attempts = 0 |
|
max_attempts = 1000 |
|
|
|
while patches_selected < len_keep_i and attempts < max_attempts: |
|
attempts += 1 |
|
|
|
|
|
block_size = int(round((len_keep_i - patches_selected) ** (1/3))) |
|
t_size_block = min(T, block_size) |
|
h_size_block = min(H, block_size) |
|
w_size_block = min(W, block_size) |
|
|
|
|
|
t0 = random.randint(0, T - t_size_block) |
|
h0 = random.randint(0, H - h_size_block) |
|
w0 = random.randint(0, W - w_size_block) |
|
|
|
|
|
t_indices = slice(t0, t0 + t_size_block) |
|
h_indices = slice(h0, h0 + h_size_block) |
|
w_indices = slice(w0, w0 + w_size_block) |
|
|
|
|
|
block_valid_mask = valid_patches_mask[t_indices, h_indices, w_indices] |
|
block_already_selected = mask[i, t_indices, h_indices, w_indices] |
|
|
|
|
|
selectable_mask = block_valid_mask & (block_already_selected == 0) |
|
num_selectable = selectable_mask.sum().item() |
|
|
|
if num_selectable == 0: |
|
continue |
|
|
|
|
|
num_to_select = min(len_keep_i - patches_selected, num_selectable) |
|
|
|
|
|
selectable_indices = selectable_mask.nonzero(as_tuple=False) |
|
|
|
|
|
selected_indices = selectable_indices[torch.randperm(num_selectable)[:num_to_select]] |
|
|
|
|
|
for idx in selected_indices: |
|
t_idx, h_idx, w_idx = idx |
|
mask[i, t0 + t_idx, h0 + h_idx, w0 + w_idx] = 1 |
|
|
|
patches_selected += num_to_select |
|
|
|
|
|
if patches_selected < len_keep_i: |
|
remaining_selectable = (valid_patches_mask & (mask[i] == 0)).nonzero(as_tuple=False) |
|
num_remaining = remaining_selectable.size(0) |
|
num_needed = len_keep_i - patches_selected |
|
num_to_select = min(num_remaining, num_needed) |
|
if num_to_select > 0: |
|
selected_indices = remaining_selectable[torch.randperm(num_remaining)[:num_to_select]] |
|
mask[i][selected_indices[:, 0], selected_indices[:, 1], selected_indices[:, 2]] = 1 |
|
patches_selected += num_to_select |
|
|
|
|
|
mask_i_flat = mask[i].view(-1) |
|
ids_keep_i = torch.where(mask_i_flat == 1)[0] |
|
ids_keep_list.append(ids_keep_i) |
|
|
|
|
|
max_len_keep = max(len(ids) for ids in ids_keep_list) |
|
|
|
|
|
ids_keep_padded = torch.zeros((N, max_len_keep), dtype=torch.long, device=x.device) |
|
for i, ids in enumerate(ids_keep_list): |
|
ids_keep_padded[i, :len(ids)] = ids |
|
|
|
|
|
x_flat = x.view(N, L, D) |
|
|
|
|
|
x_masked_list = [] |
|
for i in range(N): |
|
ids = ids_keep_list[i] |
|
x_masked_i = torch.index_select(x_flat[i], dim=0, index=ids) |
|
x_masked_list.append(x_masked_i) |
|
x_masked = torch.nn.utils.rnn.pad_sequence(x_masked_list, batch_first=True) |
|
|
|
|
|
ids_restore = torch.arange(L, device=x.device).unsqueeze(0).repeat(N, 1) |
|
|
|
|
|
mask_flat = mask.view(N, L) |
|
final_mask = 1 - mask_flat |
|
|
|
return x_masked, final_mask, ids_restore, ids_x_mask_dict |
|
|
|
|
|
def forward(self, |
|
x: torch.Tensor, |
|
x_mask: torch.Tensor, |
|
mask_ratio: float, |
|
|
|
temporal_pos: Optional[List]): |
|
|
|
|
|
x = self.drop_channels(x) |
|
|
|
|
|
x = self.patch_embed(x) |
|
|
|
|
|
x = x + self.encoder_pos_embed[:, 1:, :] |
|
|
|
if temporal_pos: |
|
temporal_encoding = self.temporal_embed_enc(*temporal_pos) |
|
|
|
x = x + temporal_encoding |
|
|
|
|
|
x_mask = self.patchify(x_mask) |
|
if self.adjacent_masking: |
|
x, mask, ids_restore, ids_x_mask_dict = self.hinted_adjacent_masking(x, x_mask, mask_ratio) |
|
else: |
|
x, mask, ids_restore, ids_x_mask_dict = self.hinted_random_masking(x, x_mask, mask_ratio) |
|
|
|
|
|
cls_token = self.cls_token + self.encoder_pos_embed[:, :1, :] |
|
cls_tokens = cls_token.expand(x.shape[0], -1, -1) |
|
x = torch.cat((cls_tokens, x), dim=1) |
|
|
|
|
|
for blk in self.encoder_blocks: |
|
x = blk(x) |
|
x = self.norm(x) |
|
|
|
return x, mask, ids_restore, ids_x_mask_dict |
|
|
|
|
|
class Decoder(torch.nn.Module): |
|
def __init__(self, |
|
img_size: Shape3d = [4, 224, 224], |
|
patch_size: Shape3d = [1, 16, 16], |
|
in_chans: int = 3, |
|
encoder_embed_dim: int = 1024, |
|
decoder_embed_dim: int = 512, |
|
decoder_depth: int = 8, |
|
decoder_num_heads: int = 16, |
|
mlp_ratio: float = 4., |
|
norm_layer: torch.nn.Module = torch.nn.LayerNorm, |
|
norm_pix_loss: bool = False, |
|
): |
|
super().__init__() |
|
self.img_size = img_size |
|
self.patch_size = patch_size |
|
self.encoder_embed_dim = encoder_embed_dim |
|
self.decoder_embed_dim = decoder_embed_dim |
|
self.decoder_depth = decoder_depth |
|
self.decoder_num_heads = decoder_num_heads |
|
self.mlp_ratio = mlp_ratio |
|
self.norm_layer = norm_layer |
|
self.norm_pix_loss = norm_pix_loss |
|
|
|
|
|
|
|
self.decoder_embed = torch.nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True) |
|
self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) |
|
|
|
self.grid_size = [d//p for d, p in zip(self.img_size, self.patch_size)] |
|
num_patches = np.prod(self.grid_size) |
|
self.register_buffer("decoder_pos_embed", torch.zeros(1, num_patches + 1, decoder_embed_dim)) |
|
|
|
self.decoder_blocks = torch.nn.ModuleList([ |
|
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) |
|
for i in range(decoder_depth)]) |
|
|
|
self.decoder_norm = norm_layer(decoder_embed_dim) |
|
self.decoder_pred = torch.nn.Linear(decoder_embed_dim, |
|
patch_size[0] * patch_size[1] * patch_size[2] * in_chans, |
|
bias=True) |
|
|
|
tokens_per_frame = num_patches // img_size[0] |
|
self.temporal_embed_dec = TemporalEncoder(embed_dim=decoder_embed_dim, tokens_per_frame=tokens_per_frame) |
|
|
|
|
|
self.initialize_weights() |
|
|
|
def initialize_weights(self): |
|
decoder_pos_embed = get_3d_sincos_pos_embed( |
|
self.decoder_pos_embed.shape[-1], self.grid_size, cls_token=True |
|
) |
|
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) |
|
|
|
|
|
torch.nn.init.normal_(self.mask_token, std=0.02) |
|
|
|
|
|
self.apply(self._init_weights) |
|
|
|
def _init_weights(self, m): |
|
if isinstance(m, torch.nn.Linear): |
|
|
|
torch.nn.init.xavier_uniform_(m.weight) |
|
if isinstance(m, torch.nn.Linear) and m.bias is not None: |
|
torch.nn.init.constant_(m.bias, 0) |
|
elif isinstance(m, torch.nn.LayerNorm): |
|
torch.nn.init.constant_(m.bias, 0) |
|
torch.nn.init.constant_(m.weight, 1.0) |
|
|
|
def forward(self, x: torch.Tensor, |
|
ids_restore: torch.Tensor, |
|
temporal_pos: Optional[torch.Tensor]): |
|
|
|
x = self.decoder_embed(x) |
|
|
|
|
|
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) |
|
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) |
|
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) |
|
x = torch.cat([x[:, :1, :], x_], dim=1) |
|
|
|
|
|
x = x + self.decoder_pos_embed |
|
|
|
x_ = x[:, 1:, :] |
|
|
|
if temporal_pos: |
|
temporal_encoding = self.temporal_embed_dec(*temporal_pos) |
|
|
|
|
|
|
|
x_ = x_ + temporal_encoding |
|
|
|
x = torch.cat([x[:, :1, :], x_], dim=1) |
|
|
|
|
|
for blk in self.decoder_blocks: |
|
x = blk(x) |
|
x = self.decoder_norm(x) |
|
|
|
|
|
x = self.decoder_pred(x) |
|
|
|
|
|
x = x[:, 1:, :] |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
class GAIABase(torch.nn.Module): |
|
def __init__(self, |
|
img_size: Shape3d = [4, 224, 224], |
|
patch_size: Shape3d = [1, 16, 16], |
|
in_chans: int = 3, |
|
encoder_embed_dim: int = 1024, |
|
encoder_depth: int = 8, |
|
encoder_num_heads: int = 16, |
|
decoder_embed_dim: int = 512, |
|
decoder_depth: int = 8, |
|
decoder_num_heads: int = 16, |
|
mlp_ratio: float = 4., |
|
norm_layer: torch.nn.Module = torch.nn.LayerNorm, |
|
norm_pix_loss: bool = False, |
|
drop_channels_rate: float = 0.0, |
|
|
|
adjacent_masking: bool = False, |
|
norm_last_layer: bool = True, |
|
dino_head_dim: int = 1024, |
|
warmup_teacher_temp: float = 0.04, |
|
teacher_temp: float = 0.04, |
|
warmup_teacher_temp_epochs: int = 5, |
|
epochs: int = 100, |
|
student_temp: float = 0.1, |
|
center_momentum: float = 0.9, |
|
momentum_teacher: float = 0.996, |
|
): |
|
super().__init__() |
|
self.img_size = img_size |
|
self.patch_size = patch_size |
|
self.in_chans = in_chans |
|
self.encoder_embed_dim = encoder_embed_dim |
|
self.encoder_depth = encoder_depth |
|
self.encoder_num_heads = encoder_num_heads |
|
self.decoder_embed_dim = decoder_embed_dim |
|
self.decoder_depth = decoder_depth |
|
self.decoder_num_heads = decoder_num_heads |
|
self.mlp_ratio = mlp_ratio |
|
self.norm_layer = norm_layer |
|
self.norm_pix_loss = norm_pix_loss |
|
self.drop_channels_rate = drop_channels_rate |
|
|
|
self.adjacent_masking = adjacent_masking |
|
self.norm_last_layer = norm_last_layer |
|
self.dino_head_dim = dino_head_dim |
|
self.warmup_teacher_temp = warmup_teacher_temp |
|
self.teacher_temp = teacher_temp |
|
self.warmup_teacher_temp_epochs = warmup_teacher_temp_epochs |
|
self.epochs = epochs |
|
self.student_temp = student_temp |
|
self.center_momentum = center_momentum |
|
self.momentum_teacher = momentum_teacher |
|
self.log_count = 0 |
|
|
|
|
|
self.encoder = Encoder( |
|
img_size=img_size, |
|
patch_size=patch_size, |
|
in_chans=in_chans, |
|
encoder_embed_dim=encoder_embed_dim, |
|
encoder_depth=encoder_depth, |
|
encoder_num_heads=encoder_num_heads, |
|
mlp_ratio=mlp_ratio, |
|
norm_layer=norm_layer, |
|
drop_channels_rate=drop_channels_rate, |
|
adjacent_masking=False, |
|
) |
|
|
|
self.decoder = Decoder( |
|
img_size=img_size, |
|
patch_size=patch_size, |
|
in_chans=in_chans, |
|
encoder_embed_dim=encoder_embed_dim, |
|
decoder_embed_dim=decoder_embed_dim, |
|
decoder_depth=decoder_depth, |
|
decoder_num_heads=decoder_num_heads, |
|
mlp_ratio=mlp_ratio, |
|
norm_layer=norm_layer, |
|
norm_pix_loss=norm_pix_loss, |
|
) |
|
|
|
self.teacher = Encoder( |
|
img_size=img_size, |
|
patch_size=patch_size, |
|
in_chans=in_chans, |
|
encoder_embed_dim=encoder_embed_dim, |
|
encoder_depth=encoder_depth, |
|
encoder_num_heads=encoder_num_heads, |
|
mlp_ratio=mlp_ratio, |
|
norm_layer=norm_layer, |
|
drop_channels_rate=drop_channels_rate, |
|
adjacent_masking=self.adjacent_masking, |
|
) |
|
|
|
|
|
self.student = PassThroughHead( |
|
self.encoder, |
|
DINOHead(encoder_embed_dim, dino_head_dim, norm_last_layer=norm_last_layer) |
|
) |
|
|
|
self.teacher = PassThroughHead( |
|
self.teacher, |
|
DINOHead(encoder_embed_dim, dino_head_dim, norm_last_layer=norm_last_layer) |
|
) |
|
|
|
|
|
self.teacher.load_state_dict(self.student.state_dict()) |
|
|
|
|
|
for p in self.teacher.parameters(): |
|
p.requires_grad = False |
|
|
|
self.dino_loss = DINOLoss(dino_head_dim, warmup_teacher_temp, teacher_temp, |
|
warmup_teacher_temp_epochs, epochs, student_temp, center_momentum) |
|
|
|
|
|
def forward_mae(self, imgs: torch.Tensor, img_masks: torch.Tensor, temporal_pos: Optional[torch.Tensor], mask_ratio: float = 0.75): |
|
|
|
latent, mask, ids_restore, ids_x_mask_dict = self.encoder(imgs, img_masks, temporal_pos=temporal_pos, mask_ratio=mask_ratio) |
|
pred = self.decoder(latent, ids_restore, temporal_pos) |
|
loss = self.forward_mae_loss(imgs, pred, mask, ids_x_mask_dict) |
|
return loss, pred, mask |
|
|
|
def forward_dino(self, imgs: torch.Tensor, img_masks: torch.Tensor, temporal_pos: Optional[torch.Tensor], mask_ratio: float = 0.75): |
|
with torch.no_grad(): |
|
t_pred = self.teacher(imgs, img_masks, temporal_pos=temporal_pos, mask_ratio=mask_ratio) |
|
s_pred = self.student(imgs, img_masks, temporal_pos=temporal_pos, mask_ratio=mask_ratio) |
|
return s_pred, t_pred |
|
|
|
def log_helper(self, message): |
|
logger.info(message) |
|
|
|
def forward_mae_loss(self, imgs: torch.Tensor, pred: torch.Tensor, mask: torch.Tensor, ids_x_mask_dict: Dict[str, List]): |
|
""" |
|
imgs: B, C, T, H, W |
|
target: B, L, D |
|
pred: B, L, D |
|
mask: B, L. 0 is keep, 1 is remove, |
|
""" |
|
|
|
eps = 1e-6 |
|
target = self.encoder.patchify(imgs) |
|
|
|
if self.decoder.norm_pix_loss: |
|
mean = target.mean(dim=-1, keepdim=True) |
|
var = target.var(dim=-1, keepdim=True) |
|
|
|
loss = ((pred - target).to(torch.float) ** 2).mean(dim=-1) |
|
|
|
|
|
for index, patch in ids_x_mask_dict.items(): |
|
mask[index, patch] = 0 |
|
loss[index, patch] = 0 |
|
|
|
loss_mask = torch.clamp(loss * mask, max=1e8) |
|
loss = loss_mask.sum() / (mask.sum() + eps) |
|
return loss |
|
|
|
def forward(self, imgs: torch.Tensor, img_masks: torch.Tensor, temporal_pos: Optional[torch.Tensor], mask_ratio: float = 0.75, epoch=None): |
|
if epoch is None: |
|
raise ValueError(f"epoch value is invalid") |
|
|
|
temporal_pos = None |
|
|
|
mae_loss, pred, mask = self.forward_mae(imgs, img_masks, temporal_pos=temporal_pos, mask_ratio=mask_ratio) |
|
|
|
|
|
student_output, teacher_output = self.forward_dino(imgs, img_masks, temporal_pos=temporal_pos, mask_ratio=mask_ratio) |
|
dino_loss = self.dino_loss(student_output, teacher_output, epoch) |
|
|
|
|
|
total_loss = mae_loss + dino_loss |
|
return (total_loss, dino_loss, mae_loss), (pred, mask, student_output, teacher_output) |
|
|
|
|
|
class DINOHead(torch.nn.Module): |
|
def __init__(self, in_dim, dino_head_dim, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256): |
|
super().__init__() |
|
nlayers = max(nlayers, 1) |
|
if nlayers == 1: |
|
self.mlp = torch.nn.Linear(in_dim, bottleneck_dim) |
|
else: |
|
layers = [torch.nn.Linear(in_dim, hidden_dim)] |
|
|
|
layers.append(torch.nn.LayerNorm(hidden_dim)) |
|
layers.append(torch.nn.GELU()) |
|
|
|
for _ in range(nlayers - 2): |
|
layers.append(torch.nn.Linear(hidden_dim, hidden_dim)) |
|
layers.append(torch.nn.LayerNorm(hidden_dim)) |
|
layers.append(torch.nn.GELU()) |
|
|
|
layers.append(torch.nn.Linear(hidden_dim, bottleneck_dim)) |
|
self.mlp = torch.nn.Sequential(*layers) |
|
|
|
self.apply(self._init_weights) |
|
self.last_layer = torch.nn.utils.weight_norm(torch.nn.Linear(bottleneck_dim, dino_head_dim, bias=False)) |
|
self.last_layer.weight_g.data.fill_(1) |
|
if norm_last_layer: |
|
self.last_layer.weight_g.requires_grad = False |
|
|
|
def _init_weights(self, m): |
|
if isinstance(m, torch.nn.Linear): |
|
trunc_normal_(m.weight, std=.02) |
|
if isinstance(m, torch.nn.Linear) and m.bias is not None: |
|
torch.nn.init.constant_(m.bias, 0) |
|
|
|
def forward(self, x): |
|
x = self.mlp(x) |
|
x = torch.nn.functional.normalize(x, dim=-1, p=2) |
|
x = self.last_layer(x) |
|
return x |
|
|
|
class PassThroughHead(torch.nn.Module): |
|
def __init__(self, backbone, head): |
|
super().__init__() |
|
self.backbone = backbone |
|
self.head = head |
|
|
|
def forward(self, imgs: torch.Tensor, img_masks: torch.Tensor, temporal_pos: Optional[torch.Tensor], mask_ratio: float = 0.75): |
|
x, _, _, _ = self.backbone(imgs, img_masks, temporal_pos=temporal_pos, mask_ratio=mask_ratio) |
|
|
|
|
|
|
|
x = self.head(x[:, 0]) |
|
|
|
|
|
return x |
|
|
|
class DINOLoss(torch.nn.Module): |
|
def __init__(self, out_dim, warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs, nepochs, |
|
student_temp, center_momentum): |
|
super().__init__() |
|
self.student_temp = student_temp |
|
self.center_momentum = center_momentum |
|
self.register_buffer("center", torch.zeros(1, out_dim)) |
|
self.teacher_temp_schedule = np.concatenate( |
|
(np.linspace(warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs), |
|
np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp)) |
|
self.log_count = 0 |
|
|
|
def forward(self, student_output, teacher_output, epoch): |
|
student_out = student_output / self.student_temp |
|
temp = self.teacher_temp_schedule[epoch] |
|
teacher_out = torch.nn.functional.softmax((teacher_output - self.center) / temp, dim=-1).detach() |
|
|
|
loss = torch.sum(-teacher_out * torch.nn.functional.log_softmax(student_out, dim=-1), dim=-1) |
|
total_loss = loss.mean() |
|
|
|
self.update_center(teacher_output) |
|
return total_loss |
|
|
|
@torch.no_grad() |
|
def update_center(self, teacher_output): |
|
""" |
|
Update center used for teacher output. |
|
""" |
|
batch_center = torch.sum(teacher_output, dim=0, keepdim=True) |
|
|
|
if dist.is_initialized(): |
|
dist.all_reduce(batch_center) |
|
batch_center /= len(teacher_output) * dist.get_world_size() |
|
else: |
|
batch_center /= len(teacher_output) |
|
|
|
self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum) |
|
|