Spaces:
Sleeping
Sleeping
import random | |
import unittest | |
import torch | |
from einops import repeat | |
from src.masking import ( | |
MASKING_MODES, | |
MAX_MASKING_STRATEGIES, | |
SPACE_BAND_GROUPS_IDX, | |
SPACE_TIME_BANDS_GROUPS_IDX, | |
STATIC_BAND_GROUPS_IDX, | |
TIME_BAND_GROUPS_IDX, | |
batch_mask_random, | |
batch_mask_space, | |
batch_mask_time, | |
check_modes_for_conflicts, | |
filter_unmasking_mode_candidates, | |
weighted_sample_without_replacement, | |
) | |
class TestMasking(unittest.TestCase): | |
def check_all_values_in_masks( | |
self, space_time_mask, space_mask, time_mask, static_mask, masking_modes, unmasking_modes | |
): | |
self.assertTrue( | |
(space_time_mask == 2).any() | |
| (space_mask == 2).any() | |
| (time_mask == 2).any() | |
| (static_mask == 2).any(), | |
f"2 check failed for {masking_modes}, {unmasking_modes}", | |
) | |
self.assertTrue( | |
(space_time_mask == 0).any() | |
| (space_mask == 0).any() | |
| (time_mask == 0).any() | |
| (static_mask == 0).any(), | |
f"0 check failed for {masking_modes}, {unmasking_modes}", | |
) | |
self.assertTrue( | |
(space_time_mask == 1).any() | |
| (space_mask == 1).any() | |
| (time_mask == 1).any() | |
| (static_mask == 1).any(), | |
f"1 check failed for {masking_modes}, {unmasking_modes}", | |
) | |
def test_mask_by_time(self): | |
# testing specific failure modes | |
self._test_mask_by_for_f( | |
batch_mask_time, | |
[ | |
("static", "LS"), | |
("static", "location"), | |
("space", "WC"), | |
("space_time", "S2_SWIR"), | |
("space", "DW"), | |
("space_time", "S2_NIR_20m"), | |
], | |
[("time", "TC"), ("time", "VIIRS")], | |
) | |
for _ in range(100): | |
num_masking_modes = random.choice(list(range(2, MAX_MASKING_STRATEGIES + 1))) | |
num_unmasking_modes = 1 | |
masking_modes = weighted_sample_without_replacement( | |
MASKING_MODES, weights=[1] * len(MASKING_MODES), k=num_masking_modes | |
) | |
unmasking_modes = weighted_sample_without_replacement( | |
MASKING_MODES, weights=[1] * len(MASKING_MODES), k=num_unmasking_modes | |
) | |
self.assertTrue( | |
len(unmasking_modes) == num_unmasking_modes, f"Got {len(unmasking_modes)}" | |
) | |
masking_modes, unmasking_modes = check_modes_for_conflicts( | |
masking_modes, unmasking_modes | |
) | |
self.assertTrue( | |
len(unmasking_modes) == num_unmasking_modes, f"Got {len(unmasking_modes)}" | |
) | |
self.assertTrue(len(masking_modes) >= 1, f"Got {len(masking_modes)}") | |
for m_m in masking_modes: | |
self.assertTrue(m_m not in unmasking_modes, f"{m_m} in {unmasking_modes}") | |
for u_m in unmasking_modes: | |
self.assertTrue(u_m not in masking_modes, f"{u_m} in {masking_modes}") | |
self.assertTrue(len(masking_modes) >= 1) | |
self.assertTrue(len(unmasking_modes) >= 1) | |
self._test_mask_by_for_f(batch_mask_space, masking_modes, unmasking_modes) | |
def test_mask_by_space(self): | |
for _ in range(100): | |
num_masking_modes = random.choice(list(range(2, MAX_MASKING_STRATEGIES + 1))) | |
num_unmasking_modes = 1 | |
masking_modes = weighted_sample_without_replacement( | |
MASKING_MODES, weights=[1] * len(MASKING_MODES), k=num_masking_modes | |
) | |
unmasking_modes = weighted_sample_without_replacement( | |
MASKING_MODES, weights=[1] * len(MASKING_MODES), k=num_unmasking_modes | |
) | |
self.assertTrue( | |
len(unmasking_modes) == num_unmasking_modes, f"Got {len(unmasking_modes)}" | |
) | |
masking_modes, unmasking_modes = check_modes_for_conflicts( | |
masking_modes, unmasking_modes | |
) | |
self.assertTrue( | |
len(unmasking_modes) == num_unmasking_modes, f"Got {len(unmasking_modes)}" | |
) | |
self.assertTrue(len(masking_modes) >= 1, f"Got {len(masking_modes)}") | |
for m_m in masking_modes: | |
self.assertTrue(m_m not in unmasking_modes, f"{m_m} in {unmasking_modes}") | |
for u_m in unmasking_modes: | |
self.assertTrue(u_m not in masking_modes, f"{u_m} in {masking_modes}") | |
self.assertTrue(len(masking_modes) >= 1) | |
self.assertTrue(len(unmasking_modes) >= 1) | |
self._test_mask_by_for_f(batch_mask_space, masking_modes, unmasking_modes) | |
def _test_mask_by_for_f(self, f, masking_modes, unmasking_modes): | |
for t in range(4, 8): | |
b, h, w = 2, 16, 16 | |
space_time_input = torch.ones((b, h, w, t, 8)) | |
space_input = torch.ones((b, h, w, 8)) | |
time_input = torch.ones((b, t, 8)) | |
static_input = torch.ones((b, 8)) | |
months = repeat(torch.arange(0, t), "t -> b t", b=b) | |
ratio = 0.25 | |
output = f( | |
space_time_input, | |
space_input, | |
time_input, | |
static_input, | |
months, | |
encode_ratio=ratio, | |
decode_ratio=ratio, | |
mode=masking_modes, | |
decoder_mode=unmasking_modes, | |
patch_size=4, | |
) | |
self.check_all_values_in_masks( | |
output.space_time_mask, | |
output.space_mask, | |
output.time_mask, | |
output.static_mask, | |
masking_modes, | |
unmasking_modes, | |
) | |
self.assertEqual( | |
(b, h, w, t, len(SPACE_TIME_BANDS_GROUPS_IDX)), | |
output.space_time_mask.shape, | |
) | |
self.assertEqual((b, h, w, len(SPACE_BAND_GROUPS_IDX)), output.space_mask.shape) | |
self.assertEqual((b, t, len(TIME_BAND_GROUPS_IDX)), output.time_mask.shape) | |
self.assertEqual((b, len(STATIC_BAND_GROUPS_IDX)), output.static_mask.shape) | |
def test_mask_by_random(self): | |
b, t, h, w, p = 2, 8, 16, 16, 4 | |
h_tokens, w_tokens = h / p, w / p | |
space_time_input = torch.ones((b, h, w, t, 8)) | |
space_input = torch.ones((b, h, w, 8)) | |
time_input = torch.ones((b, t, 8)) | |
static_input = torch.ones((b, 8)) | |
months = repeat(torch.arange(0, t), "t -> b t", b=b) | |
ratio = 0.25 | |
output = batch_mask_random( | |
space_time_input, | |
space_input, | |
time_input, | |
static_input, | |
months, | |
encode_ratio=ratio, | |
decode_ratio=ratio, | |
patch_size=p, | |
ignore_band_groups=["DW", "DW_static"], | |
) | |
self.check_all_values_in_masks( | |
output.space_time_mask, | |
output.space_mask, | |
output.time_mask, | |
output.static_mask, | |
"random", | |
"random", | |
) | |
self.assertEqual( | |
(b, h, w, t, len(SPACE_TIME_BANDS_GROUPS_IDX)), output.space_time_mask.shape | |
) | |
self.assertEqual((b, h, w, len(SPACE_BAND_GROUPS_IDX)), output.space_mask.shape) | |
self.assertEqual((b, t, len(TIME_BAND_GROUPS_IDX)), output.time_mask.shape) | |
self.assertEqual((b, len(STATIC_BAND_GROUPS_IDX)), output.static_mask.shape) | |
for i in range(1, p): | |
self.assertTrue( | |
torch.equal( | |
output.space_time_mask[:, i::p, i::p], | |
output.space_time_mask[:, i - 1 :: p, i - 1 :: p], | |
) | |
) | |
self.assertTrue( | |
torch.equal( | |
output.space_mask[:, i::p, i::p], | |
output.space_mask[:, i - 1 :: p, i - 1 :: p], | |
) | |
) | |
space_time_per_token = output.space_time_mask[:, i::p, i::p] | |
space_time_encode_per_instance = ( | |
space_time_per_token[space_time_per_token == 0] + 1 | |
).sum() | |
space_time_decode_per_instance = space_time_per_token[space_time_per_token == 2].sum() | |
space_per_token = output.space_mask[:, i::p, i::p] | |
space_encode_per_instance = (space_per_token[space_per_token == 0] + 1).sum() | |
space_decode_per_instance = space_per_token[space_per_token == 2].sum() | |
time_per_token = output.time_mask | |
time_encode_per_instance = (time_per_token[time_per_token == 0] + 1).sum() | |
time_decode_per_instance = time_per_token[time_per_token == 2].sum() | |
static_per_token = output.static_mask | |
static_encode_per_instance = (static_per_token[static_per_token == 0] + 1).sum() | |
static_decode_per_instance = static_per_token[static_per_token == 2].sum() | |
total_tokens = ( | |
(h_tokens * w_tokens * t * len(SPACE_TIME_BANDS_GROUPS_IDX)) | |
# -1 because we have now masked out dynamic world | |
+ (h_tokens * w_tokens * (len(SPACE_BAND_GROUPS_IDX) - 1)) | |
+ (t * len(TIME_BAND_GROUPS_IDX)) | |
# -1 because we have now masked out dynamic world static | |
+ (len(STATIC_BAND_GROUPS_IDX) - 1) | |
) * b | |
# handles off by one errors | |
self.assertTrue( | |
( | |
space_time_encode_per_instance | |
+ space_encode_per_instance | |
+ time_encode_per_instance | |
+ static_encode_per_instance | |
<= int(total_tokens * ratio) + 1 | |
).all() | |
and ( | |
space_time_encode_per_instance | |
+ space_encode_per_instance | |
+ time_encode_per_instance | |
+ static_encode_per_instance | |
>= int(total_tokens * ratio) - 1 | |
).all() | |
) | |
self.assertTrue( | |
( | |
# hacky but the / 2 lets us easily handle the fact | |
# we are summing over values == 2, not 1 | |
( | |
space_time_decode_per_instance | |
+ space_decode_per_instance | |
+ time_decode_per_instance | |
+ static_decode_per_instance | |
) | |
/ 2 | |
<= (int(total_tokens * ratio) + 1) | |
).all() | |
and ( | |
( | |
space_time_decode_per_instance | |
+ space_decode_per_instance | |
+ time_decode_per_instance | |
+ static_decode_per_instance | |
) | |
/ 2 | |
>= int(total_tokens * ratio) - 1 | |
).all() | |
) | |
# check DW was masked out | |
expected_mask_index = list(SPACE_BAND_GROUPS_IDX.keys()).index("DW") | |
self.assertTrue((output.space_mask[:, :, :, expected_mask_index] == 1).all()) | |
expected_mask_index = list(STATIC_BAND_GROUPS_IDX.keys()).index("DW_static") | |
self.assertTrue((output.static_mask[:, expected_mask_index] == 1).all()) | |
def test_filter_candidates(self): | |
candidates = [ | |
[("space", "SRTM"), ("space_time", "S2_RGB")], | |
[("space", "SRTM"), ("space", "DW")], | |
[("space", "DW")], | |
] | |
ignore_bands = ["DW"] | |
outputs = filter_unmasking_mode_candidates(candidates, ignore_bands) | |
print(outputs) | |
self.assertEqual(outputs, [[("space", "SRTM"), ("space_time", "S2_RGB")]]) | |