Spaces:
Sleeping
Sleeping
File size: 4,611 Bytes
b20c769 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
from pathlib import Path
from typing import Optional
import torch
from einops import repeat
from torch import nn
from .single_file_presto import (
NUM_DYNAMIC_WORLD_CLASSES,
PRESTO_BANDS,
PRESTO_S1_BANDS,
PRESTO_S2_BANDS,
Presto,
)
WEIGHTS_PATH = Path(__file__).parent / "default_model.pt"
assert WEIGHTS_PATH.exists()
INPUT_PRESTO_BANDS = [b for b in PRESTO_BANDS if b != "B9"]
INPUT_PRESTO_S2_BANDS = [b for b in PRESTO_S2_BANDS if b != "B9"]
class PrestoWrapper(nn.Module):
# we assume any data passed to this wrapper
# will contain S2 data with the following channels
S2_BAND_ORDERING = [
"B1",
"B2",
"B3",
"B4",
"B5",
"B6",
"B7",
"B8",
"B8A",
"B9",
"B10",
"B11",
"B12",
]
S1_BAND_ORDERING = [
"VV",
"VH",
]
def __init__(self, do_pool=True, temporal_pooling: str = "mean"):
super().__init__()
model = Presto.construct()
model.load_state_dict(torch.load(WEIGHTS_PATH, map_location="cpu"))
self.encoder = model.encoder
self.dim = self.encoder.embedding_size
self.do_pool = do_pool
if temporal_pooling != "mean":
raise ValueError("Only mean temporal pooling supported by Presto")
if not do_pool:
raise ValueError("Presto cannot output spatial tokens")
self.kept_s2_band_idx = [
i for i, v in enumerate(self.S2_BAND_ORDERING) if v in INPUT_PRESTO_S2_BANDS
]
self.kept_s1_band_idx = [
i for i, v in enumerate(self.S1_BAND_ORDERING) if v in PRESTO_S1_BANDS
]
kept_s2_band_names = [val for val in self.S2_BAND_ORDERING if val in INPUT_PRESTO_S2_BANDS]
kept_s1_band_names = [val for val in self.S1_BAND_ORDERING if val in PRESTO_S1_BANDS]
self.to_presto_s2_map = [PRESTO_BANDS.index(val) for val in kept_s2_band_names]
self.to_presto_s1_map = [PRESTO_BANDS.index(val) for val in kept_s1_band_names]
self.month = 6 # default month
def preproccess(
self,
s2: Optional[torch.Tensor] = None,
s1: Optional[torch.Tensor] = None,
months: Optional[torch.Tensor] = None,
):
# images should have shape (b h w c) or (b h w t c)
if s2 is not None:
data_device = s2.device
if len(s2.shape) == 4:
b, h, w, c_s2 = s2.shape
t = 1
s2 = repeat(torch.mean(s2, dim=(1, 2)), "b d -> b t d", t=1)
else:
assert len(s2.shape) == 5
b, h, w, t, c_s2 = s2.shape
s2 = torch.mean(s2, dim=(1, 2))
assert c_s2 == len(self.S2_BAND_ORDERING)
x = torch.zeros((b, t, len(INPUT_PRESTO_BANDS)), dtype=s2.dtype, device=s2.device)
x[:, :, self.to_presto_s2_map] = s2[:, :, self.kept_s2_band_idx]
elif s1 is not None:
data_device = s1.device
if len(s1.shape) == 4:
b, h, w, c_s1 = s1.shape
t = 1
s1 = repeat(torch.mean(s1, dim=(1, 2)), "b d -> b t d", t=1)
else:
assert len(s1.shape) == 5
b, h, w, t, c_s1 = s1.shape
s1 = torch.mean(s1, dim=(1, 2))
assert c_s1 == len(self.S1_BAND_ORDERING)
# add a single timestep
x = torch.zeros((b, t, len(INPUT_PRESTO_BANDS)), dtype=s1.dtype, device=s1.device)
x[:, :, self.to_presto_s1_map] = s1[:, :, self.kept_s1_band_idx]
else:
raise ValueError("no s1 or s2?")
s_t_m = torch.ones(
(b, t, len(INPUT_PRESTO_BANDS)),
dtype=x.dtype,
device=x.device,
)
if s2 is not None:
s_t_m[:, :, self.to_presto_s2_map] = 0
elif s1 is not None:
s_t_m[:, :, self.to_presto_s1_map] = 0
if months is None:
months = torch.ones((b, t), device=data_device) * self.month
else:
assert months.shape[-1] == t
dymamic_world = torch.ones((b, t), device=data_device) * NUM_DYNAMIC_WORLD_CLASSES
return (
x,
s_t_m,
dymamic_world.long(),
months.long(),
)
def forward(self, s2=None, s1=None, months=None):
x, mask, dynamic_world, months = self.preproccess(s2=s2, s1=s1, months=months)
return self.encoder(
x=x, dynamic_world=dynamic_world, mask=mask, month=months, eval_task=True
) # [B, self.dim]
|