from typing import Dict, Tuple, Union import torch import torch.nn as nn from transformers import PretrainedConfig, PreTrainedModel from transformers.utils import cached_file from .loss import FourierLoss from .normalizer import Normalizer from .mae_modules import CAMAEDecoder, MAEDecoder, MAEEncoder from .mae_utils import flatten_images from .vit import ( generate_2d_sincos_pos_embeddings, sincos_positional_encoding_vit, vit_small_patch16_256, ) TensorDict = Dict[str, torch.Tensor] class MAEConfig(PretrainedConfig): model_type = "MAE" def __init__( self, mask_ratio=0.75, encoder=None, decoder=None, loss=None, optimizer=None, input_norm=None, fourier_loss=None, fourier_loss_weight=0.0, lr_scheduler=None, use_MAE_weight_init=False, crop_size=-1, mask_fourier_loss=True, return_channelwise_embeddings=False, **kwargs, ): super().__init__(**kwargs) self.mask_ratio = mask_ratio self.encoder = encoder self.decoder = decoder self.loss = loss self.optimizer = optimizer self.input_norm = input_norm self.fourier_loss = fourier_loss self.fourier_loss_weight = fourier_loss_weight self.lr_scheduler = lr_scheduler self.use_MAE_weight_init = use_MAE_weight_init self.crop_size = crop_size self.mask_fourier_loss = mask_fourier_loss self.return_channelwise_embeddings = return_channelwise_embeddings class MAEModel(PreTrainedModel): config_class = MAEConfig # Loss metrics TOTAL_LOSS = "loss" RECON_LOSS = "reconstruction_loss" FOURIER_LOSS = "fourier_loss" def __init__(self, config: MAEConfig): super().__init__(config) self.mask_ratio = config.mask_ratio # Could use Hydra to instantiate instead self.encoder = MAEEncoder( vit_backbone=sincos_positional_encoding_vit( vit_backbone=vit_small_patch16_256(global_pool="avg") ), max_in_chans=11, # upper limit on number of input channels channel_agnostic=True, ) self.decoder = CAMAEDecoder( depth=8, embed_dim=512, mlp_ratio=4, norm_layer=nn.LayerNorm, num_heads=16, num_modalities=6, qkv_bias=True, tokens_per_modality=256, ) self.input_norm = torch.nn.Sequential( Normalizer(), nn.InstanceNorm2d(None, affine=False, track_running_stats=False), ) self.fourier_loss_weight = config.fourier_loss_weight self.mask_fourier_loss = config.mask_fourier_loss self.return_channelwise_embeddings = config.return_channelwise_embeddings self.tokens_per_channel = 256 # hardcode the number of tokens per channel since we are patch16 crop 256 # loss stuff self.loss = torch.nn.MSELoss(reduction="none") self.fourier_loss = FourierLoss(num_multimodal_modalities=6) if self.fourier_loss_weight > 0 and self.fourier_loss is None: raise ValueError( "FourierLoss weight is activated but no fourier_loss was defined in constructor" ) elif self.fourier_loss_weight >= 1: raise ValueError( "FourierLoss weight is too large to do mixing factor, weight should be < 1" ) self.patch_size = int(self.encoder.vit_backbone.patch_embed.patch_size[0]) # projection layer between the encoder and decoder self.encoder_decoder_proj = nn.Linear( self.encoder.embed_dim, self.decoder.embed_dim, bias=True ) self.decoder_pred = nn.Linear( self.decoder.embed_dim, self.patch_size**2 * (1 if self.encoder.channel_agnostic else self.in_chans), bias=True, ) # linear layer from decoder embedding to input dims # overwrite decoder pos embeddings based on encoder params self.decoder.pos_embeddings = generate_2d_sincos_pos_embeddings( # type: ignore[assignment] self.decoder.embed_dim, length=self.encoder.vit_backbone.patch_embed.grid_size[0], use_class_token=self.encoder.vit_backbone.cls_token is not None, num_modality=( self.decoder.num_modalities if self.encoder.channel_agnostic else 1 ), ) if config.use_MAE_weight_init: w = self.encoder.vit_backbone.patch_embed.proj.weight.data torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) torch.nn.init.normal_(self.encoder.vit_backbone.cls_token, std=0.02) torch.nn.init.normal_(self.decoder.mask_token, std=0.02) self.apply(self._MAE_init_weights) def setup(self, stage: str) -> None: super().setup(stage) def _MAE_init_weights(self, m): if isinstance(m, nn.Linear): torch.nn.init.xavier_uniform_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) @staticmethod def decode_to_reconstruction( encoder_latent: torch.Tensor, ind_restore: torch.Tensor, proj: torch.nn.Module, decoder: MAEDecoder | CAMAEDecoder, pred: torch.nn.Module, ) -> torch.Tensor: """Feed forward the encoder latent through the decoders necessary projections and transformations.""" decoder_latent_projection = proj( encoder_latent ) # projection from encoder.embed_dim to decoder.embed_dim decoder_tokens = decoder.forward_masked( decoder_latent_projection, ind_restore ) # decoder.embed_dim output predicted_reconstruction = pred( decoder_tokens ) # linear projection to input dim return predicted_reconstruction[:, 1:, :] # drop class token def forward( self, imgs: torch.Tensor, constant_noise: Union[torch.Tensor, None] = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: imgs = self.input_norm(imgs) latent, mask, ind_restore = self.encoder.forward_masked( imgs, self.mask_ratio, constant_noise ) # encoder blocks reconstruction = self.decode_to_reconstruction( latent, ind_restore, self.encoder_decoder_proj, self.decoder, self.decoder_pred, ) return latent, reconstruction, mask def compute_MAE_loss( self, reconstruction: torch.Tensor, img: torch.Tensor, mask: torch.Tensor, ) -> Tuple[torch.Tensor, Dict[str, float]]: """Computes final loss and returns specific values of component losses for metric reporting.""" loss_dict = {} img = self.input_norm(img) target_flattened = flatten_images( img, patch_size=self.patch_size, channel_agnostic=self.encoder.channel_agnostic, ) loss: torch.Tensor = self.loss( reconstruction, target_flattened ) # should be with MSE or MAE (L1) with reduction='none' loss = loss.mean( dim=-1 ) # average over embedding dim -> mean loss per patch (N,L) loss = (loss * mask).sum() / mask.sum() # mean loss on masked patches only loss_dict[self.RECON_LOSS] = loss.item() # compute fourier loss if self.fourier_loss_weight > 0: floss: torch.Tensor = self.fourier_loss(reconstruction, target_flattened) if not self.mask_fourier_loss: floss = floss.mean() else: floss = floss.mean(dim=-1) floss = (floss * mask).sum() / mask.sum() loss_dict[self.FOURIER_LOSS] = floss.item() # here we use a mixing factor to keep the loss magnitude appropriate with fourier if self.fourier_loss_weight > 0: loss = (1 - self.fourier_loss_weight) * loss + ( self.fourier_loss_weight * floss ) return loss, loss_dict def training_step(self, batch: TensorDict, batch_idx: int) -> TensorDict: img = batch["pixels"] latent, reconstruction, mask = self(img.clone()) full_loss, loss_dict = self.compute_MAE_loss(reconstruction, img.float(), mask) return { "loss": full_loss, **loss_dict, # type: ignore[dict-item] } def validation_step(self, batch: TensorDict, batch_idx: int) -> TensorDict: return self.training_step(batch, batch_idx) def update_metrics(self, outputs: TensorDict, batch: TensorDict) -> None: self.metrics["lr"].update(value=self.lr_scheduler.get_last_lr()) for key, value in outputs.items(): if key.endswith("loss"): self.metrics[key].update(value) def on_validation_batch_end( # type: ignore[override] self, outputs: TensorDict, batch: TensorDict, batch_idx: int, dataloader_idx: int = 0, ) -> None: super().on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx) def predict(self, imgs: torch.Tensor) -> torch.Tensor: imgs = self.input_norm(imgs) X = self.encoder.vit_backbone.forward_features( imgs ) # 3d tensor N x num_tokens x dim if self.return_channelwise_embeddings: N, _, d = X.shape num_channels = imgs.shape[1] X_reshaped = X[:, 1:, :].view(N, num_channels, self.tokens_per_channel, d) pooled_segments = X_reshaped.mean( dim=2 ) # Resulting shape: (N, num_channels, d) latent = pooled_segments.view(N, num_channels * d).contiguous() else: latent = X[:, 1:, :].mean(dim=1) # 1 + 256 * C tokens return latent def save_pretrained(self, save_directory: str, **kwargs): filename = kwargs.pop("filename", "model.safetensors") modelpath = f"{save_directory}/{filename}" self.config.save_pretrained(save_directory) torch.save({"state_dict": self.state_dict()}, modelpath) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): filename = kwargs.pop("filename", "model.safetensors") config = MAEConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) modelpath = cached_file(pretrained_model_name_or_path, filename=filename) state_dict = torch.load(modelpath, map_location="cpu") model = cls(config) model.load_state_dict(state_dict["state_dict"]) return model