Spaces:
Sleeping
Sleeping
# type: ignore | |
# Copyright (c) IBM Corp. 2024. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# -------------------------------------------------------- | |
# References: | |
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm | |
# transformers: https://github.com/huggingface/transformers | |
# -------------------------------------------------------- | |
import logging | |
from functools import partial | |
from pathlib import Path | |
from typing import List, Tuple | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import yaml | |
from einops import rearrange, repeat | |
from timm.layers import to_2tuple | |
from timm.models.vision_transformer import Block | |
class PrithviWrapper(nn.Module): | |
# we assume any data passed to this wrapper | |
# will contain S2 data with the following channels | |
INPUT_S2_BAND_ORDERING = [ | |
"B01", | |
"B02", | |
"B03", | |
"B04", | |
"B05", | |
"B06", | |
"B07", | |
"B08", | |
"B08A", | |
"B09", | |
"B10", | |
"B11", | |
"B12", | |
] | |
def __init__(self, weights_path: Path, do_pool=True, temporal_pooling: str = "mean"): | |
super().__init__() | |
with (weights_path / "prithvi/config.json").open("r") as f: | |
config = yaml.safe_load(f)["pretrained_cfg"] | |
config["num_frames"] = 1 | |
self.model = PrithviMAE(**config) | |
state_dict = torch.load(weights_path / "prithvi/Prithvi_EO_V2_300M.pt", map_location="cpu") | |
# discard fixed pos_embedding weight, following | |
# https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M/blob/e4aabdc440c8ee703a749def8af5bf4700dee35b/inference.py#L362 | |
for k in list(state_dict.keys()): | |
if "pos_embed" in k: | |
del state_dict[k] | |
self.model.load_state_dict(state_dict, strict=False) | |
self.image_resolution = config["img_size"] | |
self.grid_size = int(config["img_size"] // config["patch_size"][-1]) | |
self.bands = config["bands"] | |
self.inputs_to_prithvi = [self.INPUT_S2_BAND_ORDERING.index(b) for b in self.bands] | |
self.do_pool = do_pool | |
if temporal_pooling not in ["mean", "max"]: | |
raise ValueError( | |
f"Expected temporal_pooling to be in ['mean', 'max'], got {temporal_pooling}" | |
) | |
self.temporal_pooling = temporal_pooling | |
self.dim = config["embed_dim"] | |
def resize(self, images): | |
images = F.interpolate( | |
images, | |
size=(self.image_resolution, self.image_resolution), | |
mode="bilinear", | |
align_corners=False, | |
) | |
return images | |
def preproccess(self, images): | |
if len(images.shape) == 5: | |
# take the mean along the temporal dimension | |
images = torch.mean(images, dim=2) | |
images = rearrange(images, "b h w c -> b c h w") | |
assert images.shape[1] == 13 | |
images = images[:, self.inputs_to_prithvi, :, :] | |
images = self.resize(images) # (bsz, C, H, W) | |
return repeat(images, "b c h w -> b c t h w", t=1) | |
def forward(self, s2=None, s1=None, months=None): | |
if s2 is None: | |
raise ValueError("S2 can't be None for Prithvi") | |
if len(s2.shape) == 5: | |
outputs_l: List[torch.Tensor] = [] | |
for timestep in range(s2.shape[3]): | |
image = self.preproccess(s2[:, :, :, timestep]) | |
output = self.model.forward_features(image)[-1] | |
# following | |
# https://github.com/IBM/terratorch/blob/main/terratorch/models/backbones/prithvi_mae.py#L449 | |
# we remove the class token. This is also the approach they | |
# take for classification: https://github.com/IBM/terratorch/blob/main/terratorch/models/scalar_output_model.py#L19 | |
output = output[:, 1:, :] | |
# output shape: (bsz, num_tokens, dim) | |
if self.do_pool: | |
output = output.mean(dim=1) | |
outputs_l.append(output) | |
outputs_t = torch.stack(outputs_l, dim=-1) # b h w d t | |
if self.temporal_pooling == "mean": | |
return outputs_t.mean(dim=-1) | |
else: | |
return torch.amax(outputs_t, dim=-1) | |
else: | |
s2 = self.preproccess(s2) | |
output = self.model.forward_features(s2)[-1] | |
output = output[:, 1:, :] | |
if self.do_pool: | |
return output.mean(dim=1) | |
else: | |
return output | |
def get_3d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False): | |
""" | |
Create 3D sin/cos positional embeddings. | |
Args: | |
embed_dim (int): | |
Embedding dimension. | |
grid_size (tuple[int, int, int] | list[int]): | |
The grid depth, height and width. | |
add_cls_token (bool, *optional*, defaults to False): | |
Whether or not to add a classification (CLS) token. | |
Returns: | |
(`torch.FloatTensor` of shape (grid_size[0]*grid_size[1]*grid_size[2], embed_dim) or | |
(1+grid_size[0]*grid_size[1]*grid_size[2], embed_dim): the position embeddings (with or without cls token) | |
""" | |
assert embed_dim % 16 == 0 | |
t_size, h_size, w_size = grid_size | |
w_embed_dim = embed_dim // 16 * 6 | |
h_embed_dim = embed_dim // 16 * 6 | |
t_embed_dim = embed_dim // 16 * 4 | |
w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size)) | |
h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size)) | |
t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size)) | |
w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1)) | |
h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1)) | |
t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0) | |
pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1) | |
if add_cls_token: | |
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) | |
return pos_embed | |
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): | |
""" | |
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) | |
""" | |
if embed_dim % 2 != 0: | |
raise ValueError("embed_dim must be even") | |
omega = np.arange(embed_dim // 2, dtype=float) | |
omega /= embed_dim / 2.0 | |
omega = 1.0 / 10000**omega # (D/2,) | |
pos = pos.reshape(-1) # (M,) | |
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product | |
emb_sin = np.sin(out) # (M, D/2) | |
emb_cos = np.cos(out) # (M, D/2) | |
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) | |
return emb | |
def _get_1d_sincos_embed_from_grid_torch(embed_dim: int, pos: torch.Tensor): | |
"""This is the torch version of *get_1d_sincos_pos_embed_from_grid()*. However, | |
it was modified to cast omega values to pos.dtype which must be float (and not int as in | |
regular positional embeddings). This was required in order to allow for native FSDP mixed | |
precision support: modify omega to appropriate dtype (pos carries the correct float dtype), | |
instead of manually forcing float32. | |
embed_dim: output dimension for each position | |
pos: a list of positions to be encoded: size (M,) - must be float dtype! | |
out: (M, D) | |
""" | |
assert embed_dim % 2 == 0 | |
assert pos.dtype in [torch.float32, torch.float16, torch.bfloat16] | |
omega = torch.arange(embed_dim // 2, dtype=pos.dtype).to(pos.device) | |
omega /= embed_dim / 2.0 | |
omega = 1.0 / 10000**omega # (D/2,) | |
pos = pos.reshape(-1) # (M,) | |
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product | |
emb_sin = torch.sin(out) # (M, D/2) | |
emb_cos = torch.cos(out) # (M, D/2) | |
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) | |
return emb | |
def _init_weights(module): | |
"""Initialize the weights""" | |
if isinstance(module, nn.Linear): | |
nn.init.xavier_uniform_(module.weight) | |
if module.bias is not None: | |
module.bias.data.zero_() | |
elif isinstance(module, nn.LayerNorm): | |
module.bias.data.zero_() | |
module.weight.data.fill_(1.0) | |
class PatchEmbed(nn.Module): | |
"""3D version of timm.models.vision_transformer.PatchEmbed""" | |
def __init__( | |
self, | |
input_size: Tuple[int, int, int] = (1, 224, 224), | |
patch_size: Tuple[int, int, int] = (1, 16, 16), | |
in_chans: int = 3, | |
embed_dim: int = 768, | |
norm_layer: nn.Module | None = None, | |
flatten: bool = True, | |
bias: bool = True, | |
): | |
super().__init__() | |
self.input_size = input_size | |
self.patch_size = patch_size | |
self.grid_size = [s // p for s, p in zip(self.input_size, self.patch_size)] | |
self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2] | |
self.flatten = flatten | |
self.proj = nn.Conv3d( | |
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias | |
) | |
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() | |
def forward(self, x): | |
B, C, T, H, W = x.shape | |
if T / self.patch_size[0] % 1 or H / self.patch_size[1] % 1 or W / self.patch_size[2] % 1: | |
logging.warning( | |
f"Input {x.shape[-3:]} is not divisible by patch size {self.patch_size}." | |
f"The border will be ignored, add backbone_padding for pixel-wise tasks." | |
) | |
x = self.proj(x) | |
if self.flatten: | |
x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C | |
x = self.norm(x) | |
return x | |
class TemporalEncoder(nn.Module): | |
def __init__(self, embed_dim: int, trainable_scale: bool = False): | |
super().__init__() | |
self.embed_dim = embed_dim | |
self.year_embed_dim = embed_dim // 2 | |
self.julian_day_embed_dim = embed_dim - self.year_embed_dim | |
# If trainable, initialize scale with small number | |
if trainable_scale: | |
self.scale = nn.Parameter(torch.full((1,), 0.1)) | |
else: | |
self.register_buffer("scale", torch.ones(1)) | |
def forward(self, temporal_coords: torch.Tensor, tokens_per_frame: int | None = None): | |
""" | |
temporal_coords: year and day-of-year info with shape (B, T, 2). | |
tokens_per_frame: number of tokens for each frame in the sample. If provided, embeddings will be | |
repeated over T dimension, and final shape is (B, T*tokens_per_frame, embed_dim). | |
""" | |
shape = temporal_coords.shape[:2] + (-1,) # B, T, -1 | |
year = _get_1d_sincos_embed_from_grid_torch( | |
self.year_embed_dim, temporal_coords[:, :, 0].flatten() | |
).reshape(shape) | |
julian_day = _get_1d_sincos_embed_from_grid_torch( | |
self.julian_day_embed_dim, temporal_coords[:, :, 1].flatten() | |
).reshape(shape) | |
embedding = self.scale * torch.cat([year, julian_day], dim=-1) | |
if tokens_per_frame is not None: | |
embedding = torch.repeat_interleave(embedding, tokens_per_frame, dim=1) | |
return embedding # B, T*tokens_per_frame, embed_dim | |
class LocationEncoder(nn.Module): | |
def __init__(self, embed_dim: int, trainable_scale: bool = False): | |
super().__init__() | |
self.embed_dim = embed_dim | |
self.lat_embed_dim = embed_dim // 2 | |
self.lon_embed_dim = embed_dim - self.lat_embed_dim | |
# If trainable, initialize scale with small number | |
if trainable_scale: | |
self.scale = nn.Parameter(torch.full((1,), 0.1)) | |
else: | |
self.register_buffer("scale", torch.ones(1)) | |
def forward(self, location_coords: torch.Tensor): | |
""" | |
location_coords: lat and lon info with shape (B, 2). | |
""" | |
shape = location_coords.shape[:1] + (1, -1) # B, 1, -1 | |
lat = _get_1d_sincos_embed_from_grid_torch( | |
self.lat_embed_dim, location_coords[:, 0].flatten() | |
).reshape(shape) | |
lon = _get_1d_sincos_embed_from_grid_torch( | |
self.lon_embed_dim, location_coords[:, 1].flatten() | |
).reshape(shape) | |
embedding = self.scale * torch.cat([lat, lon], dim=-1) | |
return embedding # B, 1, embed_dim | |
class PrithviViT(nn.Module): | |
"""Prithvi ViT Encoder""" | |
def __init__( | |
self, | |
img_size: int | Tuple[int, int] = 224, | |
patch_size: int | Tuple[int, int, int] = (1, 16, 16), | |
num_frames: int = 1, | |
in_chans: int = 3, | |
embed_dim: int = 1024, | |
depth: int = 24, | |
num_heads: int = 16, | |
mlp_ratio: float = 4.0, | |
norm_layer: nn.Module = partial(torch.nn.LayerNorm, eps=1e-6), | |
coords_encoding: List[str] | None = None, | |
coords_scale_learn: bool = False, | |
encoder_only: bool = True, # needed for timm | |
**kwargs, | |
): | |
super().__init__() | |
self.feature_info = [] | |
self.encoder_only = encoder_only | |
self.in_chans = in_chans | |
self.num_frames = num_frames | |
self.embed_dim = embed_dim | |
self.img_size = to_2tuple(img_size) | |
if isinstance(patch_size, int): | |
patch_size = (1, patch_size, patch_size) | |
# 3D patch embedding | |
self.patch_embed = PatchEmbed( | |
input_size=(num_frames,) + self.img_size, | |
patch_size=patch_size, | |
in_chans=in_chans, | |
embed_dim=embed_dim, | |
) | |
# Optional temporal and location embedding | |
coords_encoding = coords_encoding or [] | |
self.temporal_encoding = "time" in coords_encoding | |
self.location_encoding = "location" in coords_encoding | |
if self.temporal_encoding: | |
assert ( | |
patch_size[0] == 1 | |
), f"With temporal encoding, patch_size[0] must be 1, received {patch_size[0]}" | |
self.temporal_embed_enc = TemporalEncoder(embed_dim, coords_scale_learn) | |
if self.location_encoding: | |
self.location_embed_enc = LocationEncoder(embed_dim, coords_scale_learn) | |
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |
self.register_buffer( | |
"pos_embed", torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim) | |
) | |
# Transformer layers | |
self.blocks = [] | |
for i in range(depth): | |
self.blocks.append( | |
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) | |
) | |
self.feature_info.append( | |
{ | |
"num_chs": embed_dim * self.patch_embed.patch_size[0], | |
"reduction": 1, | |
"module": f"blocks.{i}", | |
} | |
) | |
self.blocks = nn.ModuleList(self.blocks) | |
self.norm = norm_layer(embed_dim) | |
self.initialize_weights() | |
def initialize_weights(self): | |
# initialize (and freeze) position embeddings by sin-cos embedding | |
pos_embed = get_3d_sincos_pos_embed( | |
self.pos_embed.shape[-1], self.patch_embed.grid_size, add_cls_token=True | |
) | |
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) | |
# initialize patch_embeddings like nn.Linear (instead of nn.Conv2d) | |
w = self.patch_embed.proj.weight.data | |
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) | |
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) | |
torch.nn.init.normal_(self.cls_token, std=0.02) | |
self.apply(_init_weights) | |
def random_masking(self, sequence, mask_ratio, noise=None): | |
""" | |
Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random | |
noise. | |
Args: | |
sequence (`torch.FloatTensor` of shape `(batch_size, sequence_length, dim)`) | |
mask_ratio (float): mask ratio to use. | |
noise (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*) which is | |
mainly used for testing purposes to control randomness and maintain the reproducibility | |
""" | |
batch_size, seq_length, dim = sequence.shape | |
len_keep = int(seq_length * (1 - mask_ratio)) | |
if noise is None: | |
noise = torch.rand(batch_size, seq_length, device=sequence.device) # noise in [0, 1] | |
# sort noise for each sample | |
ids_shuffle = torch.argsort(noise, dim=1).to( | |
sequence.device | |
) # ascend: small is keep, large is remove | |
ids_restore = torch.argsort(ids_shuffle, dim=1).to(sequence.device) | |
# keep the first subset | |
ids_keep = ids_shuffle[:, :len_keep] | |
sequence_unmasked = torch.gather( | |
sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim) | |
) | |
# generate the binary mask: 0 is keep, 1 is remove | |
mask = torch.ones([batch_size, seq_length], device=sequence.device) | |
mask[:, :len_keep] = 0 | |
# unshuffle to get the binary mask | |
mask = torch.gather(mask, dim=1, index=ids_restore) | |
return sequence_unmasked, mask, ids_restore | |
def _get_pos_embed(self, x): | |
t, h, w = x.shape[-3:] | |
pos_embed = ( | |
torch.from_numpy( | |
get_3d_sincos_pos_embed( | |
self.embed_dim, | |
( | |
t // self.patch_embed.patch_size[0], | |
h // self.patch_embed.patch_size[1], | |
w // self.patch_embed.patch_size[2], | |
), | |
add_cls_token=True, | |
) | |
) | |
.float() | |
.unsqueeze(0) | |
.to(x) | |
) | |
return pos_embed | |
def forward( | |
self, | |
x: torch.Tensor, | |
temporal_coords: None | torch.Tensor = None, | |
location_coords: None | torch.Tensor = None, | |
mask_ratio=0.0, | |
): | |
if x.shape[-3:] != self.patch_embed.input_size: | |
# changed input size | |
pos_embed = self._get_pos_embed(x) | |
else: | |
pos_embed = self.pos_embed | |
# embed patches | |
x = self.patch_embed(x) | |
# add pos embed w/o cls token | |
x = x + pos_embed[:, 1:, :] | |
if self.temporal_encoding: | |
num_tokens_per_frame = x.shape[1] // self.num_frames | |
temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame) | |
x = x + temporal_encoding | |
if self.location_encoding: | |
location_encoding = self.location_embed_enc(location_coords) | |
x = x + location_encoding | |
# masking: length -> length * mask_ratio | |
x, mask, ids_restore = self.random_masking(x, mask_ratio) | |
# append cls token | |
cls_token = self.cls_token + pos_embed[:, :1, :] | |
cls_tokens = cls_token.expand(x.shape[0], -1, -1) | |
x = torch.cat((cls_tokens, x), dim=1) | |
# apply Transformer blocks | |
for block in self.blocks: | |
x = block(x) | |
x = self.norm(x) | |
return x, mask, ids_restore | |
def forward_features( | |
self, | |
x: torch.Tensor, | |
temporal_coords: None | torch.Tensor = None, | |
location_coords: None | torch.Tensor = None, | |
) -> list[torch.Tensor]: | |
if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1: | |
# add time dim | |
x = x.unsqueeze(2) | |
if x.shape[-3:] != self.patch_embed.input_size: | |
pos_embed = self._get_pos_embed(x) | |
else: | |
pos_embed = self.pos_embed | |
# embed patches | |
x = self.patch_embed(x) | |
# add pos embed w/o cls token | |
x = x + pos_embed[:, 1:, :] | |
if self.temporal_encoding: | |
num_tokens_per_frame = x.shape[1] // self.patch_embed.num_frames | |
temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame) | |
x = x + temporal_encoding | |
if self.location_encoding: | |
location_encoding = self.location_embed_enc(location_coords) | |
x = x + location_encoding | |
# append cls token | |
cls_token = self.cls_token + pos_embed[:, :1, :] | |
cls_tokens = cls_token.expand(x.shape[0], -1, -1) | |
x = torch.cat((cls_tokens, x), dim=1) | |
# apply Transformer blocks | |
out = [] | |
for block in self.blocks: | |
x = block(x) | |
out.append(x.clone()) | |
x = self.norm(x) | |
out[-1] = x | |
return out | |
def prepare_features_for_image_model(self, features: list[torch.Tensor]) -> list[torch.Tensor]: | |
out = [] | |
effective_time_dim = self.patch_embed.input_size[0] // self.patch_embed.patch_size[0] | |
for x in features: | |
x_no_token = x[:, 1:, :] | |
number_of_tokens = x_no_token.shape[1] | |
tokens_per_timestep = number_of_tokens // effective_time_dim | |
h = int(np.sqrt(tokens_per_timestep)) | |
encoded = rearrange( | |
x_no_token, | |
"batch (t h w) e -> batch (t e) h w", | |
e=self.embed_dim, | |
t=effective_time_dim, | |
h=h, | |
) | |
out.append(encoded) | |
return out | |
class MAEDecoder(nn.Module): | |
"""Transformer Decoder used in the Prithvi MAE""" | |
def __init__( | |
self, | |
patch_size: int | Tuple[int, int, int] = (1, 16, 16), | |
grid_size: List[int] | Tuple[int, int, int] = (3, 14, 14), | |
in_chans: int = 3, | |
encoder_embed_dim: int = 1024, | |
decoder_embed_dim: int = 512, | |
depth: int = 8, | |
num_heads: int = 16, | |
mlp_ratio: float = 4.0, | |
norm_layer: nn.Module = nn.LayerNorm, | |
coords_encoding: List[str] | None = None, | |
coords_scale_learn: bool = False, | |
): | |
super().__init__() | |
self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True) | |
self.decoder_embed_dim = decoder_embed_dim | |
self.grid_size = grid_size | |
if isinstance(patch_size, int): | |
patch_size = (1, patch_size, patch_size) | |
self.patch_size = patch_size | |
self.num_frames = self.grid_size[0] * patch_size[0] | |
num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2] | |
# Optional temporal and location embedding | |
coords_encoding = coords_encoding or [] | |
self.temporal_encoding = "time" in coords_encoding | |
self.location_encoding = "location" in coords_encoding | |
if self.temporal_encoding: | |
self.temporal_embed_dec = TemporalEncoder(decoder_embed_dim, coords_scale_learn) | |
if self.location_encoding: | |
self.location_embed_dec = LocationEncoder(decoder_embed_dim, coords_scale_learn) | |
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) | |
self.register_buffer( | |
"decoder_pos_embed", torch.zeros(1, num_patches + 1, decoder_embed_dim) | |
) | |
self.decoder_blocks = nn.ModuleList( | |
[ | |
Block( | |
decoder_embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer | |
) | |
for _ in range(depth) | |
] | |
) | |
self.decoder_norm = norm_layer(decoder_embed_dim) | |
self.decoder_pred = nn.Linear( | |
decoder_embed_dim, patch_size[0] * patch_size[1] * patch_size[2] * in_chans, bias=True | |
) | |
self.initialize_weights() | |
def initialize_weights(self): | |
# initialize (and freeze) position embeddings by sin-cos embedding | |
decoder_pos_embed = get_3d_sincos_pos_embed( | |
self.decoder_pos_embed.shape[-1], self.grid_size, add_cls_token=True | |
) | |
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) | |
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) | |
torch.nn.init.normal_(self.mask_token, std=0.02) | |
self.apply(_init_weights) | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
ids_restore: torch.Tensor, | |
temporal_coords: None | torch.Tensor = None, | |
location_coords: None | torch.Tensor = None, | |
input_size: list[int] = None, | |
): | |
# embed tokens | |
x = self.decoder_embed(hidden_states) | |
t, h, w = input_size[-3:] | |
decoder_pos_embed = torch.from_numpy( | |
get_3d_sincos_pos_embed( | |
self.decoder_embed_dim, | |
( | |
t // self.patch_size[0], | |
h // self.patch_size[1], | |
w // self.patch_size[2], | |
), | |
add_cls_token=True, | |
) | |
).to(x) | |
# append mask tokens to sequence | |
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) | |
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token | |
# unshuffle | |
x_ = torch.gather( | |
x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x_.device) | |
) | |
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token | |
# add pos embed | |
x = x + decoder_pos_embed | |
# remove cls token | |
x_ = x[:, 1:, :] | |
if self.temporal_encoding: | |
num_tokens_per_frame = x_.shape[1] // self.num_frames | |
temporal_encoding = self.temporal_embed_dec(temporal_coords, num_tokens_per_frame) | |
# Add temporal encoding w/o cls token | |
x_ = x_ + temporal_encoding | |
if self.location_encoding: | |
location_encoding = self.location_embed_dec(location_coords) | |
# Add location encoding w/o cls token | |
x_ = x_ + location_encoding | |
# append cls token | |
x = torch.cat([x[:, :1, :], x_], dim=1) | |
# apply Transformer layers (blocks) | |
for block in self.decoder_blocks: | |
x = block(x) | |
x = self.decoder_norm(x) | |
# predictor projection | |
pred = self.decoder_pred(x) | |
# remove cls token | |
pred = pred[:, 1:, :] | |
return pred | |
class PrithviMAE(nn.Module): | |
"""Prithvi Masked Autoencoder""" | |
def __init__( | |
self, | |
img_size: int | Tuple[int, int] = 224, | |
patch_size: int | Tuple[int, int, int] = (1, 16, 16), | |
num_frames: int = 3, | |
in_chans: int = 3, | |
embed_dim: int = 1024, | |
depth: int = 24, | |
num_heads: int = 16, | |
decoder_embed_dim: int = 512, | |
decoder_depth: int = 8, | |
decoder_num_heads: int = 16, | |
mlp_ratio: float = 4.0, | |
norm_layer: nn.Module = partial(torch.nn.LayerNorm, eps=1e-6), | |
norm_pix_loss: bool = False, | |
coords_encoding: List[str] | None = None, | |
coords_scale_learn: bool = False, | |
encoder_only: bool = False, | |
**kwargs, | |
): | |
super().__init__() | |
self.encoder = PrithviViT( | |
img_size=img_size, | |
num_frames=num_frames, | |
patch_size=patch_size, | |
in_chans=in_chans, | |
embed_dim=embed_dim, | |
depth=depth, | |
num_heads=num_heads, | |
mlp_ratio=mlp_ratio, | |
norm_layer=norm_layer, | |
coords_encoding=coords_encoding, | |
coords_scale_learn=coords_scale_learn, | |
) | |
self.encoder_only = encoder_only | |
if not encoder_only: | |
self.decoder = MAEDecoder( | |
patch_size=patch_size, | |
grid_size=self.encoder.patch_embed.grid_size, | |
in_chans=in_chans, | |
encoder_embed_dim=embed_dim, | |
decoder_embed_dim=decoder_embed_dim, | |
depth=decoder_depth, | |
num_heads=decoder_num_heads, | |
mlp_ratio=mlp_ratio, | |
norm_layer=norm_layer, | |
coords_encoding=coords_encoding, | |
coords_scale_learn=coords_scale_learn, | |
) | |
else: | |
self.decoder = nn.Identity() | |
self.norm_pix_loss = norm_pix_loss | |
def patchify(self, pixel_values): | |
""" | |
Args: | |
pixel_values (torch.FloatTensor of shape `(batch_size, num_channels, time, height, width)`): | |
Pixel values. | |
Returns: | |
torch.FloatTensor of shape `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`: | |
Patchified pixel values. | |
""" | |
patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size | |
num_channels = self.encoder.in_chans | |
# patchify | |
patchified_pixel_values = rearrange( | |
pixel_values, | |
"b c (t s) (h p) (w q) -> b (t h w) (s p q c)", | |
c=num_channels, | |
s=patch_size_t, | |
p=patch_size_h, | |
q=patch_size_w, | |
) | |
return patchified_pixel_values | |
def unpatchify(self, patchified_pixel_values, image_size: Tuple[int, int] | None = None): | |
""" | |
Args: | |
patchified_pixel_values (`torch.FloatTensor` of shape | |
`(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`: | |
Patchified pixel values. | |
image_size (`Tuple[int, int]`, *optional*): | |
Original image size. | |
Returns: | |
`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`: | |
Pixel values. | |
""" | |
patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size | |
image_size = to_2tuple(image_size) if image_size is not None else self.encoder.img_size | |
original_height, original_width = image_size | |
num_patches_h = original_height // patch_size_h | |
num_patches_w = original_width // patch_size_w | |
num_channels = self.encoder.in_chans | |
pixel_values = rearrange( | |
patchified_pixel_values, | |
"b (t h w) (s p q c) -> b c (t s) (h p) (w q)", | |
c=num_channels, | |
h=num_patches_h, | |
w=num_patches_w, | |
s=patch_size_t, | |
p=patch_size_h, | |
q=patch_size_w, | |
) | |
return pixel_values | |
def forward_loss(self, pixel_values, pred, mask): | |
""" | |
Args: | |
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, time, height, width)`): | |
Pixel values. | |
pred (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`: | |
Predicted pixel values. | |
mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): | |
Tensor indicating which patches are masked (1) and which are not (0). | |
Returns: | |
`torch.FloatTensor`: Pixel reconstruction loss. | |
""" | |
target = self.patchify(pixel_values) | |
if self.norm_pix_loss: | |
mean = target.mean(dim=-1, keepdim=True) | |
var = target.var(dim=-1, keepdim=True) | |
target = (target - mean) / (var + 1.0e-6) ** 0.5 | |
loss = (pred - target) ** 2 | |
loss = loss.mean(dim=-1) # [N, L], mean loss per patch | |
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches | |
return loss | |
def forward( | |
self, | |
pixel_values: torch.Tensor, | |
temporal_coords: None | torch.Tensor = None, | |
location_coords: None | torch.Tensor = None, | |
mask_ratio: float = 0.75, | |
): | |
if len(pixel_values.shape) == 4 and self.encoder.patch_embed.input_size[0] == 1: | |
# add time dim | |
pixel_values = pixel_values.unsqueeze(2) | |
latent, mask, ids_restore = self.encoder( | |
pixel_values, temporal_coords, location_coords, mask_ratio | |
) | |
pred = self.decoder( | |
latent, ids_restore, temporal_coords, location_coords, input_size=pixel_values.shape | |
) | |
loss = self.forward_loss(pixel_values, pred, mask) | |
return loss, pred, mask | |
def forward_features( | |
self, | |
x: torch.Tensor, | |
temporal_coords: None | torch.Tensor = None, | |
location_coords: None | torch.Tensor = None, | |
) -> List[torch.Tensor]: | |
return self.encoder.forward_features(x, temporal_coords, location_coords) | |