GAIA-v1 / mae_dino /models.py
willbender's picture
GAIA: A Foundation Model for Operational Atmospheric Dynamics
fd943c3
from typing import List, Dict, Optional, Tuple, Union
import random
import torch
import timm
import numpy as np
from einops import rearrange
import torch.distributed as dist
from timm.layers import drop_path, DropPath, Mlp, trunc_normal_
import logging
from mae_dino.model_layers.layers import PatchEmbed, Attention, Block, PatchEmbed
from mae_dino.model_layers.pos_embed import get_3d_sincos_pos_embed
Shape3d = Union[List[int], Tuple[int, int, int]]
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class TemporalEncoder(torch.nn.Module):
def __init__(self, embed_dim, tokens_per_frame):
super().__init__()
self.embed_dim = embed_dim
self.tokens_per_frame = tokens_per_frame
# Define embedding sizes for each temporal component
self.year_embed_dim = embed_dim // 4
self.doy_embed_dim = embed_dim // 4
self.hour_embed_dim = embed_dim // 4
self.minute_embed_dim = embed_dim - (
self.year_embed_dim + self.doy_embed_dim + self.hour_embed_dim
)
# Embedding layers for categorical temporal features
self.year_embedding = torch.nn.Embedding(3000, self.year_embed_dim) # Years from 0000 to 2999
self.doy_embedding = torch.nn.Embedding(367, self.doy_embed_dim) # Day of Year 0-365
self.hour_embedding = torch.nn.Embedding(24, self.hour_embed_dim) # Hours 0-23
self.minute_embedding = torch.nn.Embedding(60, self.minute_embed_dim) # Minutes 0-59
# Initialize embeddings
self._init_weights()
def _init_weights(self):
torch.nn.init.xavier_uniform_(self.year_embedding.weight)
torch.nn.init.xavier_uniform_(self.doy_embedding.weight)
torch.nn.init.xavier_uniform_(self.hour_embedding.weight)
torch.nn.init.xavier_uniform_(self.minute_embedding.weight)
def forward(self, year, doy, hour, minute):
"""
Args:
year (torch.Tensor): Shape (batch_size, time), integer years
doy (torch.Tensor): Shape (batch_size, time), values [1, 366]
hour (torch.Tensor): Shape (batch_size, time), values [0, 23]
minute (torch.Tensor): Shape (batch_size, time), values [0, 59]
Returns:
torch.Tensor: Temporal embeddings of shape (batch_size, time * tokens_per_frame, embed_dim)
"""
# Ensure inputs are of type Long
year = year.long()
doy = doy.long()
hour = hour.long()
minute = minute.long()
# Get embeddings for each temporal component
year_emb = self.year_embedding(year) # (batch_size, time, year_embed_dim)
doy_emb = self.doy_embedding(doy) # (batch_size, time, doy_embed_dim)
hour_emb = self.hour_embedding(hour) # (batch_size, time, hour_embed_dim)
minute_emb = self.minute_embedding(minute) # (batch_size, time, minute_embed_dim)
# Concatenate embeddings along the last dimension
temporal_emb = torch.cat(
[year_emb, doy_emb, hour_emb, minute_emb], dim=-1
) # (batch_size, time, embed_dim)
# Reshape to (batch_size, time * tokens_per_frame, embed_dim)
batch_size, time_steps, _ = temporal_emb.shape
temporal_emb = torch.repeat_interleave(temporal_emb, self.tokens_per_frame, dim=1)
return temporal_emb
class Encoder(torch.nn.Module):
def __init__(self,
img_size: Shape3d = [4, 224, 224],
patch_size: Shape3d = [1, 16, 16],
in_chans: int = 3,
encoder_embed_dim: int = 1024,
encoder_depth: int = 8,
encoder_num_heads: int = 16,
mlp_ratio: float = 4.,
norm_layer: torch.nn.Module = torch.nn.LayerNorm,
drop_channels_rate: float = 0.0,
adjacent_masking: bool = False,
):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.in_chans = in_chans
self.encoder_embed_dim = encoder_embed_dim
self.encoder_depth = encoder_depth
self.encoder_num_heads = encoder_num_heads
self.mlp_ratio = mlp_ratio
self.norm_layer = norm_layer
self.drop_channels_rate = drop_channels_rate
self.adjacent_masking = adjacent_masking
# -------------------------------------------------------------------------- #
# MAE encoder
self.drop_channels = torch.nn.Dropout3d(self.drop_channels_rate) if self.drop_channels_rate > 0 else torch.nn.Identity()
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, encoder_embed_dim)
num_patches = self.patch_embed.num_patches
tokens_per_frame = num_patches // img_size[0]
self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, encoder_embed_dim))
self.register_buffer("encoder_pos_embed", torch.zeros(1, num_patches + 1, encoder_embed_dim))
self.encoder_blocks = torch.nn.ModuleList([
Block(encoder_embed_dim, encoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=self.norm_layer)
for i in range(self.encoder_depth)])
self.norm = norm_layer(encoder_embed_dim)
self.temporal_embed_enc = TemporalEncoder(embed_dim=encoder_embed_dim, tokens_per_frame=tokens_per_frame)
# Initialize weights
self.initialize_weights()
def initialize_weights(self):
encoder_pos_embed = get_3d_sincos_pos_embed(
self.encoder_pos_embed.shape[-1], self.patch_embed.grid_size, cls_token=True
)
self.encoder_pos_embed.data.copy_(torch.from_numpy(encoder_pos_embed).float().unsqueeze(0))
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
w = self.patch_embed.proj.weight.data
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
torch.nn.init.normal_(self.cls_token, std=0.02)
# initialize nn.Linear and nn.LayerNorm
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, torch.nn.Linear):
# we use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, torch.nn.Linear) and m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
elif isinstance(m, torch.nn.LayerNorm):
torch.nn.init.constant_(m.bias, 0)
torch.nn.init.constant_(m.weight, 1.0)
def patchify(self, imgs):
"""
imgs: B, C, T, H, W
x: B, L, D
"""
s, p, q = self.patch_embed.patch_size
x = rearrange(imgs, 'b c (t s) (h p) (w q) -> b (t h w) (s p q c)', s=s, p=p, q=q)
return x
def unpatchify(self, x):
"""
x: B, L, D
imgs: B, C, T, H, W
"""
s, p, q = self.patch_embed.patch_size
gs = self.patch_embed.grid_size
imgs = rearrange(x, 'b (t h w) (s p q c) -> b c (t s) (h p) (w q)', h=gs[1], w=gs[2], t=gs[0], s=s, p=p, q=q)
return imgs
def log_helper(self, message):
logger.info(message)
def hinted_random_masking(self, x: torch.Tensor, x_mask: torch.Tensor, mask_ratio: int):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise on patches without missing value pixels (indicated by x_mask).
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
# Calculate missing data ratio from x_mask
x_mask = x_mask.sum(dim=-1) > 0 # [N, L]
missing_ratio = x_mask.float().mean(dim=1) # [N]
adjusted_mask_ratio = max(max(missing_ratio).item(),mask_ratio)
len_keep = int(L * (1 - adjusted_mask_ratio))
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
# x_mask = x_mask[0]
ids_x_mask = torch.where(x_mask)
# ids_x_mask = list(zip(ids_x_mask[0].tolist(), ids_x_mask[1].tolist()))
# TODO translate to NP
ids_x_mask_dict = {}
for bs, p in zip(ids_x_mask[0], ids_x_mask[1]):
bs, p = int(bs), int(p)
if bs not in ids_x_mask_dict:
ids_x_mask_dict[bs] = []
ids_x_mask_dict[bs].append(p)
noise += x_mask
# sort noise for each sample
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore, ids_x_mask_dict
def hinted_adjacent_masking(self, x: torch.Tensor, x_mask: torch.Tensor, mask_ratio: float):
"""
Perform masking by keeping contiguous blocks of patches in time and space,
excluding patches where x_mask is 1 in each patch, and adjust to keep mask_ratio consistent.
x: [N, L, D], sequence
x_mask: [N, L, D], mask indicating invalid patches (1 where invalid)
patch_embed: object containing grid_size attribute (T, H, W)
"""
N, L, D = x.shape # batch size, sequence length, feature dimension
# Get grid sizes
gs = self.patch_embed.grid_size # (T, H, W)
T, H, W = gs
# Calculate total number of patches and number to keep
total_patches = T * H * W
len_keep = int(total_patches * (1 - mask_ratio))
# Reshape x to [N, T, H, W, D]
x = x.view(N, T, H, W, D)
# Process x_mask to identify invalid patches
# x_mask: [N, L, D] -> [N, L], True where invalid
x_mask = x_mask.sum(dim=-1) > 0
ids_x_mask = torch.where(x_mask)
ids_x_mask_dict = {}
for bs, p in zip(ids_x_mask[0], ids_x_mask[1]):
bs, p = int(bs), int(p)
if bs not in ids_x_mask_dict:
ids_x_mask_dict[bs] = []
ids_x_mask_dict[bs].append(p)
x_mask = x_mask.view(N, T, H, W) # Reshape to [N, T, H, W]
# Initialize mask of zeros [N, T, H, W]; 0 is remove, 1 is keep
mask = torch.zeros((N, T, H, W), device=x.device)
ids_keep_list = []
# For each sample in the batch
for i in range(N):
# Valid patches are where x_mask is False (0)
valid_patches_mask = ~x_mask[i] # [T, H, W], True where valid
num_valid_patches = valid_patches_mask.sum().item()
# Adjust len_keep if not enough valid patches
len_keep_i = min(len_keep, num_valid_patches)
if len_keep_i == 0:
ids_keep_list.append(torch.tensor([], device=x.device, dtype=torch.long))
continue # Skip if no valid patches
patches_selected = 0
attempts = 0
max_attempts = 1000 # Prevent infinite loops
while patches_selected < len_keep_i and attempts < max_attempts:
attempts += 1
# Compute block sizes
block_size = int(round((len_keep_i - patches_selected) ** (1/3)))
t_size_block = min(T, block_size)
h_size_block = min(H, block_size)
w_size_block = min(W, block_size)
# Randomly select starting positions
t0 = random.randint(0, T - t_size_block)
h0 = random.randint(0, H - h_size_block)
w0 = random.randint(0, W - w_size_block)
# Get indices for the block
t_indices = slice(t0, t0 + t_size_block)
h_indices = slice(h0, h0 + h_size_block)
w_indices = slice(w0, w0 + w_size_block)
# Extract the block of valid patches
block_valid_mask = valid_patches_mask[t_indices, h_indices, w_indices]
block_already_selected = mask[i, t_indices, h_indices, w_indices]
# Find valid, unselected patches in the block
selectable_mask = block_valid_mask & (block_already_selected == 0)
num_selectable = selectable_mask.sum().item()
if num_selectable == 0:
continue # Try another block
# Determine how many patches to select from this block
num_to_select = min(len_keep_i - patches_selected, num_selectable)
# Get indices of selectable patches
selectable_indices = selectable_mask.nonzero(as_tuple=False)
# Randomly select patches from the selectable ones
selected_indices = selectable_indices[torch.randperm(num_selectable)[:num_to_select]]
# Update the mask to keep the selected patches
for idx in selected_indices:
t_idx, h_idx, w_idx = idx
mask[i, t0 + t_idx, h0 + h_idx, w0 + w_idx] = 1
patches_selected += num_to_select
# If not enough patches were selected, randomly select from remaining valid patches
if patches_selected < len_keep_i:
remaining_selectable = (valid_patches_mask & (mask[i] == 0)).nonzero(as_tuple=False)
num_remaining = remaining_selectable.size(0)
num_needed = len_keep_i - patches_selected
num_to_select = min(num_remaining, num_needed)
if num_to_select > 0:
selected_indices = remaining_selectable[torch.randperm(num_remaining)[:num_to_select]]
mask[i][selected_indices[:, 0], selected_indices[:, 1], selected_indices[:, 2]] = 1
patches_selected += num_to_select
# Get the indices of the kept patches
mask_i_flat = mask[i].view(-1)
ids_keep_i = torch.where(mask_i_flat == 1)[0]
ids_keep_list.append(ids_keep_i)
# Determine the maximum length of ids_keep across all samples for padding
max_len_keep = max(len(ids) for ids in ids_keep_list)
# Pad ids_keep to have the same length
ids_keep_padded = torch.zeros((N, max_len_keep), dtype=torch.long, device=x.device)
for i, ids in enumerate(ids_keep_list):
ids_keep_padded[i, :len(ids)] = ids
# Flatten x back to [N, L, D]
x_flat = x.view(N, L, D)
# Gather x_masked with padding
x_masked_list = []
for i in range(N):
ids = ids_keep_list[i]
x_masked_i = torch.index_select(x_flat[i], dim=0, index=ids)
x_masked_list.append(x_masked_i)
x_masked = torch.nn.utils.rnn.pad_sequence(x_masked_list, batch_first=True)
# Generate ids_restore (identity mapping)
ids_restore = torch.arange(L, device=x.device).unsqueeze(0).repeat(N, 1)
# Generate the binary mask: 0 is keep, 1 is remove
mask_flat = mask.view(N, L)
final_mask = 1 - mask_flat # Invert mask to match expected output
return x_masked, final_mask, ids_restore, ids_x_mask_dict
def forward(self,
x: torch.Tensor,
x_mask: torch.Tensor,
mask_ratio: float,
# temporal_pos: Optional[torch.Tensor]):
temporal_pos: Optional[List]):
# Drop input channels
x = self.drop_channels(x)
# embed patches
x = self.patch_embed(x)
# add pos embed w/o cls token
x = x + self.encoder_pos_embed[:, 1:, :]
if temporal_pos:
temporal_encoding = self.temporal_embed_enc(*temporal_pos)
# temporal_encoding = self.drop_temporal(temporal_encoding, new_mask=True)
x = x + temporal_encoding
# masking: length -> length * mask_ratio
x_mask = self.patchify(x_mask)
if self.adjacent_masking:
x, mask, ids_restore, ids_x_mask_dict = self.hinted_adjacent_masking(x, x_mask, mask_ratio)
else:
x, mask, ids_restore, ids_x_mask_dict = self.hinted_random_masking(x, x_mask, mask_ratio)
# append cls token
cls_token = self.cls_token + self.encoder_pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
# apply Transformer blocks
for blk in self.encoder_blocks:
x = blk(x)
x = self.norm(x)
return x, mask, ids_restore, ids_x_mask_dict
class Decoder(torch.nn.Module):
def __init__(self,
img_size: Shape3d = [4, 224, 224],
patch_size: Shape3d = [1, 16, 16],
in_chans: int = 3,
encoder_embed_dim: int = 1024,
decoder_embed_dim: int = 512,
decoder_depth: int = 8,
decoder_num_heads: int = 16,
mlp_ratio: float = 4.,
norm_layer: torch.nn.Module = torch.nn.LayerNorm,
norm_pix_loss: bool = False,
):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.encoder_embed_dim = encoder_embed_dim
self.decoder_embed_dim = decoder_embed_dim
self.decoder_depth = decoder_depth
self.decoder_num_heads = decoder_num_heads
self.mlp_ratio = mlp_ratio
self.norm_layer = norm_layer
self.norm_pix_loss = norm_pix_loss
# -------------------------------------------------------------------------- #
# MAE decoder
self.decoder_embed = torch.nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
self.grid_size = [d//p for d, p in zip(self.img_size, self.patch_size)]
num_patches = np.prod(self.grid_size)
self.register_buffer("decoder_pos_embed", torch.zeros(1, num_patches + 1, decoder_embed_dim))
self.decoder_blocks = torch.nn.ModuleList([
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
for i in range(decoder_depth)])
self.decoder_norm = norm_layer(decoder_embed_dim)
self.decoder_pred = torch.nn.Linear(decoder_embed_dim,
patch_size[0] * patch_size[1] * patch_size[2] * in_chans,
bias=True) # decoder to patch
tokens_per_frame = num_patches // img_size[0]
self.temporal_embed_dec = TemporalEncoder(embed_dim=decoder_embed_dim, tokens_per_frame=tokens_per_frame)
# Initialize weights
self.initialize_weights()
def initialize_weights(self):
decoder_pos_embed = get_3d_sincos_pos_embed(
self.decoder_pos_embed.shape[-1], self.grid_size, cls_token=True
)
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
torch.nn.init.normal_(self.mask_token, std=0.02)
# initialize nn.Linear and nn.LayerNorm
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, torch.nn.Linear):
# we use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, torch.nn.Linear) and m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
elif isinstance(m, torch.nn.LayerNorm):
torch.nn.init.constant_(m.bias, 0)
torch.nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor,
ids_restore: torch.Tensor,
temporal_pos: Optional[torch.Tensor]):
# embed tokens
x = self.decoder_embed(x)
# append mask tokens to sequence
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
# add pos embed
x = x + self.decoder_pos_embed
# remove cls token
x_ = x[:, 1:, :]
if temporal_pos:
temporal_encoding = self.temporal_embed_dec(*temporal_pos)
# Reuse drop mask from encoder for consistent dropping
# temporal_encoding = self.drop_temporal(temporal_encoding, new_mask=False)
# Add temporal encoding w/o cls token
x_ = x_ + temporal_encoding
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
# apply Transformer blocks
for blk in self.decoder_blocks:
x = blk(x)
x = self.decoder_norm(x)
# predictor projection
x = self.decoder_pred(x)
# remove cls token
x = x[:, 1:, :]
return x
class GAIABase(torch.nn.Module):
def __init__(self,
img_size: Shape3d = [4, 224, 224],
patch_size: Shape3d = [1, 16, 16],
in_chans: int = 3,
encoder_embed_dim: int = 1024,
encoder_depth: int = 8,
encoder_num_heads: int = 16,
decoder_embed_dim: int = 512,
decoder_depth: int = 8,
decoder_num_heads: int = 16,
mlp_ratio: float = 4.,
norm_layer: torch.nn.Module = torch.nn.LayerNorm,
norm_pix_loss: bool = False,
drop_channels_rate: float = 0.0,
# DINO Args
adjacent_masking: bool = False,
norm_last_layer: bool = True,
dino_head_dim: int = 1024,
warmup_teacher_temp: float = 0.04,
teacher_temp: float = 0.04,
warmup_teacher_temp_epochs: int = 5,
epochs: int = 100,
student_temp: float = 0.1,
center_momentum: float = 0.9,
momentum_teacher: float = 0.996,
):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.in_chans = in_chans
self.encoder_embed_dim = encoder_embed_dim
self.encoder_depth = encoder_depth
self.encoder_num_heads = encoder_num_heads
self.decoder_embed_dim = decoder_embed_dim
self.decoder_depth = decoder_depth
self.decoder_num_heads = decoder_num_heads
self.mlp_ratio = mlp_ratio
self.norm_layer = norm_layer
self.norm_pix_loss = norm_pix_loss
self.drop_channels_rate = drop_channels_rate
# DINO Args
self.adjacent_masking = adjacent_masking
self.norm_last_layer = norm_last_layer
self.dino_head_dim = dino_head_dim
self.warmup_teacher_temp = warmup_teacher_temp
self.teacher_temp = teacher_temp
self.warmup_teacher_temp_epochs = warmup_teacher_temp_epochs
self.epochs = epochs
self.student_temp = student_temp
self.center_momentum = center_momentum
self.momentum_teacher = momentum_teacher
self.log_count = 0
self.encoder = Encoder(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
encoder_embed_dim=encoder_embed_dim,
encoder_depth=encoder_depth,
encoder_num_heads=encoder_num_heads,
mlp_ratio=mlp_ratio,
norm_layer=norm_layer,
drop_channels_rate=drop_channels_rate,
adjacent_masking=False,
)
self.decoder = Decoder(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
encoder_embed_dim=encoder_embed_dim,
decoder_embed_dim=decoder_embed_dim,
decoder_depth=decoder_depth,
decoder_num_heads=decoder_num_heads,
mlp_ratio=mlp_ratio,
norm_layer=norm_layer,
norm_pix_loss=norm_pix_loss,
)
self.teacher = Encoder(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
encoder_embed_dim=encoder_embed_dim,
encoder_depth=encoder_depth,
encoder_num_heads=encoder_num_heads,
mlp_ratio=mlp_ratio,
norm_layer=norm_layer,
drop_channels_rate=drop_channels_rate,
adjacent_masking=self.adjacent_masking,
)
# DINO Head wrappers
self.student = PassThroughHead(
self.encoder,
DINOHead(encoder_embed_dim, dino_head_dim, norm_last_layer=norm_last_layer)
)
self.teacher = PassThroughHead(
self.teacher,
DINOHead(encoder_embed_dim, dino_head_dim, norm_last_layer=norm_last_layer)
)
# teacher and student start with the same weights
self.teacher.load_state_dict(self.student.state_dict())
# teacher frozen
for p in self.teacher.parameters():
p.requires_grad = False
self.dino_loss = DINOLoss(dino_head_dim, warmup_teacher_temp, teacher_temp,
warmup_teacher_temp_epochs, epochs, student_temp, center_momentum)
def forward_mae(self, imgs: torch.Tensor, img_masks: torch.Tensor, temporal_pos: Optional[torch.Tensor], mask_ratio: float = 0.75):
latent, mask, ids_restore, ids_x_mask_dict = self.encoder(imgs, img_masks, temporal_pos=temporal_pos, mask_ratio=mask_ratio)
pred = self.decoder(latent, ids_restore, temporal_pos)
loss = self.forward_mae_loss(imgs, pred, mask, ids_x_mask_dict)
return loss, pred, mask
def forward_dino(self, imgs: torch.Tensor, img_masks: torch.Tensor, temporal_pos: Optional[torch.Tensor], mask_ratio: float = 0.75):
with torch.no_grad():
t_pred = self.teacher(imgs, img_masks, temporal_pos=temporal_pos, mask_ratio=mask_ratio)
s_pred = self.student(imgs, img_masks, temporal_pos=temporal_pos, mask_ratio=mask_ratio)
return s_pred, t_pred
def log_helper(self, message):
logger.info(message)
def forward_mae_loss(self, imgs: torch.Tensor, pred: torch.Tensor, mask: torch.Tensor, ids_x_mask_dict: Dict[str, List]):
"""
imgs: B, C, T, H, W
target: B, L, D
pred: B, L, D
mask: B, L. 0 is keep, 1 is remove,
"""
eps = 1e-6
target = self.encoder.patchify(imgs)
if self.decoder.norm_pix_loss:
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
loss = ((pred - target).to(torch.float) ** 2).mean(dim=-1)
# Process mask and loss to exclude ids_x_mask patches (patches that include missing data) in the calculation
for index, patch in ids_x_mask_dict.items():
mask[index, patch] = 0
loss[index, patch] = 0
loss_mask = torch.clamp(loss * mask, max=1e8)
loss = loss_mask.sum() / (mask.sum() + eps) ### mean loss on removed patches
return loss
def forward(self, imgs: torch.Tensor, img_masks: torch.Tensor, temporal_pos: Optional[torch.Tensor], mask_ratio: float = 0.75, epoch=None):
if epoch is None:
raise ValueError(f"epoch value is invalid")
# MAE
temporal_pos = None
mae_loss, pred, mask = self.forward_mae(imgs, img_masks, temporal_pos=temporal_pos, mask_ratio=mask_ratio)
# DINO
student_output, teacher_output = self.forward_dino(imgs, img_masks, temporal_pos=temporal_pos, mask_ratio=mask_ratio)
dino_loss = self.dino_loss(student_output, teacher_output, epoch)
# Total Loss
total_loss = mae_loss + dino_loss
return (total_loss, dino_loss, mae_loss), (pred, mask, student_output, teacher_output)
class DINOHead(torch.nn.Module):
def __init__(self, in_dim, dino_head_dim, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
super().__init__()
nlayers = max(nlayers, 1)
if nlayers == 1:
self.mlp = torch.nn.Linear(in_dim, bottleneck_dim)
else:
layers = [torch.nn.Linear(in_dim, hidden_dim)]
layers.append(torch.nn.LayerNorm(hidden_dim))
layers.append(torch.nn.GELU())
for _ in range(nlayers - 2):
layers.append(torch.nn.Linear(hidden_dim, hidden_dim))
layers.append(torch.nn.LayerNorm(hidden_dim))
layers.append(torch.nn.GELU())
layers.append(torch.nn.Linear(hidden_dim, bottleneck_dim))
self.mlp = torch.nn.Sequential(*layers)
self.apply(self._init_weights)
self.last_layer = torch.nn.utils.weight_norm(torch.nn.Linear(bottleneck_dim, dino_head_dim, bias=False))
self.last_layer.weight_g.data.fill_(1)
if norm_last_layer:
self.last_layer.weight_g.requires_grad = False
def _init_weights(self, m):
if isinstance(m, torch.nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, torch.nn.Linear) and m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.mlp(x)
x = torch.nn.functional.normalize(x, dim=-1, p=2)
x = self.last_layer(x)
return x
class PassThroughHead(torch.nn.Module):
def __init__(self, backbone, head):
super().__init__()
self.backbone = backbone
self.head = head
def forward(self, imgs: torch.Tensor, img_masks: torch.Tensor, temporal_pos: Optional[torch.Tensor], mask_ratio: float = 0.75):
x, _, _, _ = self.backbone(imgs, img_masks, temporal_pos=temporal_pos, mask_ratio=mask_ratio)
# Either use cls token or use Global Average Pooling
# # CLS Token (default)
x = self.head(x[:, 0]) # Use the cls token
# # Global Average Pooling
# x = self.head(x.mean(dim=1) # Use the cls token
return x
class DINOLoss(torch.nn.Module):
def __init__(self, out_dim, warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs, nepochs,
student_temp, center_momentum):
super().__init__()
self.student_temp = student_temp
self.center_momentum = center_momentum
self.register_buffer("center", torch.zeros(1, out_dim))
self.teacher_temp_schedule = np.concatenate(
(np.linspace(warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs),
np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp))
self.log_count = 0
def forward(self, student_output, teacher_output, epoch):
student_out = student_output / self.student_temp
temp = self.teacher_temp_schedule[epoch]
teacher_out = torch.nn.functional.softmax((teacher_output - self.center) / temp, dim=-1).detach()
loss = torch.sum(-teacher_out * torch.nn.functional.log_softmax(student_out, dim=-1), dim=-1) # Changed from student_output
total_loss = loss.mean()
self.update_center(teacher_output)
return total_loss
@torch.no_grad()
def update_center(self, teacher_output):
"""
Update center used for teacher output.
"""
batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
# TODO: Tom et al.: Will this impact the DeepSpeed and lightning? Should we keep it or remove it?
if dist.is_initialized():
dist.all_reduce(batch_center)
batch_center /= len(teacher_output) * dist.get_world_size()
else:
batch_center /= len(teacher_output) # Use only batch size for single-process mode
self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)