Spaces:
Sleeping
Sleeping
""" | |
https://github.com/nasaharvest/presto/blob/main/single_file_presto.py | |
""" | |
import math | |
from collections import OrderedDict | |
from copy import deepcopy | |
from pathlib import Path | |
from typing import Optional, Tuple, Union, cast | |
import numpy as np | |
import torch | |
from einops import repeat | |
from torch import nn | |
from torch.jit import Final | |
from torch.nn import functional as F | |
from src.utils import device | |
PRESTO_S2_BANDS = ["B2", "B3", "B4", "B5", "B6", "B7", "B8", "B8A", "B9", "B11", "B12"] | |
PRESTO_S1_BANDS = ["VV", "VH"] | |
PRESTO_BANDS = ( | |
PRESTO_S1_BANDS | |
+ PRESTO_S2_BANDS | |
+ ["temperature_2m", "total_precipitation", "elevation", "slope", "NDVI"] | |
) | |
DEFAULT_MODEL_PATH = Path(__file__).parent / "default_model.pt" | |
# used in normalization | |
PRESTO_ADD_BY = [ | |
25.0, | |
25.0, | |
float(0.0), | |
float(0.0), | |
float(0.0), | |
float(0.0), | |
float(0.0), | |
float(0.0), | |
float(0.0), | |
float(0.0), | |
float(0.0), | |
float(0.0), | |
float(0.0), | |
-272.15, | |
0.0, | |
float(0.0), | |
float(0.0), | |
float(0.0), | |
] | |
PRESTO_DIV_BY = [ | |
25.0, | |
25.0, | |
float(1e4), | |
float(1e4), | |
float(1e4), | |
float(1e4), | |
float(1e4), | |
float(1e4), | |
float(1e4), | |
float(1e4), | |
float(1e4), | |
float(1e4), | |
float(1e4), | |
35.0, | |
0.03, | |
2000.0, | |
50.0, | |
float(1.0), | |
] | |
BANDS_GROUPS_IDX = OrderedDict( | |
[ | |
("S1", [0, 1]), | |
("S2_RGB", [2, 3, 4]), | |
("S2_Red_Edge", [5, 6, 7]), | |
("S2_NIR_10m", [8]), | |
("S2_NIR_20m", [9]), | |
("S2_SWIR", [10, 11]), | |
("ERA5", [12, 13]), | |
("SRTM", [14, 15]), | |
("NDVI", [16]), | |
] | |
) | |
NUM_DYNAMIC_WORLD_CLASSES = 9 | |
class Attention(nn.Module): | |
# https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py | |
fast_attn: Final[bool] | |
def __init__( | |
self, | |
dim, | |
num_heads=8, | |
qkv_bias=False, | |
qk_norm=False, | |
attn_drop=0.0, | |
proj_drop=0.0, | |
norm_layer=nn.LayerNorm, | |
): | |
super().__init__() | |
assert dim % num_heads == 0, "dim should be divisible by num_heads" | |
self.num_heads = num_heads | |
self.head_dim = dim // num_heads | |
self.scale = self.head_dim**-0.5 | |
self.fast_attn = hasattr(torch.nn.functional, "scaled_dot_product_attention") # FIXME | |
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() | |
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() | |
self.attn_drop = nn.Dropout(attn_drop) | |
self.proj = nn.Linear(dim, dim) | |
self.proj_drop = nn.Dropout(proj_drop) | |
def forward(self, x): | |
B, N, C = x.shape | |
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) | |
q, k, v = qkv.unbind(0) | |
q, k = self.q_norm(q), self.k_norm(k) | |
if self.fast_attn: | |
x = F.scaled_dot_product_attention( | |
q, | |
k, | |
v, | |
dropout_p=self.attn_drop.p, | |
) | |
else: | |
q = q * self.scale | |
attn = q @ k.transpose(-2, -1) | |
attn = attn.softmax(dim=-1) | |
attn = self.attn_drop(attn) | |
x = attn @ v | |
x = x.transpose(1, 2).reshape(B, N, C) | |
x = self.proj(x) | |
x = self.proj_drop(x) | |
return x | |
class Mlp(nn.Module): | |
"""MLP as used in Vision Transformer, MLP-Mixer and related networks""" | |
def __init__( | |
self, | |
in_features, | |
hidden_features=None, | |
out_features=None, | |
act_layer=nn.GELU, | |
bias=True, | |
drop=0.0, | |
): | |
super().__init__() | |
out_features = out_features or in_features | |
hidden_features = hidden_features or in_features | |
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) | |
self.act = act_layer() | |
self.drop1 = nn.Dropout(drop) | |
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) | |
self.drop2 = nn.Dropout(drop) | |
def forward(self, x): | |
x = self.fc1(x) | |
x = self.act(x) | |
x = self.drop1(x) | |
x = self.fc2(x) | |
x = self.drop2(x) | |
return x | |
class LayerScale(nn.Module): | |
def __init__(self, dim, init_values=1e-5, inplace=False): | |
super().__init__() | |
self.inplace = inplace | |
self.gamma = nn.Parameter(init_values * torch.ones(dim)) | |
def forward(self, x): | |
return x.mul_(self.gamma) if self.inplace else x * self.gamma | |
class Block(nn.Module): | |
def __init__( | |
self, | |
dim, | |
num_heads, | |
mlp_ratio=4.0, | |
qkv_bias=False, | |
qk_norm=False, | |
drop=0.0, | |
attn_drop=0.0, | |
init_values=None, | |
act_layer=nn.GELU, | |
norm_layer=nn.LayerNorm, | |
): | |
super().__init__() | |
self.norm1 = norm_layer(dim) | |
self.attn = Attention( | |
dim, | |
num_heads=num_heads, | |
qkv_bias=qkv_bias, | |
qk_norm=qk_norm, | |
attn_drop=attn_drop, | |
proj_drop=drop, | |
norm_layer=norm_layer, | |
) | |
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() | |
self.norm2 = norm_layer(dim) | |
self.mlp = Mlp( | |
in_features=dim, | |
hidden_features=int(dim * mlp_ratio), | |
act_layer=act_layer, | |
drop=drop, | |
) | |
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() | |
def forward(self, x): | |
x = x + self.ls1(self.attn(self.norm1(x))) | |
x = x + self.ls2(self.mlp(self.norm2(x))) | |
return x | |
def get_sinusoid_encoding_table(positions, d_hid, T=1000): | |
"""Sinusoid position encoding table | |
positions: int or list of integer, if int range(positions)""" | |
if isinstance(positions, int): | |
positions = list(range(positions)) | |
def cal_angle(position, hid_idx): | |
return position / np.power(T, 2 * (hid_idx // 2) / d_hid) | |
def get_posi_angle_vec(position): | |
return [cal_angle(position, hid_j) for hid_j in range(d_hid)] | |
sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in positions]) | |
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i | |
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 | |
if torch.cuda.is_available(): | |
return torch.FloatTensor(sinusoid_table).cuda() | |
else: | |
return torch.FloatTensor(sinusoid_table) | |
def get_month_encoding_table(d_hid): | |
"""Sinusoid month encoding table, for 12 months indexed from 0-11""" | |
assert d_hid % 2 == 0 | |
angles = np.arange(0, 13) / (12 / (2 * np.pi)) | |
sin_table = np.sin(np.stack([angles for _ in range(d_hid // 2)], axis=-1)) | |
cos_table = np.cos(np.stack([angles for _ in range(d_hid // 2)], axis=-1)) | |
month_table = np.concatenate([sin_table[:-1], cos_table[:-1]], axis=-1) | |
if torch.cuda.is_available(): | |
return torch.FloatTensor(month_table).cuda() | |
else: | |
return torch.FloatTensor(month_table) | |
def month_to_tensor( | |
month: Union[torch.Tensor, int], batch_size: int, seq_len: int, device: torch.device | |
): | |
if isinstance(month, int): | |
assert cast(int, month) < 12 | |
else: | |
assert max(cast(torch.Tensor, month.flatten())) < 12 | |
if isinstance(month, int): | |
# >>> torch.fmod(torch.tensor([9., 10, 11, 12, 13, 14]), 12) | |
# tensor([ 9., 10., 11., 0., 1., 2.]) | |
month = ( | |
torch.fmod(torch.arange(month, month + seq_len, dtype=torch.long), 12) | |
.expand(batch_size, seq_len) | |
.to(device) | |
) | |
elif len(month.shape) == 1: | |
month = torch.stack( | |
[torch.fmod(torch.arange(m, m + seq_len, dtype=torch.long), 12) for m in month] | |
).to(device) | |
return month | |
class Encoder(nn.Module): | |
def __init__( | |
self, | |
embedding_size: int = 128, | |
channel_embed_ratio: float = 0.25, | |
month_embed_ratio: float = 0.25, | |
depth=2, | |
mlp_ratio=2, | |
num_heads=8, | |
max_sequence_length=24, | |
): | |
super().__init__() | |
self.band_groups = BANDS_GROUPS_IDX | |
self.embedding_size = embedding_size | |
# this is used for the channel embedding | |
self.band_group_to_idx = { | |
group_name: idx for idx, (group_name, _) in enumerate(self.band_groups.items()) | |
} | |
self.band_group_to_idx["dynamic_world"] = max(self.band_group_to_idx.values()) + 1 | |
self.eo_patch_embed = nn.ModuleDict( | |
{ | |
group_name: nn.Linear(len(group), embedding_size) | |
for group_name, group in self.band_groups.items() | |
} | |
) | |
self.dw_embed = nn.Embedding( | |
num_embeddings=NUM_DYNAMIC_WORLD_CLASSES + 1, embedding_dim=embedding_size | |
) | |
self.latlon_embed = nn.Linear(3, embedding_size) | |
self.blocks = nn.ModuleList( | |
[ | |
Block( | |
embedding_size, | |
num_heads, | |
mlp_ratio, | |
qkv_bias=True, | |
norm_layer=nn.LayerNorm, | |
) | |
for _ in range(depth) | |
] | |
) | |
self.norm = nn.LayerNorm(embedding_size) | |
# the positional + monthly + channel embedding | |
self.max_sequence_length = max_sequence_length | |
pos_embedding_size = int(embedding_size * (1 - (channel_embed_ratio + month_embed_ratio))) | |
channel_embedding_size = int(embedding_size * channel_embed_ratio) | |
month_embedding_size = int(embedding_size * month_embed_ratio) | |
self.pos_embed = nn.Parameter( | |
torch.zeros(1, max_sequence_length, pos_embedding_size), requires_grad=False | |
) | |
month_tab = get_month_encoding_table(month_embedding_size) | |
self.month_embed = nn.Embedding.from_pretrained(month_tab, freeze=True) | |
self.channel_embed = nn.Embedding( | |
num_embeddings=len(self.band_groups) + 1, embedding_dim=channel_embedding_size | |
) | |
self.initialize_weights() | |
def initialize_weights(self): | |
pos_embed = get_sinusoid_encoding_table(self.pos_embed.shape[1], self.pos_embed.shape[-1]) | |
self.pos_embed.data.copy_(pos_embed) | |
# initialize nn.Linear and nn.LayerNorm | |
self.apply(self._init_weights) | |
def _init_weights(self, m): | |
if isinstance(m, nn.Linear): | |
# we use xavier_uniform following official JAX ViT: | |
torch.nn.init.xavier_uniform_(m.weight) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.LayerNorm): | |
nn.init.constant_(m.bias, 0) | |
nn.init.constant_(m.weight, 1.0) | |
def cartesian(latlons: torch.Tensor) -> torch.Tensor: | |
with torch.no_grad(): | |
# an embedding is calculated for all timesteps. This is then expanded | |
# for each timestep in the sequence | |
latlon_radians = latlons * math.pi / 180 | |
lats, lons = latlon_radians[:, 0], latlon_radians[:, 1] | |
x = torch.cos(lats) * torch.cos(lons) | |
y = torch.cos(lats) * torch.sin(lons) | |
z = torch.sin(lats) | |
return torch.stack([x, y, z], dim=-1) | |
def mask_tokens(x, mask): | |
summed = mask.sum( | |
dim=(1, 2) | |
) # summed tells me the number of masked elements per batch idx | |
assert summed.max() == summed.min(), f"{summed.max()}, {summed.min()}" | |
batch_size = x.shape[0] | |
removed_elements_per_batch = int(summed.max() / mask.shape[2]) | |
kept_elements_per_batch = x.shape[1] - removed_elements_per_batch | |
embedding_dim = x.shape[-1] | |
# we want the mask to just be the indices of the masked tokens | |
indices = repeat(torch.arange(0, x.shape[1]).long().to(x.device), "d -> b d", b=x.shape[0]) | |
x = x[~mask.bool()].view(batch_size, kept_elements_per_batch, embedding_dim) | |
mask = mask[:, :, 0] | |
kept_indices = indices[~mask.bool()].view(batch_size, kept_elements_per_batch) | |
removed_indices = indices[mask.bool()].view(batch_size, removed_elements_per_batch) | |
return x, kept_indices, removed_indices | |
def forward( | |
self, | |
x: torch.Tensor, | |
dynamic_world: torch.Tensor, | |
# different from the original | |
# presto - latlons can be optionally ignored | |
latlons: Optional[torch.Tensor] = None, | |
mask: Optional[torch.Tensor] = None, | |
month: Union[torch.Tensor, int] = 0, | |
eval_task: bool = True, | |
): | |
device = x.device | |
if mask is None: | |
mask = torch.zeros_like(x, device=x.device).float() | |
months = month_to_tensor(month, x.shape[0], x.shape[1], device) | |
month_embedding = self.month_embed(months) | |
positional_embedding = repeat( | |
self.pos_embed[:, : x.shape[1], :], "b t d -> (repeat b) t d", repeat=x.shape[0] | |
) | |
# we assume the number of masked patches is the same | |
# for all items in the batch. Otherwise things become a headache | |
all_tokens, all_masks = [], [] | |
for channel_group, channel_idxs in self.band_groups.items(): | |
tokens = self.eo_patch_embed[channel_group](x[:, :, channel_idxs]) | |
channel_embedding = self.channel_embed( | |
torch.tensor(self.band_group_to_idx[channel_group]).long().to(device) | |
) | |
channel_embedding = repeat(channel_embedding, "d -> b t d", b=x.shape[0], t=x.shape[1]) | |
if channel_group == "SRTM": | |
# for SRTM, we reduce it to a single token instead of | |
# a token per timestep | |
channel_wise_positional_embedding = torch.cat( | |
( | |
torch.zeros_like(month_embedding[:, 0:1]), | |
channel_embedding[:, 0:1], | |
torch.zeros_like(positional_embedding[:, 0:1]), | |
), | |
dim=-1, | |
) | |
indices = slice(0, 1) | |
else: | |
channel_wise_positional_embedding = torch.cat( | |
(month_embedding, channel_embedding, positional_embedding), dim=-1 | |
) | |
indices = slice(None) | |
tokens = tokens[:, indices] | |
tokens += channel_wise_positional_embedding | |
all_tokens.append(tokens) | |
group_mask = repeat( | |
torch.max(mask[:, indices, channel_idxs], dim=-1)[0], | |
"b t -> b t d", | |
d=tokens.shape[-1], | |
) | |
all_masks.append(group_mask) | |
# then, dynamic world | |
tokens = self.dw_embed(dynamic_world) | |
channel_embedding = self.channel_embed( | |
torch.tensor(self.band_group_to_idx["dynamic_world"]).long().to(device) | |
) | |
channel_embedding = repeat(channel_embedding, "d -> b t d", b=x.shape[0], t=x.shape[1]) | |
positional_embedding = torch.cat( | |
(month_embedding, channel_embedding, positional_embedding), dim=-1 | |
) | |
tokens += positional_embedding | |
all_tokens.append(tokens) | |
# now we calculate the mask for these [b, t] tokens | |
group_mask = repeat( | |
dynamic_world == NUM_DYNAMIC_WORLD_CLASSES, "b t -> b t d", d=tokens.shape[-1] | |
) | |
all_masks.append(group_mask) | |
x = torch.cat(all_tokens, dim=1) # [batch, timesteps, embedding_dim] | |
mask = torch.cat(all_masks, dim=1) # [batch, timesteps, embedding_dim] | |
x, kept_indices, removed_indices = self.mask_tokens(x, mask) | |
# append latlon tokens | |
if latlons is not None: | |
latlon_tokens = self.latlon_embed(self.cartesian(latlons)).unsqueeze(1) | |
x = torch.cat((latlon_tokens, x), dim=1) | |
# apply Transformer blocks | |
for blk in self.blocks: | |
x = blk(x) | |
# mask will be a boolean of shape [batch, total_num_tokens] | |
if eval_task: | |
return self.norm(x.mean(dim=1)) | |
return self.norm(x), kept_indices, removed_indices | |
class Decoder(nn.Module): | |
def __init__( | |
self, | |
channel_embeddings: nn.Embedding, | |
encoder_embed_dim=128, | |
decoder_embed_dim=128, | |
decoder_depth=2, | |
decoder_num_heads=8, | |
mlp_ratio=2, | |
max_sequence_length=24, | |
): | |
super().__init__() | |
self.band_groups = BANDS_GROUPS_IDX | |
# this is used for the channel embedding | |
self.band_group_to_idx = { | |
group_name: idx for idx, (group_name, _) in enumerate(self.band_groups.items()) | |
} | |
self.band_group_to_idx["dynamic_world"] = max(self.band_group_to_idx.values()) + 1 | |
self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True) | |
self.mask_token = nn.Parameter(torch.zeros(decoder_embed_dim)) | |
self.decoder_blocks = nn.ModuleList( | |
[ | |
Block( | |
decoder_embed_dim, | |
decoder_num_heads, | |
mlp_ratio, | |
qkv_bias=True, | |
norm_layer=nn.LayerNorm, | |
) | |
for _ in range(decoder_depth) | |
] | |
) | |
self.decoder_norm = nn.LayerNorm(decoder_embed_dim) | |
self.eo_decoder_pred = nn.ModuleDict( | |
{ | |
group_name: nn.Linear(decoder_embed_dim, len(group)) | |
for group_name, group in self.band_groups.items() | |
} | |
) | |
self.dw_decoder_pred = nn.Linear(decoder_embed_dim, NUM_DYNAMIC_WORLD_CLASSES) | |
self.channel_embeddings = channel_embeddings | |
channel_embedding_dims = channel_embeddings.weight.shape[-1] | |
remaining_embeddings = decoder_embed_dim - channel_embedding_dims | |
# the positional + monthly + channel embedding | |
self.max_sequence_length = max_sequence_length | |
self.pos_embed = nn.Parameter( | |
torch.zeros(1, max_sequence_length, int(remaining_embeddings) // 2), | |
requires_grad=False, | |
) | |
month_tab = get_month_encoding_table(int(remaining_embeddings) // 2) | |
self.month_embed = nn.Embedding.from_pretrained(month_tab, freeze=True) | |
self.initialize_weights() | |
def initialize_weights(self): | |
pos_embed = get_sinusoid_encoding_table(self.pos_embed.shape[1], self.pos_embed.shape[-1]) | |
self.pos_embed.data.copy_(pos_embed) | |
# initialize nn.Linear and nn.LayerNorm | |
self.apply(self._init_weights) | |
def _init_weights(self, m): | |
if isinstance(m, nn.Linear): | |
# we use xavier_uniform following official JAX ViT: | |
torch.nn.init.xavier_uniform_(m.weight) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.LayerNorm): | |
nn.init.constant_(m.bias, 0) | |
nn.init.constant_(m.weight, 1.0) | |
def add_masked_tokens(self, x, kept_indices, removed_indices): | |
mask_tokens = repeat( | |
self.mask_token, "d -> b t d", b=x.shape[0], t=removed_indices.shape[1] | |
) | |
x = torch.cat([x, mask_tokens], dim=1) | |
# sort according to their indices. Shape is [batch, index] | |
combined_indices = torch.cat([kept_indices, removed_indices], dim=1) + 1 | |
# 0 for latlon index | |
combined_indices = torch.sort( | |
torch.cat([torch.zeros_like(combined_indices[:, 0:1]), combined_indices], dim=1) | |
)[1] | |
# and then tile for each dimension | |
combined_indices = repeat(combined_indices, "b t -> b t d", d=x.shape[-1]) | |
x = torch.gather(x, 1, combined_indices) | |
return x | |
def add_embeddings(self, x, month: Union[torch.Tensor, int]): | |
num_channel_groups = len(self.band_group_to_idx) | |
# -2 since we remove srtm and latlon, and -1 since the srtm | |
# channel group doesn't have timesteps | |
num_timesteps = int((x.shape[1] - 2) / (num_channel_groups - 1)) | |
srtm_index = self.band_group_to_idx["SRTM"] * num_timesteps | |
months = month_to_tensor(month, x.shape[0], num_timesteps, x.device) | |
# when we expand the encodings, each channel_group gets num_timesteps | |
# encodings. However, there is only one SRTM token so we remove the | |
# excess SRTM encodings | |
remove_mask = torch.full(size=(num_timesteps * num_channel_groups,), fill_value=False) | |
remove_mask[torch.arange(num_timesteps - 1) + srtm_index] = True | |
month_embedding = repeat( | |
self.month_embed(months), "b t d -> b (repeat t) d", repeat=num_channel_groups | |
) | |
month_embedding = month_embedding[:, ~remove_mask] | |
month_embedding[:, srtm_index] = 0 | |
positional_embedding = repeat( | |
self.pos_embed[:, :num_timesteps, :], | |
"b t d -> (b2 b) (t2 t) d", | |
b2=x.shape[0], | |
t2=num_channel_groups, | |
) | |
positional_embedding = positional_embedding[:, ~remove_mask] | |
positional_embedding[:, srtm_index] = 0 | |
channel_embeddings = torch.repeat_interleave( | |
self.channel_embeddings.weight, repeats=num_timesteps, dim=0 | |
) | |
channel_embeddings = repeat(channel_embeddings, "c d -> b c d", b=x.shape[0]) | |
channel_embeddings = channel_embeddings[:, ~remove_mask] | |
positional_embedding = torch.cat( | |
(month_embedding, channel_embeddings, positional_embedding), dim=-1 | |
) | |
# add the zero embedding for the latlon token | |
positional_embedding = torch.cat( | |
[torch.zeros_like(positional_embedding[:, 0:1, :]), positional_embedding], dim=1 | |
) | |
x += positional_embedding | |
return x | |
def reconstruct_inputs(self, x) -> Tuple[torch.Tensor, torch.Tensor]: | |
# remove the latlon token | |
x = x[:, 1:, :] | |
# split into channel groups | |
num_channel_groups = len(self.band_group_to_idx) - 1 | |
num_timesteps = int((x.shape[1] - 1) / num_channel_groups) | |
srtm_index = self.band_group_to_idx["SRTM"] * num_timesteps | |
srtm_token = x[:, srtm_index : srtm_index + 1, :] | |
mask = torch.full((x.shape[1],), True, device=x.device) | |
mask[torch.tensor(srtm_index)] = False | |
x = x[:, mask] | |
x = x.view(x.shape[0], num_channel_groups, num_timesteps, x.shape[-1]) | |
eo_output, dw_output = [], None | |
for group_name, idx in self.band_group_to_idx.items(): | |
if group_name == "SRTM": | |
eo_output.append( | |
repeat( | |
self.eo_decoder_pred[group_name](srtm_token), | |
"b t d -> b (t2 t) d", | |
t2=num_timesteps, | |
) | |
) | |
else: | |
if idx > self.band_group_to_idx["SRTM"]: | |
idx -= 1 | |
group_tokens = x[:, idx] | |
if group_name == "dynamic_world": | |
dw_output = self.dw_decoder_pred(group_tokens) | |
else: | |
eo_output.append(self.eo_decoder_pred[group_name](group_tokens)) | |
# we can just do this concatenation because the BANDS_GROUP_IDX | |
# is ordered | |
return torch.cat(eo_output, dim=-1), cast(torch.Tensor, dw_output) | |
def forward(self, x, kept_indices, removed_indices, month): | |
x = self.decoder_embed(x) | |
x = self.add_masked_tokens(x, kept_indices, removed_indices) | |
x = self.add_embeddings(x, month) | |
# apply Transformer blocks | |
for blk in self.decoder_blocks: | |
x = blk(x) | |
x = self.decoder_norm(x) | |
return self.reconstruct_inputs(x) | |
class PrestoFineTuningModel(nn.Module): | |
def __init__(self, encoder, head): | |
super().__init__() | |
self.encoder: Encoder = deepcopy(encoder) | |
# make sure the model is trainable, since we can call | |
# this having called requires_grad_(False) | |
self.encoder.requires_grad_(True) | |
# but don't unfreeze the position encoder, which | |
# shouldn't be trainable | |
self.encoder.pos_embed.requires_grad_(False) | |
self.encoder.month_embed.requires_grad_(False) | |
self.head = head | |
def forward( | |
self, | |
x: torch.Tensor, | |
dynamic_world: torch.Tensor, | |
latlons: torch.Tensor, | |
mask: Optional[torch.Tensor] = None, | |
month: Union[torch.Tensor, int] = 0, | |
) -> torch.Tensor: | |
return self.head( | |
self.encoder( | |
x=x, | |
dynamic_world=dynamic_world, | |
latlons=latlons, | |
mask=mask, | |
month=month, | |
eval_task=True, | |
) | |
) | |
class FinetuningHead(nn.Module): | |
def __init__(self, hidden_size: int, num_outputs: int, regression: bool) -> None: | |
super().__init__() | |
self.hidden_size = hidden_size | |
self.num_outputs = num_outputs | |
self.regression = regression | |
self.linear = nn.Linear(hidden_size, num_outputs) | |
def forward(self, x: torch.Tensor): | |
x = self.linear(x) | |
if (not self.regression) & (self.num_outputs == 1): | |
x = torch.sigmoid(x) | |
return x | |
class Presto(nn.Module): | |
def __init__(self, encoder, decoder): | |
super().__init__() | |
self.encoder: Encoder = encoder | |
self.decoder: Decoder = decoder | |
def forward( | |
self, | |
x: torch.Tensor, | |
dynamic_world: torch.Tensor, | |
latlons: torch.Tensor, | |
mask: Optional[torch.Tensor] = None, | |
month: Union[torch.Tensor, int] = 0, | |
) -> torch.Tensor: | |
x, kept_indices, removed_indices = self.encoder( | |
x=x, | |
dynamic_world=dynamic_world, | |
latlons=latlons, | |
mask=mask, | |
month=month, | |
eval_task=False, | |
) | |
return self.decoder(x, kept_indices, removed_indices, month) | |
def construct( | |
cls, | |
encoder_embedding_size: int = 128, | |
channel_embed_ratio: float = 0.25, | |
month_embed_ratio: float = 0.25, | |
encoder_depth=2, | |
mlp_ratio=4, | |
encoder_num_heads=8, | |
decoder_embedding_size=128, | |
decoder_depth=2, | |
decoder_num_heads=8, | |
max_sequence_length=24, | |
): | |
encoder = Encoder( | |
embedding_size=encoder_embedding_size, | |
channel_embed_ratio=channel_embed_ratio, | |
month_embed_ratio=month_embed_ratio, | |
depth=encoder_depth, | |
mlp_ratio=mlp_ratio, | |
num_heads=encoder_num_heads, | |
max_sequence_length=max_sequence_length, | |
) | |
decoder = Decoder( | |
channel_embeddings=encoder.channel_embed, | |
encoder_embed_dim=encoder_embedding_size, | |
decoder_embed_dim=decoder_embedding_size, | |
decoder_depth=decoder_depth, | |
decoder_num_heads=decoder_num_heads, | |
mlp_ratio=mlp_ratio, | |
max_sequence_length=max_sequence_length, | |
) | |
return cls(encoder, decoder) | |
def construct_finetuning_model( | |
self, | |
num_outputs: int, | |
regression: bool = False, | |
): | |
head = FinetuningHead( | |
num_outputs=num_outputs, | |
hidden_size=self.encoder.embedding_size, | |
regression=regression, | |
) | |
model = PrestoFineTuningModel(self.encoder, head).to(self.encoder.pos_embed.device) | |
model.train() | |
return model | |
def load_pretrained(cls): | |
model = cls.construct() | |
model.load_state_dict(torch.load(DEFAULT_MODEL_PATH, map_location=device)) | |
return model | |