NASA-Galileo / tests /test_masking.py
openfree's picture
Deploy from GitHub repository
b20c769 verified
import random
import unittest
import torch
from einops import repeat
from src.masking import (
MASKING_MODES,
MAX_MASKING_STRATEGIES,
SPACE_BAND_GROUPS_IDX,
SPACE_TIME_BANDS_GROUPS_IDX,
STATIC_BAND_GROUPS_IDX,
TIME_BAND_GROUPS_IDX,
batch_mask_random,
batch_mask_space,
batch_mask_time,
check_modes_for_conflicts,
filter_unmasking_mode_candidates,
weighted_sample_without_replacement,
)
class TestMasking(unittest.TestCase):
def check_all_values_in_masks(
self, space_time_mask, space_mask, time_mask, static_mask, masking_modes, unmasking_modes
):
self.assertTrue(
(space_time_mask == 2).any()
| (space_mask == 2).any()
| (time_mask == 2).any()
| (static_mask == 2).any(),
f"2 check failed for {masking_modes}, {unmasking_modes}",
)
self.assertTrue(
(space_time_mask == 0).any()
| (space_mask == 0).any()
| (time_mask == 0).any()
| (static_mask == 0).any(),
f"0 check failed for {masking_modes}, {unmasking_modes}",
)
self.assertTrue(
(space_time_mask == 1).any()
| (space_mask == 1).any()
| (time_mask == 1).any()
| (static_mask == 1).any(),
f"1 check failed for {masking_modes}, {unmasking_modes}",
)
def test_mask_by_time(self):
# testing specific failure modes
self._test_mask_by_for_f(
batch_mask_time,
[
("static", "LS"),
("static", "location"),
("space", "WC"),
("space_time", "S2_SWIR"),
("space", "DW"),
("space_time", "S2_NIR_20m"),
],
[("time", "TC"), ("time", "VIIRS")],
)
for _ in range(100):
num_masking_modes = random.choice(list(range(2, MAX_MASKING_STRATEGIES + 1)))
num_unmasking_modes = 1
masking_modes = weighted_sample_without_replacement(
MASKING_MODES, weights=[1] * len(MASKING_MODES), k=num_masking_modes
)
unmasking_modes = weighted_sample_without_replacement(
MASKING_MODES, weights=[1] * len(MASKING_MODES), k=num_unmasking_modes
)
self.assertTrue(
len(unmasking_modes) == num_unmasking_modes, f"Got {len(unmasking_modes)}"
)
masking_modes, unmasking_modes = check_modes_for_conflicts(
masking_modes, unmasking_modes
)
self.assertTrue(
len(unmasking_modes) == num_unmasking_modes, f"Got {len(unmasking_modes)}"
)
self.assertTrue(len(masking_modes) >= 1, f"Got {len(masking_modes)}")
for m_m in masking_modes:
self.assertTrue(m_m not in unmasking_modes, f"{m_m} in {unmasking_modes}")
for u_m in unmasking_modes:
self.assertTrue(u_m not in masking_modes, f"{u_m} in {masking_modes}")
self.assertTrue(len(masking_modes) >= 1)
self.assertTrue(len(unmasking_modes) >= 1)
self._test_mask_by_for_f(batch_mask_space, masking_modes, unmasking_modes)
def test_mask_by_space(self):
for _ in range(100):
num_masking_modes = random.choice(list(range(2, MAX_MASKING_STRATEGIES + 1)))
num_unmasking_modes = 1
masking_modes = weighted_sample_without_replacement(
MASKING_MODES, weights=[1] * len(MASKING_MODES), k=num_masking_modes
)
unmasking_modes = weighted_sample_without_replacement(
MASKING_MODES, weights=[1] * len(MASKING_MODES), k=num_unmasking_modes
)
self.assertTrue(
len(unmasking_modes) == num_unmasking_modes, f"Got {len(unmasking_modes)}"
)
masking_modes, unmasking_modes = check_modes_for_conflicts(
masking_modes, unmasking_modes
)
self.assertTrue(
len(unmasking_modes) == num_unmasking_modes, f"Got {len(unmasking_modes)}"
)
self.assertTrue(len(masking_modes) >= 1, f"Got {len(masking_modes)}")
for m_m in masking_modes:
self.assertTrue(m_m not in unmasking_modes, f"{m_m} in {unmasking_modes}")
for u_m in unmasking_modes:
self.assertTrue(u_m not in masking_modes, f"{u_m} in {masking_modes}")
self.assertTrue(len(masking_modes) >= 1)
self.assertTrue(len(unmasking_modes) >= 1)
self._test_mask_by_for_f(batch_mask_space, masking_modes, unmasking_modes)
def _test_mask_by_for_f(self, f, masking_modes, unmasking_modes):
for t in range(4, 8):
b, h, w = 2, 16, 16
space_time_input = torch.ones((b, h, w, t, 8))
space_input = torch.ones((b, h, w, 8))
time_input = torch.ones((b, t, 8))
static_input = torch.ones((b, 8))
months = repeat(torch.arange(0, t), "t -> b t", b=b)
ratio = 0.25
output = f(
space_time_input,
space_input,
time_input,
static_input,
months,
encode_ratio=ratio,
decode_ratio=ratio,
mode=masking_modes,
decoder_mode=unmasking_modes,
patch_size=4,
)
self.check_all_values_in_masks(
output.space_time_mask,
output.space_mask,
output.time_mask,
output.static_mask,
masking_modes,
unmasking_modes,
)
self.assertEqual(
(b, h, w, t, len(SPACE_TIME_BANDS_GROUPS_IDX)),
output.space_time_mask.shape,
)
self.assertEqual((b, h, w, len(SPACE_BAND_GROUPS_IDX)), output.space_mask.shape)
self.assertEqual((b, t, len(TIME_BAND_GROUPS_IDX)), output.time_mask.shape)
self.assertEqual((b, len(STATIC_BAND_GROUPS_IDX)), output.static_mask.shape)
def test_mask_by_random(self):
b, t, h, w, p = 2, 8, 16, 16, 4
h_tokens, w_tokens = h / p, w / p
space_time_input = torch.ones((b, h, w, t, 8))
space_input = torch.ones((b, h, w, 8))
time_input = torch.ones((b, t, 8))
static_input = torch.ones((b, 8))
months = repeat(torch.arange(0, t), "t -> b t", b=b)
ratio = 0.25
output = batch_mask_random(
space_time_input,
space_input,
time_input,
static_input,
months,
encode_ratio=ratio,
decode_ratio=ratio,
patch_size=p,
ignore_band_groups=["DW", "DW_static"],
)
self.check_all_values_in_masks(
output.space_time_mask,
output.space_mask,
output.time_mask,
output.static_mask,
"random",
"random",
)
self.assertEqual(
(b, h, w, t, len(SPACE_TIME_BANDS_GROUPS_IDX)), output.space_time_mask.shape
)
self.assertEqual((b, h, w, len(SPACE_BAND_GROUPS_IDX)), output.space_mask.shape)
self.assertEqual((b, t, len(TIME_BAND_GROUPS_IDX)), output.time_mask.shape)
self.assertEqual((b, len(STATIC_BAND_GROUPS_IDX)), output.static_mask.shape)
for i in range(1, p):
self.assertTrue(
torch.equal(
output.space_time_mask[:, i::p, i::p],
output.space_time_mask[:, i - 1 :: p, i - 1 :: p],
)
)
self.assertTrue(
torch.equal(
output.space_mask[:, i::p, i::p],
output.space_mask[:, i - 1 :: p, i - 1 :: p],
)
)
space_time_per_token = output.space_time_mask[:, i::p, i::p]
space_time_encode_per_instance = (
space_time_per_token[space_time_per_token == 0] + 1
).sum()
space_time_decode_per_instance = space_time_per_token[space_time_per_token == 2].sum()
space_per_token = output.space_mask[:, i::p, i::p]
space_encode_per_instance = (space_per_token[space_per_token == 0] + 1).sum()
space_decode_per_instance = space_per_token[space_per_token == 2].sum()
time_per_token = output.time_mask
time_encode_per_instance = (time_per_token[time_per_token == 0] + 1).sum()
time_decode_per_instance = time_per_token[time_per_token == 2].sum()
static_per_token = output.static_mask
static_encode_per_instance = (static_per_token[static_per_token == 0] + 1).sum()
static_decode_per_instance = static_per_token[static_per_token == 2].sum()
total_tokens = (
(h_tokens * w_tokens * t * len(SPACE_TIME_BANDS_GROUPS_IDX))
# -1 because we have now masked out dynamic world
+ (h_tokens * w_tokens * (len(SPACE_BAND_GROUPS_IDX) - 1))
+ (t * len(TIME_BAND_GROUPS_IDX))
# -1 because we have now masked out dynamic world static
+ (len(STATIC_BAND_GROUPS_IDX) - 1)
) * b
# handles off by one errors
self.assertTrue(
(
space_time_encode_per_instance
+ space_encode_per_instance
+ time_encode_per_instance
+ static_encode_per_instance
<= int(total_tokens * ratio) + 1
).all()
and (
space_time_encode_per_instance
+ space_encode_per_instance
+ time_encode_per_instance
+ static_encode_per_instance
>= int(total_tokens * ratio) - 1
).all()
)
self.assertTrue(
(
# hacky but the / 2 lets us easily handle the fact
# we are summing over values == 2, not 1
(
space_time_decode_per_instance
+ space_decode_per_instance
+ time_decode_per_instance
+ static_decode_per_instance
)
/ 2
<= (int(total_tokens * ratio) + 1)
).all()
and (
(
space_time_decode_per_instance
+ space_decode_per_instance
+ time_decode_per_instance
+ static_decode_per_instance
)
/ 2
>= int(total_tokens * ratio) - 1
).all()
)
# check DW was masked out
expected_mask_index = list(SPACE_BAND_GROUPS_IDX.keys()).index("DW")
self.assertTrue((output.space_mask[:, :, :, expected_mask_index] == 1).all())
expected_mask_index = list(STATIC_BAND_GROUPS_IDX.keys()).index("DW_static")
self.assertTrue((output.static_mask[:, expected_mask_index] == 1).all())
def test_filter_candidates(self):
candidates = [
[("space", "SRTM"), ("space_time", "S2_RGB")],
[("space", "SRTM"), ("space", "DW")],
[("space", "DW")],
]
ignore_bands = ["DW"]
outputs = filter_unmasking_mode_candidates(candidates, ignore_bands)
print(outputs)
self.assertEqual(outputs, [[("space", "SRTM"), ("space_time", "S2_RGB")]])