import random import unittest import torch from einops import repeat from src.masking import ( MASKING_MODES, MAX_MASKING_STRATEGIES, SPACE_BAND_GROUPS_IDX, SPACE_TIME_BANDS_GROUPS_IDX, STATIC_BAND_GROUPS_IDX, TIME_BAND_GROUPS_IDX, batch_mask_random, batch_mask_space, batch_mask_time, check_modes_for_conflicts, filter_unmasking_mode_candidates, weighted_sample_without_replacement, ) class TestMasking(unittest.TestCase): def check_all_values_in_masks( self, space_time_mask, space_mask, time_mask, static_mask, masking_modes, unmasking_modes ): self.assertTrue( (space_time_mask == 2).any() | (space_mask == 2).any() | (time_mask == 2).any() | (static_mask == 2).any(), f"2 check failed for {masking_modes}, {unmasking_modes}", ) self.assertTrue( (space_time_mask == 0).any() | (space_mask == 0).any() | (time_mask == 0).any() | (static_mask == 0).any(), f"0 check failed for {masking_modes}, {unmasking_modes}", ) self.assertTrue( (space_time_mask == 1).any() | (space_mask == 1).any() | (time_mask == 1).any() | (static_mask == 1).any(), f"1 check failed for {masking_modes}, {unmasking_modes}", ) def test_mask_by_time(self): # testing specific failure modes self._test_mask_by_for_f( batch_mask_time, [ ("static", "LS"), ("static", "location"), ("space", "WC"), ("space_time", "S2_SWIR"), ("space", "DW"), ("space_time", "S2_NIR_20m"), ], [("time", "TC"), ("time", "VIIRS")], ) for _ in range(100): num_masking_modes = random.choice(list(range(2, MAX_MASKING_STRATEGIES + 1))) num_unmasking_modes = 1 masking_modes = weighted_sample_without_replacement( MASKING_MODES, weights=[1] * len(MASKING_MODES), k=num_masking_modes ) unmasking_modes = weighted_sample_without_replacement( MASKING_MODES, weights=[1] * len(MASKING_MODES), k=num_unmasking_modes ) self.assertTrue( len(unmasking_modes) == num_unmasking_modes, f"Got {len(unmasking_modes)}" ) masking_modes, unmasking_modes = check_modes_for_conflicts( masking_modes, unmasking_modes ) self.assertTrue( len(unmasking_modes) == num_unmasking_modes, f"Got {len(unmasking_modes)}" ) self.assertTrue(len(masking_modes) >= 1, f"Got {len(masking_modes)}") for m_m in masking_modes: self.assertTrue(m_m not in unmasking_modes, f"{m_m} in {unmasking_modes}") for u_m in unmasking_modes: self.assertTrue(u_m not in masking_modes, f"{u_m} in {masking_modes}") self.assertTrue(len(masking_modes) >= 1) self.assertTrue(len(unmasking_modes) >= 1) self._test_mask_by_for_f(batch_mask_space, masking_modes, unmasking_modes) def test_mask_by_space(self): for _ in range(100): num_masking_modes = random.choice(list(range(2, MAX_MASKING_STRATEGIES + 1))) num_unmasking_modes = 1 masking_modes = weighted_sample_without_replacement( MASKING_MODES, weights=[1] * len(MASKING_MODES), k=num_masking_modes ) unmasking_modes = weighted_sample_without_replacement( MASKING_MODES, weights=[1] * len(MASKING_MODES), k=num_unmasking_modes ) self.assertTrue( len(unmasking_modes) == num_unmasking_modes, f"Got {len(unmasking_modes)}" ) masking_modes, unmasking_modes = check_modes_for_conflicts( masking_modes, unmasking_modes ) self.assertTrue( len(unmasking_modes) == num_unmasking_modes, f"Got {len(unmasking_modes)}" ) self.assertTrue(len(masking_modes) >= 1, f"Got {len(masking_modes)}") for m_m in masking_modes: self.assertTrue(m_m not in unmasking_modes, f"{m_m} in {unmasking_modes}") for u_m in unmasking_modes: self.assertTrue(u_m not in masking_modes, f"{u_m} in {masking_modes}") self.assertTrue(len(masking_modes) >= 1) self.assertTrue(len(unmasking_modes) >= 1) self._test_mask_by_for_f(batch_mask_space, masking_modes, unmasking_modes) def _test_mask_by_for_f(self, f, masking_modes, unmasking_modes): for t in range(4, 8): b, h, w = 2, 16, 16 space_time_input = torch.ones((b, h, w, t, 8)) space_input = torch.ones((b, h, w, 8)) time_input = torch.ones((b, t, 8)) static_input = torch.ones((b, 8)) months = repeat(torch.arange(0, t), "t -> b t", b=b) ratio = 0.25 output = f( space_time_input, space_input, time_input, static_input, months, encode_ratio=ratio, decode_ratio=ratio, mode=masking_modes, decoder_mode=unmasking_modes, patch_size=4, ) self.check_all_values_in_masks( output.space_time_mask, output.space_mask, output.time_mask, output.static_mask, masking_modes, unmasking_modes, ) self.assertEqual( (b, h, w, t, len(SPACE_TIME_BANDS_GROUPS_IDX)), output.space_time_mask.shape, ) self.assertEqual((b, h, w, len(SPACE_BAND_GROUPS_IDX)), output.space_mask.shape) self.assertEqual((b, t, len(TIME_BAND_GROUPS_IDX)), output.time_mask.shape) self.assertEqual((b, len(STATIC_BAND_GROUPS_IDX)), output.static_mask.shape) def test_mask_by_random(self): b, t, h, w, p = 2, 8, 16, 16, 4 h_tokens, w_tokens = h / p, w / p space_time_input = torch.ones((b, h, w, t, 8)) space_input = torch.ones((b, h, w, 8)) time_input = torch.ones((b, t, 8)) static_input = torch.ones((b, 8)) months = repeat(torch.arange(0, t), "t -> b t", b=b) ratio = 0.25 output = batch_mask_random( space_time_input, space_input, time_input, static_input, months, encode_ratio=ratio, decode_ratio=ratio, patch_size=p, ignore_band_groups=["DW", "DW_static"], ) self.check_all_values_in_masks( output.space_time_mask, output.space_mask, output.time_mask, output.static_mask, "random", "random", ) self.assertEqual( (b, h, w, t, len(SPACE_TIME_BANDS_GROUPS_IDX)), output.space_time_mask.shape ) self.assertEqual((b, h, w, len(SPACE_BAND_GROUPS_IDX)), output.space_mask.shape) self.assertEqual((b, t, len(TIME_BAND_GROUPS_IDX)), output.time_mask.shape) self.assertEqual((b, len(STATIC_BAND_GROUPS_IDX)), output.static_mask.shape) for i in range(1, p): self.assertTrue( torch.equal( output.space_time_mask[:, i::p, i::p], output.space_time_mask[:, i - 1 :: p, i - 1 :: p], ) ) self.assertTrue( torch.equal( output.space_mask[:, i::p, i::p], output.space_mask[:, i - 1 :: p, i - 1 :: p], ) ) space_time_per_token = output.space_time_mask[:, i::p, i::p] space_time_encode_per_instance = ( space_time_per_token[space_time_per_token == 0] + 1 ).sum() space_time_decode_per_instance = space_time_per_token[space_time_per_token == 2].sum() space_per_token = output.space_mask[:, i::p, i::p] space_encode_per_instance = (space_per_token[space_per_token == 0] + 1).sum() space_decode_per_instance = space_per_token[space_per_token == 2].sum() time_per_token = output.time_mask time_encode_per_instance = (time_per_token[time_per_token == 0] + 1).sum() time_decode_per_instance = time_per_token[time_per_token == 2].sum() static_per_token = output.static_mask static_encode_per_instance = (static_per_token[static_per_token == 0] + 1).sum() static_decode_per_instance = static_per_token[static_per_token == 2].sum() total_tokens = ( (h_tokens * w_tokens * t * len(SPACE_TIME_BANDS_GROUPS_IDX)) # -1 because we have now masked out dynamic world + (h_tokens * w_tokens * (len(SPACE_BAND_GROUPS_IDX) - 1)) + (t * len(TIME_BAND_GROUPS_IDX)) # -1 because we have now masked out dynamic world static + (len(STATIC_BAND_GROUPS_IDX) - 1) ) * b # handles off by one errors self.assertTrue( ( space_time_encode_per_instance + space_encode_per_instance + time_encode_per_instance + static_encode_per_instance <= int(total_tokens * ratio) + 1 ).all() and ( space_time_encode_per_instance + space_encode_per_instance + time_encode_per_instance + static_encode_per_instance >= int(total_tokens * ratio) - 1 ).all() ) self.assertTrue( ( # hacky but the / 2 lets us easily handle the fact # we are summing over values == 2, not 1 ( space_time_decode_per_instance + space_decode_per_instance + time_decode_per_instance + static_decode_per_instance ) / 2 <= (int(total_tokens * ratio) + 1) ).all() and ( ( space_time_decode_per_instance + space_decode_per_instance + time_decode_per_instance + static_decode_per_instance ) / 2 >= int(total_tokens * ratio) - 1 ).all() ) # check DW was masked out expected_mask_index = list(SPACE_BAND_GROUPS_IDX.keys()).index("DW") self.assertTrue((output.space_mask[:, :, :, expected_mask_index] == 1).all()) expected_mask_index = list(STATIC_BAND_GROUPS_IDX.keys()).index("DW_static") self.assertTrue((output.static_mask[:, expected_mask_index] == 1).all()) def test_filter_candidates(self): candidates = [ [("space", "SRTM"), ("space_time", "S2_RGB")], [("space", "SRTM"), ("space", "DW")], [("space", "DW")], ] ignore_bands = ["DW"] outputs = filter_unmasking_mode_candidates(candidates, ignore_bands) print(outputs) self.assertEqual(outputs, [[("space", "SRTM"), ("space_time", "S2_RGB")]])