Spaces:
Sleeping
Sleeping
from typing import cast | |
import torch | |
from src.data import ( | |
SPACE_BAND_GROUPS_IDX, | |
SPACE_TIME_BANDS_GROUPS_IDX, | |
STATIC_BAND_GROUPS_IDX, | |
TIME_BAND_GROUPS_IDX, | |
) | |
from src.data.dataset import ( | |
SPACE_BANDS, | |
SPACE_TIME_BANDS, | |
STATIC_BANDS, | |
TIME_BANDS, | |
Normalizer, | |
to_cartesian, | |
) | |
from src.data.earthengine.eo import ( | |
DW_BANDS, | |
ERA5_BANDS, | |
LANDSCAN_BANDS, | |
LOCATION_BANDS, | |
S1_BANDS, | |
S2_BANDS, | |
SRTM_BANDS, | |
TC_BANDS, | |
VIIRS_BANDS, | |
WC_BANDS, | |
) | |
from src.masking import MaskedOutput | |
DEFAULT_MONTH = 5 | |
def construct_galileo_input( | |
s1: torch.Tensor | None = None, # [H, W, T, D] | |
s2: torch.Tensor | None = None, # [H, W, T, D] | |
era5: torch.Tensor | None = None, # [T, D] | |
tc: torch.Tensor | None = None, # [T, D] | |
viirs: torch.Tensor | None = None, # [T, D] | |
srtm: torch.Tensor | None = None, # [H, W, D] | |
dw: torch.Tensor | None = None, # [H, W, D] | |
wc: torch.Tensor | None = None, # [H, W, D] | |
landscan: torch.Tensor | None = None, # [D] | |
latlon: torch.Tensor | None = None, # [D] | |
months: torch.Tensor | None = None, # [T] | |
normalize: bool = False, | |
): | |
space_time_inputs = [s1, s2] | |
time_inputs = [era5, tc, viirs] | |
space_inputs = [srtm, dw, wc] | |
static_inputs = [landscan, latlon] | |
devices = [ | |
x.device | |
for x in space_time_inputs + time_inputs + space_inputs + static_inputs | |
if x is not None | |
] | |
if len(devices) == 0: | |
raise ValueError("At least one input must be not None") | |
if not all(devices[0] == device for device in devices): | |
raise ValueError("Received tensors on multiple devices") | |
device = devices[0] | |
# first, check all the input shapes are consistent | |
timesteps_list = [x.shape[2] for x in space_time_inputs if x is not None] + [ | |
x.shape[1] for x in time_inputs if x is not None | |
] | |
height_list = [x.shape[0] for x in space_time_inputs if x is not None] + [ | |
x.shape[0] for x in space_inputs if x is not None | |
] | |
width_list = [x.shape[1] for x in space_time_inputs if x is not None] + [ | |
x.shape[1] for x in space_inputs if x is not None | |
] | |
if len(timesteps_list) > 0: | |
if not all(timesteps_list[0] == timestep for timestep in timesteps_list): | |
raise ValueError("Inconsistent number of timesteps per input") | |
t = timesteps_list[0] | |
else: | |
t = 1 | |
if len(height_list) > 0: | |
if not all(height_list[0] == height for height in height_list): | |
raise ValueError("Inconsistent heights per input") | |
if not all(width_list[0] == width for width in width_list): | |
raise ValueError("Inconsistent widths per input") | |
h = height_list[0] | |
w = width_list[0] | |
else: | |
h, w = 1, 1 | |
# now, we can construct our empty input tensors. By default, everything is masked | |
s_t_x = torch.zeros((h, w, t, len(SPACE_TIME_BANDS)), dtype=torch.float, device=device) | |
s_t_m = torch.ones( | |
(h, w, t, len(SPACE_TIME_BANDS_GROUPS_IDX)), dtype=torch.float, device=device | |
) | |
sp_x = torch.zeros((h, w, len(SPACE_BANDS)), dtype=torch.float, device=device) | |
sp_m = torch.ones((h, w, len(SPACE_BAND_GROUPS_IDX)), dtype=torch.float, device=device) | |
t_x = torch.zeros((t, len(TIME_BANDS)), dtype=torch.float, device=device) | |
t_m = torch.ones((t, len(TIME_BAND_GROUPS_IDX)), dtype=torch.float, device=device) | |
st_x = torch.zeros((len(STATIC_BANDS)), dtype=torch.float, device=device) | |
st_m = torch.ones((len(STATIC_BAND_GROUPS_IDX)), dtype=torch.float, device=device) | |
for x, bands_list, group_key in zip([s1, s2], [S1_BANDS, S2_BANDS], ["S1", "S2"]): | |
if x is not None: | |
indices = [idx for idx, val in enumerate(SPACE_TIME_BANDS) if val in bands_list] | |
groups_idx = [ | |
idx for idx, key in enumerate(SPACE_TIME_BANDS_GROUPS_IDX) if group_key in key | |
] | |
s_t_x[:, :, :, indices] = x | |
s_t_m[:, :, :, groups_idx] = 0 | |
for x, bands_list, group_key in zip( | |
[srtm, dw, wc], [SRTM_BANDS, DW_BANDS, WC_BANDS], ["SRTM", "DW", "WC"] | |
): | |
if x is not None: | |
indices = [idx for idx, val in enumerate(SPACE_BANDS) if val in bands_list] | |
groups_idx = [idx for idx, key in enumerate(SPACE_BAND_GROUPS_IDX) if group_key in key] | |
sp_x[:, :, indices] = x | |
sp_m[:, :, groups_idx] = 0 | |
for x, bands_list, group_key in zip( | |
[era5, tc, viirs], [ERA5_BANDS, TC_BANDS, VIIRS_BANDS], ["ERA5", "TC", "VIIRS"] | |
): | |
if x is not None: | |
indices = [idx for idx, val in enumerate(TIME_BANDS) if val in bands_list] | |
groups_idx = [idx for idx, key in enumerate(TIME_BAND_GROUPS_IDX) if group_key in key] | |
t_x[:, indices] = x | |
t_m[:, groups_idx] = 0 | |
for x, bands_list, group_key in zip( | |
[landscan, latlon], [LANDSCAN_BANDS, LOCATION_BANDS], ["LS", "location"] | |
): | |
if x is not None: | |
if group_key == "location": | |
# transform latlon to cartesian | |
x = cast(torch.Tensor, to_cartesian(x[0], x[1])) | |
indices = [idx for idx, val in enumerate(STATIC_BANDS) if val in bands_list] | |
groups_idx = [ | |
idx for idx, key in enumerate(STATIC_BAND_GROUPS_IDX) if group_key in key | |
] | |
st_x[indices] = x | |
st_m[groups_idx] = 0 | |
if months is None: | |
months = torch.ones((t), dtype=torch.long, device=device) * DEFAULT_MONTH | |
else: | |
if months.shape[0] != t: | |
raise ValueError("Incorrect number of input months") | |
if normalize: | |
normalizer = Normalizer(std=False) | |
s_t_x = torch.from_numpy(normalizer(s_t_x.cpu().numpy())).to(device) | |
sp_x = torch.from_numpy(normalizer(sp_x.cpu().numpy())).to(device) | |
t_x = torch.from_numpy(normalizer(t_x.cpu().numpy())).to(device) | |
st_x = torch.from_numpy(normalizer(st_x.cpu().numpy())).to(device) | |
return MaskedOutput( | |
space_time_x=s_t_x, | |
space_time_mask=s_t_m, | |
space_x=sp_x, | |
space_mask=sp_m, | |
time_x=t_x, | |
time_mask=t_m, | |
static_x=st_x, | |
static_mask=st_m, | |
months=months, | |
) | |