from typing import List, Dict, Optional, Tuple, Union import random import torch import timm import numpy as np from einops import rearrange import torch.distributed as dist from timm.layers import drop_path, DropPath, Mlp, trunc_normal_ import logging from torchvision import models from mae_dino.model_layers.layers import PatchEmbed, Attention, Block, PatchEmbed from mae_dino.model_layers.pos_embed import get_3d_sincos_pos_embed import torch.nn as nn Shape3d = Union[List[int], Tuple[int, int, int]] logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class TemporalEncoder(torch.nn.Module): def __init__(self, embed_dim, tokens_per_frame): super().__init__() self.embed_dim = embed_dim self.tokens_per_frame = tokens_per_frame # Define embedding sizes for each temporal component self.year_embed_dim = embed_dim // 4 self.doy_embed_dim = embed_dim // 4 self.hour_embed_dim = embed_dim // 4 self.minute_embed_dim = embed_dim - ( self.year_embed_dim + self.doy_embed_dim + self.hour_embed_dim ) # Embedding layers for categorical temporal features self.year_embedding = torch.nn.Embedding(3000, self.year_embed_dim) # Years from 0000 to 2999 self.doy_embedding = torch.nn.Embedding(367, self.doy_embed_dim) # Day of Year 0-365 self.hour_embedding = torch.nn.Embedding(24, self.hour_embed_dim) # Hours 0-23 self.minute_embedding = torch.nn.Embedding(60, self.minute_embed_dim) # Minutes 0-59 # Initialize embeddings self._init_weights() def _init_weights(self): torch.nn.init.xavier_uniform_(self.year_embedding.weight) torch.nn.init.xavier_uniform_(self.doy_embedding.weight) torch.nn.init.xavier_uniform_(self.hour_embedding.weight) torch.nn.init.xavier_uniform_(self.minute_embedding.weight) def forward(self, year, doy, hour, minute): """ Args: year (torch.Tensor): Shape (batch_size, time), integer years doy (torch.Tensor): Shape (batch_size, time), values [1, 366] hour (torch.Tensor): Shape (batch_size, time), values [0, 23] minute (torch.Tensor): Shape (batch_size, time), values [0, 59] Returns: torch.Tensor: Temporal embeddings of shape (batch_size, time * tokens_per_frame, embed_dim) """ # Ensure inputs are of type Long year = year.long() doy = doy.long() hour = hour.long() minute = minute.long() # Get embeddings for each temporal component year_emb = self.year_embedding(year) # (batch_size, time, year_embed_dim) doy_emb = self.doy_embedding(doy) # (batch_size, time, doy_embed_dim) hour_emb = self.hour_embedding(hour) # (batch_size, time, hour_embed_dim) minute_emb = self.minute_embedding(minute) # (batch_size, time, minute_embed_dim) # Concatenate embeddings along the last dimension temporal_emb = torch.cat( [year_emb, doy_emb, hour_emb, minute_emb], dim=-1 ) # (batch_size, time, embed_dim) # Reshape to (batch_size, time * tokens_per_frame, embed_dim) batch_size, time_steps, _ = temporal_emb.shape temporal_emb = torch.repeat_interleave(temporal_emb, self.tokens_per_frame, dim=1) return temporal_emb class Encoder(torch.nn.Module): def __init__(self, img_size: Shape3d = [4, 224, 224], patch_size: Shape3d = [1, 16, 16], in_chans: int = 3, encoder_embed_dim: int = 1024, encoder_depth: int = 8, encoder_num_heads: int = 16, mlp_ratio: float = 4., norm_layer: torch.nn.Module = torch.nn.LayerNorm, drop_channels_rate: float = 0.0, adjacent_masking: bool = False, ): super().__init__() self.img_size = img_size self.patch_size = patch_size self.in_chans = in_chans self.encoder_embed_dim = encoder_embed_dim self.encoder_depth = encoder_depth self.encoder_num_heads = encoder_num_heads self.mlp_ratio = mlp_ratio self.norm_layer = norm_layer self.drop_channels_rate = drop_channels_rate self.adjacent_masking = adjacent_masking # -------------------------------------------------------------------------- # # MAE encoder self.drop_channels = torch.nn.Dropout3d(self.drop_channels_rate) if self.drop_channels_rate > 0 else torch.nn.Identity() self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, encoder_embed_dim) num_patches = self.patch_embed.num_patches tokens_per_frame = num_patches // img_size[0] self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, encoder_embed_dim)) self.register_buffer("encoder_pos_embed", torch.zeros(1, num_patches + 1, encoder_embed_dim)) self.encoder_blocks = torch.nn.ModuleList([ Block(encoder_embed_dim, encoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=self.norm_layer) for i in range(self.encoder_depth)]) self.norm = norm_layer(encoder_embed_dim) self.temporal_embed_enc = TemporalEncoder(embed_dim=encoder_embed_dim, tokens_per_frame=tokens_per_frame) # Initialize weights self.initialize_weights() def initialize_weights(self): encoder_pos_embed = get_3d_sincos_pos_embed( self.encoder_pos_embed.shape[-1], self.patch_embed.grid_size, cls_token=True ) self.encoder_pos_embed.data.copy_(torch.from_numpy(encoder_pos_embed).float().unsqueeze(0)) # initialize patch_embed like nn.Linear (instead of nn.Conv2d) w = self.patch_embed.proj.weight.data torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) torch.nn.init.normal_(self.cls_token, std=0.02) # initialize nn.Linear and nn.LayerNorm self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, torch.nn.Linear): # we use xavier_uniform following official JAX ViT: torch.nn.init.xavier_uniform_(m.weight) if isinstance(m, torch.nn.Linear) and m.bias is not None: torch.nn.init.constant_(m.bias, 0) elif isinstance(m, torch.nn.LayerNorm): torch.nn.init.constant_(m.bias, 0) torch.nn.init.constant_(m.weight, 1.0) def patchify(self, imgs): """ imgs: B, C, T, H, W x: B, L, D """ s, p, q = self.patch_embed.patch_size x = rearrange(imgs, 'b c (t s) (h p) (w q) -> b (t h w) (s p q c)', s=s, p=p, q=q) return x def unpatchify(self, x): """ x: B, L, D imgs: B, C, T, H, W """ s, p, q = self.patch_embed.patch_size gs = self.patch_embed.grid_size imgs = rearrange(x, 'b (t h w) (s p q c) -> b c (t s) (h p) (w q)', h=gs[1], w=gs[2], t=gs[0], s=s, p=p, q=q) return imgs def log_helper(self, message): logger.info(message) def hinted_random_masking(self, x: torch.Tensor, x_mask: torch.Tensor, mask_ratio: int): """ Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random noise on patches without missing value pixels (indicated by x_mask). x: [N, L, D], sequence """ N, L, D = x.shape # batch, length, dim # Calculate missing data ratio from x_mask x_mask = x_mask.sum(dim=-1) > 0 # [N, L] missing_ratio = x_mask.float().mean(dim=1) # [N] adjusted_mask_ratio = min(max(missing_ratio).item() + mask_ratio, 1) self.log_helper(f"missing_data_ratio: {missing_ratio}") self.log_helper(f"adjusted_mask_ratio: {adjusted_mask_ratio}") len_keep = int(L * (1 - adjusted_mask_ratio)) noise = torch.rand(N, L, device=x.device) # noise in [0, 1] ids_x_mask = torch.where(x_mask) # ids_x_mask = list(zip(ids_x_mask[0].tolist(), ids_x_mask[1].tolist())) # TODO translate to NP ids_x_mask_dict = {} for bs, p in zip(ids_x_mask[0], ids_x_mask[1]): bs, p = int(bs), int(p) if bs not in ids_x_mask_dict: ids_x_mask_dict[bs] = [] ids_x_mask_dict[bs].append(p) noise += x_mask # sort noise for each sample ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove ids_restore = torch.argsort(ids_shuffle, dim=1) # keep the first subset ids_keep = ids_shuffle[:, :len_keep] x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) # generate the binary mask: 0 is keep, 1 is remove mask = torch.ones([N, L], device=x.device) mask[:, :len_keep] = 0 # unshuffle to get the binary mask mask = torch.gather(mask, dim=1, index=ids_restore) return x_masked, mask, ids_restore, ids_x_mask_dict def hinted_adjacent_masking(self, x: torch.Tensor, x_mask: torch.Tensor, mask_ratio: float): """ Perform masking by keeping contiguous blocks of patches in time and space, excluding patches where x_mask is 1 in each patch, and adjust to keep mask_ratio consistent. x: [N, L, D], sequence x_mask: [N, L, D], mask indicating invalid patches (1 where invalid) patch_embed: object containing grid_size attribute (T, H, W) """ N, L, D = x.shape # batch size, sequence length, feature dimension # Get grid sizes gs = self.patch_embed.grid_size # (T, H, W) T, H, W = gs # Calculate total number of patches and number to keep total_patches = T * H * W len_keep = int(total_patches * (1 - mask_ratio)) # Reshape x to [N, T, H, W, D] x = x.view(N, T, H, W, D) # Process x_mask to identify invalid patches # x_mask: [N, L, D] -> [N, L], True where invalid x_mask = x_mask.sum(dim=-1) > 0 ids_x_mask = torch.where(x_mask) ids_x_mask_dict = {} for bs, p in zip(ids_x_mask[0], ids_x_mask[1]): bs, p = int(bs), int(p) if bs not in ids_x_mask_dict: ids_x_mask_dict[bs] = [] ids_x_mask_dict[bs].append(p) x_mask = x_mask.view(N, T, H, W) # Reshape to [N, T, H, W] # Initialize mask of zeros [N, T, H, W]; 0 is remove, 1 is keep mask = torch.zeros((N, T, H, W), device=x.device) ids_keep_list = [] # For each sample in the batch for i in range(N): # Valid patches are where x_mask is False (0) valid_patches_mask = ~x_mask[i] # [T, H, W], True where valid num_valid_patches = valid_patches_mask.sum().item() # Adjust len_keep if not enough valid patches len_keep_i = min(len_keep, num_valid_patches) if len_keep_i == 0: ids_keep_list.append(torch.tensor([], device=x.device, dtype=torch.long)) continue # Skip if no valid patches patches_selected = 0 attempts = 0 max_attempts = 1000 # Prevent infinite loops while patches_selected < len_keep_i and attempts < max_attempts: attempts += 1 # Compute block sizes block_size = int(round((len_keep_i - patches_selected) ** (1/3))) t_size_block = min(T, block_size) h_size_block = min(H, block_size) w_size_block = min(W, block_size) # Randomly select starting positions t0 = random.randint(0, T - t_size_block) h0 = random.randint(0, H - h_size_block) w0 = random.randint(0, W - w_size_block) # Get indices for the block t_indices = slice(t0, t0 + t_size_block) h_indices = slice(h0, h0 + h_size_block) w_indices = slice(w0, w0 + w_size_block) # Extract the block of valid patches block_valid_mask = valid_patches_mask[t_indices, h_indices, w_indices] block_already_selected = mask[i, t_indices, h_indices, w_indices] # Find valid, unselected patches in the block selectable_mask = block_valid_mask & (block_already_selected == 0) num_selectable = selectable_mask.sum().item() if num_selectable == 0: continue # Try another block # Determine how many patches to select from this block num_to_select = min(len_keep_i - patches_selected, num_selectable) # Get indices of selectable patches selectable_indices = selectable_mask.nonzero(as_tuple=False) # Randomly select patches from the selectable ones selected_indices = selectable_indices[torch.randperm(num_selectable)[:num_to_select]] # Update the mask to keep the selected patches for idx in selected_indices: t_idx, h_idx, w_idx = idx mask[i, t0 + t_idx, h0 + h_idx, w0 + w_idx] = 1 patches_selected += num_to_select # If not enough patches were selected, randomly select from remaining valid patches if patches_selected < len_keep_i: remaining_selectable = (valid_patches_mask & (mask[i] == 0)).nonzero(as_tuple=False) num_remaining = remaining_selectable.size(0) num_needed = len_keep_i - patches_selected num_to_select = min(num_remaining, num_needed) if num_to_select > 0: selected_indices = remaining_selectable[torch.randperm(num_remaining)[:num_to_select]] mask[i][selected_indices[:, 0], selected_indices[:, 1], selected_indices[:, 2]] = 1 patches_selected += num_to_select # Get the indices of the kept patches mask_i_flat = mask[i].view(-1) ids_keep_i = torch.where(mask_i_flat == 1)[0] ids_keep_list.append(ids_keep_i) # Determine the maximum length of ids_keep across all samples for padding max_len_keep = max(len(ids) for ids in ids_keep_list) # Pad ids_keep to have the same length ids_keep_padded = torch.zeros((N, max_len_keep), dtype=torch.long, device=x.device) for i, ids in enumerate(ids_keep_list): ids_keep_padded[i, :len(ids)] = ids # Flatten x back to [N, L, D] x_flat = x.view(N, L, D) # Gather x_masked with padding x_masked_list = [] for i in range(N): ids = ids_keep_list[i] x_masked_i = torch.index_select(x_flat[i], dim=0, index=ids) x_masked_list.append(x_masked_i) x_masked = torch.nn.utils.rnn.pad_sequence(x_masked_list, batch_first=True) # Generate ids_restore (identity mapping) ids_restore = torch.arange(L, device=x.device).unsqueeze(0).repeat(N, 1) # Generate the binary mask: 0 is keep, 1 is remove mask_flat = mask.view(N, L) final_mask = 1 - mask_flat # Invert mask to match expected output return x_masked, final_mask, ids_restore, ids_x_mask_dict def forward(self, x: torch.Tensor, x_mask: torch.Tensor, mask_ratio: float, # temporal_pos: Optional[torch.Tensor]): temporal_pos: Optional[List]): # Drop input channels x = self.drop_channels(x) # embed patches x = self.patch_embed(x) # add pos embed w/o cls token x = x + self.encoder_pos_embed[:, 1:, :] if temporal_pos: temporal_encoding = self.temporal_embed_enc(*temporal_pos) # temporal_encoding = self.drop_temporal(temporal_encoding, new_mask=True) x = x + temporal_encoding # masking: length -> length * mask_ratio x_mask = self.patchify(x_mask) if self.adjacent_masking: x, mask, ids_restore, ids_x_mask_dict = self.hinted_adjacent_masking(x, x_mask, mask_ratio) else: x, mask, ids_restore, ids_x_mask_dict = self.hinted_random_masking(x, x_mask, mask_ratio) # append cls token cls_token = self.cls_token + self.encoder_pos_embed[:, :1, :] cls_tokens = cls_token.expand(x.shape[0], -1, -1) x = torch.cat((cls_tokens, x), dim=1) # apply Transformer blocks for blk in self.encoder_blocks: x = blk(x) x = self.norm(x) return x, mask, ids_restore, ids_x_mask_dict class Decoder(torch.nn.Module): def __init__(self, img_size: Shape3d = [4, 224, 224], patch_size: Shape3d = [1, 16, 16], in_chans: int = 3, encoder_embed_dim: int = 1024, decoder_embed_dim: int = 512, decoder_depth: int = 8, decoder_num_heads: int = 16, mlp_ratio: float = 4., norm_layer: torch.nn.Module = torch.nn.LayerNorm, norm_pix_loss: bool = False, ): super().__init__() self.img_size = img_size self.patch_size = patch_size self.encoder_embed_dim = encoder_embed_dim self.decoder_embed_dim = decoder_embed_dim self.decoder_depth = decoder_depth self.decoder_num_heads = decoder_num_heads self.mlp_ratio = mlp_ratio self.norm_layer = norm_layer self.norm_pix_loss = norm_pix_loss # -------------------------------------------------------------------------- # # MAE decoder self.decoder_embed = torch.nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True) self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) self.grid_size = [d//p for d, p in zip(self.img_size, self.patch_size)] num_patches = np.prod(self.grid_size) self.register_buffer("decoder_pos_embed", torch.zeros(1, num_patches + 1, decoder_embed_dim)) self.decoder_blocks = torch.nn.ModuleList([ Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) for i in range(decoder_depth)]) self.decoder_norm = norm_layer(decoder_embed_dim) self.decoder_pred = torch.nn.Linear(decoder_embed_dim, patch_size[0] * patch_size[1] * patch_size[2] * in_chans, bias=True) # decoder to patch tokens_per_frame = num_patches // img_size[0] self.temporal_embed_dec = TemporalEncoder(embed_dim=decoder_embed_dim, tokens_per_frame=tokens_per_frame) # Initialize weights self.initialize_weights() def initialize_weights(self): decoder_pos_embed = get_3d_sincos_pos_embed( self.decoder_pos_embed.shape[-1], self.grid_size, cls_token=True ) self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) torch.nn.init.normal_(self.mask_token, std=0.02) # initialize nn.Linear and nn.LayerNorm self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, torch.nn.Linear): # we use xavier_uniform following official JAX ViT: torch.nn.init.xavier_uniform_(m.weight) if isinstance(m, torch.nn.Linear) and m.bias is not None: torch.nn.init.constant_(m.bias, 0) elif isinstance(m, torch.nn.LayerNorm): torch.nn.init.constant_(m.bias, 0) torch.nn.init.constant_(m.weight, 1.0) def forward(self, x: torch.Tensor, ids_restore: torch.Tensor, temporal_pos: Optional[torch.Tensor]): # embed tokens x = self.decoder_embed(x) # append mask tokens to sequence mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token # add pos embed x = x + self.decoder_pos_embed # remove cls token x_ = x[:, 1:, :] if temporal_pos: temporal_encoding = self.temporal_embed_dec(*temporal_pos) # Add temporal encoding w/o cls token x_ = x_ + temporal_encoding x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token # apply Transformer blocks for blk in self.decoder_blocks: x = blk(x) x = self.decoder_norm(x) # predictor projection x = self.decoder_pred(x) # remove cls token x = x[:, 1:, :] return x class GapFill(torch.nn.Module): def __init__(self, img_size: Shape3d = [4, 224, 224], patch_size: Shape3d = [1, 16, 16], in_chans: int = 3, encoder_embed_dim: int = 1024, encoder_depth: int = 8, encoder_num_heads: int = 16, decoder_embed_dim: int = 512, decoder_depth: int = 8, decoder_num_heads: int = 16, mlp_ratio: float = 4., norm_layer: torch.nn.Module = torch.nn.LayerNorm, norm_pix_loss: bool = False, drop_channels_rate: float = 0.0, mask_loss_weight: float = 1, adjacent_masking: bool = False, ): super().__init__() self.img_size = img_size self.patch_size = patch_size self.in_chans = in_chans self.encoder_embed_dim = encoder_embed_dim self.encoder_depth = encoder_depth self.encoder_num_heads = encoder_num_heads self.decoder_embed_dim = decoder_embed_dim self.decoder_depth = decoder_depth self.decoder_num_heads = decoder_num_heads self.mlp_ratio = mlp_ratio self.norm_layer = norm_layer self.norm_pix_loss = norm_pix_loss self.drop_channels_rate = drop_channels_rate self.mask_loss_weight = mask_loss_weight self.adjacent_masking = adjacent_masking self.encoder = Encoder( img_size=img_size, patch_size=patch_size, in_chans=in_chans, encoder_embed_dim=encoder_embed_dim, encoder_depth=encoder_depth, encoder_num_heads=encoder_num_heads, mlp_ratio=mlp_ratio, norm_layer=norm_layer, drop_channels_rate=drop_channels_rate, adjacent_masking=adjacent_masking, ) self.decoder = Decoder( img_size=img_size, patch_size=patch_size, in_chans=in_chans, encoder_embed_dim=encoder_embed_dim, decoder_embed_dim=decoder_embed_dim, decoder_depth=decoder_depth, decoder_num_heads=decoder_num_heads, mlp_ratio=mlp_ratio, norm_layer=norm_layer, norm_pix_loss=norm_pix_loss, ) def forward_mae(self, imgs: torch.Tensor, img_masks: torch.Tensor, temporal_pos: Optional[torch.Tensor], mask_ratio: float = 0.75): latent, mask, ids_restore, ids_x_mask_dict = self.encoder(imgs, img_masks, temporal_pos=None, mask_ratio=mask_ratio) pred = self.decoder(latent, ids_restore, temporal_pos=None) loss = self.forward_mae_loss(imgs, pred, mask, ids_x_mask_dict) return loss, pred, mask def log_helper(self, message): logger.info(message) def forward_mae_loss(self, imgs: torch.Tensor, pred: torch.Tensor, mask: torch.Tensor, ids_x_mask_dict: Dict[str, List]): """ imgs: B, C, T, H, W target: B, L, D pred: B, L, D mask: B, L. 0 is keep, 1 is remove, """ target = self.encoder.patchify(imgs) if self.decoder.norm_pix_loss: mean = target.mean(dim=-1, keepdim=True) var = target.var(dim=-1, keepdim=True) target = (target - mean) / (var + 1.e-6)**.5 loss = ((pred - target).to(torch.float) ** 2).mean(dim=-1) # MSE non_missing_data_mask = torch.ones_like(mask) for index, patch in ids_x_mask_dict.items(): non_missing_data_mask[index, patch] = 0 # Loss over non-missing, non-masked patches loss_unmasked = torch.clamp(loss * non_missing_data_mask * (1-mask), max=1e8).sum() / (non_missing_data_mask * (1-mask)).sum() # Loss over non-missing, masked patches loss_masked = torch.clamp(loss * non_missing_data_mask * mask, max=1e8).sum() / (non_missing_data_mask * mask).sum() loss = loss_unmasked + self.mask_loss_weight * loss_masked return loss def forward(self, imgs: torch.Tensor, img_masks: torch.Tensor, temporal_pos: Optional[torch.Tensor], mask_ratio: float = 0.75, epoch=None): if epoch is None: raise ValueError(f"epoch value is invalid") # MAE mae_loss, pred, mask = self.forward_mae(imgs, img_masks, temporal_pos=temporal_pos, mask_ratio=mask_ratio) return (mae_loss), (pred, mask)