Spaces:
Sleeping
Sleeping
from pathlib import Path | |
from typing import Optional | |
import torch | |
from einops import repeat | |
from torch import nn | |
from .single_file_presto import ( | |
NUM_DYNAMIC_WORLD_CLASSES, | |
PRESTO_BANDS, | |
PRESTO_S1_BANDS, | |
PRESTO_S2_BANDS, | |
Presto, | |
) | |
WEIGHTS_PATH = Path(__file__).parent / "default_model.pt" | |
assert WEIGHTS_PATH.exists() | |
INPUT_PRESTO_BANDS = [b for b in PRESTO_BANDS if b != "B9"] | |
INPUT_PRESTO_S2_BANDS = [b for b in PRESTO_S2_BANDS if b != "B9"] | |
class PrestoWrapper(nn.Module): | |
# we assume any data passed to this wrapper | |
# will contain S2 data with the following channels | |
S2_BAND_ORDERING = [ | |
"B1", | |
"B2", | |
"B3", | |
"B4", | |
"B5", | |
"B6", | |
"B7", | |
"B8", | |
"B8A", | |
"B9", | |
"B10", | |
"B11", | |
"B12", | |
] | |
S1_BAND_ORDERING = [ | |
"VV", | |
"VH", | |
] | |
def __init__(self, do_pool=True, temporal_pooling: str = "mean"): | |
super().__init__() | |
model = Presto.construct() | |
model.load_state_dict(torch.load(WEIGHTS_PATH, map_location="cpu")) | |
self.encoder = model.encoder | |
self.dim = self.encoder.embedding_size | |
self.do_pool = do_pool | |
if temporal_pooling != "mean": | |
raise ValueError("Only mean temporal pooling supported by Presto") | |
if not do_pool: | |
raise ValueError("Presto cannot output spatial tokens") | |
self.kept_s2_band_idx = [ | |
i for i, v in enumerate(self.S2_BAND_ORDERING) if v in INPUT_PRESTO_S2_BANDS | |
] | |
self.kept_s1_band_idx = [ | |
i for i, v in enumerate(self.S1_BAND_ORDERING) if v in PRESTO_S1_BANDS | |
] | |
kept_s2_band_names = [val for val in self.S2_BAND_ORDERING if val in INPUT_PRESTO_S2_BANDS] | |
kept_s1_band_names = [val for val in self.S1_BAND_ORDERING if val in PRESTO_S1_BANDS] | |
self.to_presto_s2_map = [PRESTO_BANDS.index(val) for val in kept_s2_band_names] | |
self.to_presto_s1_map = [PRESTO_BANDS.index(val) for val in kept_s1_band_names] | |
self.month = 6 # default month | |
def preproccess( | |
self, | |
s2: Optional[torch.Tensor] = None, | |
s1: Optional[torch.Tensor] = None, | |
months: Optional[torch.Tensor] = None, | |
): | |
# images should have shape (b h w c) or (b h w t c) | |
if s2 is not None: | |
data_device = s2.device | |
if len(s2.shape) == 4: | |
b, h, w, c_s2 = s2.shape | |
t = 1 | |
s2 = repeat(torch.mean(s2, dim=(1, 2)), "b d -> b t d", t=1) | |
else: | |
assert len(s2.shape) == 5 | |
b, h, w, t, c_s2 = s2.shape | |
s2 = torch.mean(s2, dim=(1, 2)) | |
assert c_s2 == len(self.S2_BAND_ORDERING) | |
x = torch.zeros((b, t, len(INPUT_PRESTO_BANDS)), dtype=s2.dtype, device=s2.device) | |
x[:, :, self.to_presto_s2_map] = s2[:, :, self.kept_s2_band_idx] | |
elif s1 is not None: | |
data_device = s1.device | |
if len(s1.shape) == 4: | |
b, h, w, c_s1 = s1.shape | |
t = 1 | |
s1 = repeat(torch.mean(s1, dim=(1, 2)), "b d -> b t d", t=1) | |
else: | |
assert len(s1.shape) == 5 | |
b, h, w, t, c_s1 = s1.shape | |
s1 = torch.mean(s1, dim=(1, 2)) | |
assert c_s1 == len(self.S1_BAND_ORDERING) | |
# add a single timestep | |
x = torch.zeros((b, t, len(INPUT_PRESTO_BANDS)), dtype=s1.dtype, device=s1.device) | |
x[:, :, self.to_presto_s1_map] = s1[:, :, self.kept_s1_band_idx] | |
else: | |
raise ValueError("no s1 or s2?") | |
s_t_m = torch.ones( | |
(b, t, len(INPUT_PRESTO_BANDS)), | |
dtype=x.dtype, | |
device=x.device, | |
) | |
if s2 is not None: | |
s_t_m[:, :, self.to_presto_s2_map] = 0 | |
elif s1 is not None: | |
s_t_m[:, :, self.to_presto_s1_map] = 0 | |
if months is None: | |
months = torch.ones((b, t), device=data_device) * self.month | |
else: | |
assert months.shape[-1] == t | |
dymamic_world = torch.ones((b, t), device=data_device) * NUM_DYNAMIC_WORLD_CLASSES | |
return ( | |
x, | |
s_t_m, | |
dymamic_world.long(), | |
months.long(), | |
) | |
def forward(self, s2=None, s1=None, months=None): | |
x, mask, dynamic_world, months = self.preproccess(s2=s2, s1=s1, months=months) | |
return self.encoder( | |
x=x, dynamic_world=dynamic_world, mask=mask, month=months, eval_task=True | |
) # [B, self.dim] | |