Spaces:
Running
Running
import unittest | |
import torch | |
from src.data.utils import ( | |
S2_BANDS, | |
SPACE_TIME_BANDS, | |
SPACE_TIME_BANDS_GROUPS_IDX, | |
construct_galileo_input, | |
) | |
class TestDataUtils(unittest.TestCase): | |
def test_construct_galileo_input_s2(self): | |
t, h, w = 2, 4, 4 | |
s2 = torch.randn((t, h, w, len(S2_BANDS))) | |
for normalize in [True, False]: | |
masked_output = construct_galileo_input(s2=s2, normalize=normalize) | |
self.assertTrue((masked_output.space_mask == 1).all()) | |
self.assertTrue((masked_output.time_mask == 1).all()) | |
self.assertTrue((masked_output.static_mask == 1).all()) | |
# check that only the s2 bands got unmasked | |
not_s2 = [ | |
idx for idx, key in enumerate(SPACE_TIME_BANDS_GROUPS_IDX) if "S2" not in key | |
] | |
self.assertTrue((masked_output.space_time_mask[:, :, :, not_s2] == 1).all()) | |
# and that s2 got unmasked | |
s2_mask_indices = [ | |
idx for idx, key in enumerate(SPACE_TIME_BANDS_GROUPS_IDX) if "S2" in key | |
] | |
self.assertTrue((masked_output.space_time_mask[:, :, :, s2_mask_indices] == 0).all()) | |
# and got assigned to the right indices | |
if not normalize: | |
s2_indices = [idx for idx, val in enumerate(SPACE_TIME_BANDS) if val in S2_BANDS] | |
self.assertTrue(torch.equal(masked_output.space_time_x[:, :, :, s2_indices], s2)) | |