from pathlib import Path from typing import Optional import torch from einops import repeat from torch import nn from .single_file_presto import ( NUM_DYNAMIC_WORLD_CLASSES, PRESTO_BANDS, PRESTO_S1_BANDS, PRESTO_S2_BANDS, Presto, ) WEIGHTS_PATH = Path(__file__).parent / "default_model.pt" assert WEIGHTS_PATH.exists() INPUT_PRESTO_BANDS = [b for b in PRESTO_BANDS if b != "B9"] INPUT_PRESTO_S2_BANDS = [b for b in PRESTO_S2_BANDS if b != "B9"] class PrestoWrapper(nn.Module): # we assume any data passed to this wrapper # will contain S2 data with the following channels S2_BAND_ORDERING = [ "B1", "B2", "B3", "B4", "B5", "B6", "B7", "B8", "B8A", "B9", "B10", "B11", "B12", ] S1_BAND_ORDERING = [ "VV", "VH", ] def __init__(self, do_pool=True, temporal_pooling: str = "mean"): super().__init__() model = Presto.construct() model.load_state_dict(torch.load(WEIGHTS_PATH, map_location="cpu")) self.encoder = model.encoder self.dim = self.encoder.embedding_size self.do_pool = do_pool if temporal_pooling != "mean": raise ValueError("Only mean temporal pooling supported by Presto") if not do_pool: raise ValueError("Presto cannot output spatial tokens") self.kept_s2_band_idx = [ i for i, v in enumerate(self.S2_BAND_ORDERING) if v in INPUT_PRESTO_S2_BANDS ] self.kept_s1_band_idx = [ i for i, v in enumerate(self.S1_BAND_ORDERING) if v in PRESTO_S1_BANDS ] kept_s2_band_names = [val for val in self.S2_BAND_ORDERING if val in INPUT_PRESTO_S2_BANDS] kept_s1_band_names = [val for val in self.S1_BAND_ORDERING if val in PRESTO_S1_BANDS] self.to_presto_s2_map = [PRESTO_BANDS.index(val) for val in kept_s2_band_names] self.to_presto_s1_map = [PRESTO_BANDS.index(val) for val in kept_s1_band_names] self.month = 6 # default month def preproccess( self, s2: Optional[torch.Tensor] = None, s1: Optional[torch.Tensor] = None, months: Optional[torch.Tensor] = None, ): # images should have shape (b h w c) or (b h w t c) if s2 is not None: data_device = s2.device if len(s2.shape) == 4: b, h, w, c_s2 = s2.shape t = 1 s2 = repeat(torch.mean(s2, dim=(1, 2)), "b d -> b t d", t=1) else: assert len(s2.shape) == 5 b, h, w, t, c_s2 = s2.shape s2 = torch.mean(s2, dim=(1, 2)) assert c_s2 == len(self.S2_BAND_ORDERING) x = torch.zeros((b, t, len(INPUT_PRESTO_BANDS)), dtype=s2.dtype, device=s2.device) x[:, :, self.to_presto_s2_map] = s2[:, :, self.kept_s2_band_idx] elif s1 is not None: data_device = s1.device if len(s1.shape) == 4: b, h, w, c_s1 = s1.shape t = 1 s1 = repeat(torch.mean(s1, dim=(1, 2)), "b d -> b t d", t=1) else: assert len(s1.shape) == 5 b, h, w, t, c_s1 = s1.shape s1 = torch.mean(s1, dim=(1, 2)) assert c_s1 == len(self.S1_BAND_ORDERING) # add a single timestep x = torch.zeros((b, t, len(INPUT_PRESTO_BANDS)), dtype=s1.dtype, device=s1.device) x[:, :, self.to_presto_s1_map] = s1[:, :, self.kept_s1_band_idx] else: raise ValueError("no s1 or s2?") s_t_m = torch.ones( (b, t, len(INPUT_PRESTO_BANDS)), dtype=x.dtype, device=x.device, ) if s2 is not None: s_t_m[:, :, self.to_presto_s2_map] = 0 elif s1 is not None: s_t_m[:, :, self.to_presto_s1_map] = 0 if months is None: months = torch.ones((b, t), device=data_device) * self.month else: assert months.shape[-1] == t dymamic_world = torch.ones((b, t), device=data_device) * NUM_DYNAMIC_WORLD_CLASSES return ( x, s_t_m, dymamic_world.long(), months.long(), ) def forward(self, s2=None, s1=None, months=None): x, mask, dynamic_world, months = self.preproccess(s2=s2, s1=s1, months=months) return self.encoder( x=x, dynamic_world=dynamic_world, mask=mask, month=months, eval_task=True ) # [B, self.dim]