NASA-Galileo / tests /test_galileo.py
openfree's picture
Deploy from GitHub repository
b20c769 verified
import json
import tempfile
import unittest
from pathlib import Path
import torch
from einops import repeat
from single_file_galileo import Encoder as SingleFileEncoder
from src.data import (
SPACE_BAND_GROUPS_IDX,
SPACE_TIME_BANDS_GROUPS_IDX,
STATIC_BAND_GROUPS_IDX,
TIME_BAND_GROUPS_IDX,
Dataset,
)
from src.data.config import CONFIG_FILENAME, ENCODER_FILENAME
from src.data.dataset import DatasetOutput
from src.galileo import Decoder, Encoder
from src.masking import (
MASKING_MODES,
MaskingFunctions,
batch_mask_space,
batch_mask_time,
batch_subset_mask_galileo,
)
from src.utils import device, load_check_config
DATA_FOLDER = Path(__file__).parents[1] / "data"
TIFS_FOLDER = DATA_FOLDER / "tifs"
TEST_MODEL_FOLDER = Path(__file__).parents[0] / "141"
class TestGalileo(unittest.TestCase):
@staticmethod
def to_tensor_with_batch_d(input: DatasetOutput):
return (
torch.from_numpy(input.space_time_x).float().unsqueeze(0),
torch.from_numpy(input.space_x).float().unsqueeze(0),
torch.from_numpy(input.time_x).float().unsqueeze(0),
torch.from_numpy(input.static_x).float().unsqueeze(0),
torch.from_numpy(input.months).long().unsqueeze(0),
)
def test_end_to_end(self):
self._end_to_end_run(16, 8)
def test_end_to_end_different_inputs_per_dim_than_default(self):
self._end_to_end_run(16, 4)
def _end_to_end_run(self, embedding_size, patch_size):
image_size = patch_size * 4
num_timesteps = 3
encoder = Encoder(embedding_size=embedding_size, num_heads=1)
decoder = Decoder(
encoder_embedding_size=embedding_size,
decoder_embedding_size=embedding_size,
num_heads=1,
)
ds = Dataset(TIFS_FOLDER, False)
for i in range(len(ds)):
s_t_x, sp_x, t_x, st_x, months = self.to_tensor_with_batch_d(ds[i])
masked_output = batch_subset_mask_galileo(
s_t_x,
sp_x,
t_x,
st_x,
months,
encode_ratio=0.25,
decode_ratio=0.25,
patch_size=patch_size,
image_size=image_size,
num_timesteps=num_timesteps,
augmentation_strategies=None,
masking_probabilities=[1] * len(MASKING_MODES),
masking_function=MaskingFunctions.SPACE,
max_unmasking_channels=4,
)
# for now, we just make sure it all runs
with torch.autocast(device_type=device.type, dtype=torch.float16):
encoder_output = encoder(
masked_output.space_time_x,
masked_output.space_x,
masked_output.time_x,
masked_output.static_x,
masked_output.space_time_mask,
masked_output.space_mask,
masked_output.time_mask,
masked_output.static_mask,
masked_output.months.long(),
patch_size=patch_size,
)
output = decoder(*encoder_output)
with torch.no_grad():
t_s_t, t_sp, t_t, t_st, _, _, _, _ = encoder.apply_linear_projection(
masked_output.space_time_x,
masked_output.space_x,
masked_output.time_x,
masked_output.static_x,
~(masked_output.space_time_mask == 2), # we want 0s where the mask == 2
~(masked_output.space_mask == 2),
~(masked_output.time_mask == 2),
~(masked_output.static_mask == 2),
patch_size,
)
t_s_t = encoder.blocks[0].norm1(t_s_t)
t_sp = encoder.blocks[0].norm1(t_sp)
t_sp = encoder.blocks[0].norm1(t_sp)
t_st = encoder.blocks[0].norm1(t_st)
self.assertFalse(
torch.isnan(
t_s_t[masked_output.space_time_mask[:, 0::patch_size, 0::patch_size] == 2]
).any()
)
self.assertFalse(
torch.isnan(
t_sp[masked_output.space_mask[:, 0::patch_size, 0::patch_size] == 2]
).any()
)
self.assertFalse(torch.isnan(t_t[masked_output.time_mask == 2]).any())
self.assertFalse(torch.isnan(t_st[masked_output.static_mask == 2]).any())
self.assertTrue(
list(encoder_output[0].shape)
== [
1,
image_size / patch_size,
image_size / patch_size,
num_timesteps,
len(SPACE_TIME_BANDS_GROUPS_IDX),
embedding_size,
]
)
self.assertTrue(
list(encoder_output[1].shape)
== [
1,
image_size / patch_size,
image_size / patch_size,
len(SPACE_BAND_GROUPS_IDX),
embedding_size,
]
)
self.assertTrue(
list(encoder_output[2].shape)
== [
1,
num_timesteps,
len(TIME_BAND_GROUPS_IDX),
embedding_size,
]
)
self.assertTrue(
list(encoder_output[3].shape)
== [
1,
len(STATIC_BAND_GROUPS_IDX),
embedding_size,
]
)
self.assertFalse(
torch.isnan(
encoder_output[0][
masked_output.space_time_mask[:, 0::patch_size, 0::patch_size] == 0
]
).any()
)
self.assertFalse(
torch.isnan(
encoder_output[1][
masked_output.space_mask[:, 0::patch_size, 0::patch_size] == 0
]
).any()
)
self.assertFalse(torch.isnan(encoder_output[2][masked_output.time_mask == 0]).any())
self.assertFalse(torch.isnan(encoder_output[3][masked_output.static_mask == 0]).any())
self.assertTrue(
list(output[0].shape)
== [
1,
image_size / patch_size,
image_size / patch_size,
num_timesteps,
len(SPACE_TIME_BANDS_GROUPS_IDX),
embedding_size,
]
)
self.assertTrue(
list(output[1].shape)
== [
1,
image_size / patch_size,
image_size / patch_size,
len(SPACE_BAND_GROUPS_IDX),
embedding_size,
]
)
self.assertTrue(
list(output[2].shape)
== [1, num_timesteps, len(TIME_BAND_GROUPS_IDX), embedding_size]
)
self.assertTrue(
list(output[3].shape) == [1, len(STATIC_BAND_GROUPS_IDX), embedding_size]
)
self.assertFalse(
torch.isnan(
output[0][masked_output.space_time_mask[:, 0::patch_size, 0::patch_size] == 2]
).any()
)
self.assertFalse(
torch.isnan(
output[1][masked_output.space_mask[:, 0::patch_size, 0::patch_size] == 2]
).any()
)
self.assertFalse(torch.isnan(output[2][masked_output.time_mask == 2]).any())
self.assertFalse(torch.isnan(output[3][masked_output.static_mask == 2]).any())
# check we can call backwards, with the loss
summed_output = sum([torch.sum(o) for o in output])
summed_output.backward()
def test_decoder_add_masks(self):
embedding_size = 16
decoder = Decoder(
encoder_embedding_size=embedding_size,
decoder_embedding_size=embedding_size,
num_heads=1,
)
b, h, w, t = 5, 6, 7, 8
s_t_x = torch.ones(b, h, w, t, len(SPACE_TIME_BANDS_GROUPS_IDX), embedding_size)
s_t_m = torch.zeros(b, h, w, t, len(SPACE_TIME_BANDS_GROUPS_IDX))
s_t_m[:, :, :, 0] = 2 # the first timestep will get processed by the decoder
s_t_m[:, :, :, 1] = 1 # the second timestep gets masked but not processed
sp_x = torch.ones(b, h, w, len(SPACE_BAND_GROUPS_IDX), embedding_size)
sp_m = torch.zeros(b, h, w, len(SPACE_BAND_GROUPS_IDX))
sp_m[:, 0] = 2
sp_m[:, 1] = 1
t_x = torch.ones(b, t, len(TIME_BAND_GROUPS_IDX), embedding_size)
t_m = torch.zeros(b, t, len(TIME_BAND_GROUPS_IDX))
t_m[:, 0] = 2
t_m[:, 1] = 1
st_x = torch.ones(b, len(STATIC_BAND_GROUPS_IDX), embedding_size)
st_m = torch.zeros(b, len(STATIC_BAND_GROUPS_IDX))
st_m[:, 0] = 2
st_m[:, 1] = 1
with torch.no_grad():
o = decoder.add_masks(s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m)
self.assertTrue((o[0][:, :, :, 0] == 0).all())
self.assertTrue((o[0][:, :, :, 1:] == 1).all())
self.assertTrue((o[1][:, 0] == 0).all())
self.assertTrue((o[1][:, 1:] == 1).all())
self.assertTrue((o[2][:, 0] == 0).all())
self.assertTrue((o[2][:, 1:] == 1).all())
self.assertTrue((o[3][:, 0] == 0).all())
self.assertTrue((o[3][:, 1:] == 1).all())
def test_mean_of_tokens(self):
b, t, d, h, w, s_t_c_g, sp_c_g, t_c_g, st_c_g = 1, 2, 8, 3, 3, 5, 6, 2, 4
s_t_x = torch.ones((b, h, w, t, s_t_c_g, d))
sp_x = torch.ones((b, h, w, sp_c_g, d))
t_x = torch.ones((b, t, t_c_g, d))
st_x = torch.ones((b, st_c_g, d))
# the first timestep and the first column are masked
s_t_m = torch.zeros((b, h, w, t, s_t_c_g))
s_t_m[:, :, 0, :] = 1
s_t_m[:, :, :, 0] = 1
s_t_x[:, :, 0, :] = 0
s_t_x[:, :, :, 0] = 0
# the last row is masked
sp_m = torch.zeros((b, h, w, sp_c_g))
sp_m[:, -1, :] = 1
sp_x[:, -1, :] = 0
# the first timestep is masked
t_m = torch.zeros((b, t, t_c_g))
t_m[:, 0] = 1
t_x[:, 0] = 0
# the last column is masked
st_m = torch.zeros((b, st_c_g))
st_m[:, -1] = 1
st_x[:, -1] = 0
mean = Encoder.average_tokens(s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m)
self.assertEqual(mean.shape, (b, d))
self.assertTrue((mean == 1).all())
def test_mask_and_unmask_tokens(self):
b, d = 2, 2
x = torch.tensor([[0, 1, 0], [1, 0, 1]]).float()
x = repeat(x, "b n -> b n d", d=d)
mask = torch.tensor([[1, 0, 1], [0, 1, 0]]).float()
out_x, indices, updated_mask = Encoder.remove_masked_tokens(x, mask)
self.assertEqual(out_x.dtype, x.dtype)
self.assertEqual(updated_mask.dtype, mask.dtype)
self.assertEqual(out_x.shape, (b, 2, d))
# for the 2nd item in the batch, there should be only 0s
self.assertTrue(torch.equal(out_x[1], torch.ones_like(out_x[1])))
# for the first item in the batch, only the first index is unmasked so
# it should be at the front
self.assertEqual(indices[0, 0], 1)
# for the second item, the 0th and 2nd are masked
self.assertTrue(torch.equal(indices[1, :2], torch.tensor([0, 2])))
self.assertEqual(updated_mask.shape, (b, 2))
self.assertTrue(torch.equal(updated_mask, torch.Tensor([[0, 1], [0, 0]])))
# check that when we add things back, they are once again what we had originally
final_x, final_mask = Encoder.add_removed_tokens(out_x, indices, updated_mask)
self.assertEqual(final_x.dtype, x.dtype)
self.assertEqual(final_mask.dtype, mask.dtype)
self.assertTrue(torch.equal(final_x, x))
self.assertTrue(torch.equal(final_mask, mask))
def test_combine_x_y(self):
# x is the query (i.e. the masked tokens)
x = torch.tensor([[14, 15, 16], [15, 16, 1]]).unsqueeze(-1)
# y is the keys and values (i.e. the unmasked tokens)
y = torch.tensor([[5, 6, 7, 8], [4, 5, 6, 7]]).unsqueeze(-1)
x_mask = torch.tensor([[1, 1, 1], [1, 1, 0]])
y_mask = torch.tensor([[1, 1, 1, 1], [0, 1, 1, 1]])
indices = torch.tensor([[6, 7, 8, 4, 5, 0, 1, 2, 3], [7, 8, 3, 4, 5, 6, 0, 1, 2]])
tokens = Decoder.combine_x_y(x, y, x_mask, y_mask, indices)
self.assertTrue(
torch.equal(
tokens,
torch.tensor(
[[5, 6, 7, 8, 0, 0, 14, 15, 16], [5, 6, 7, 0, 0, 0, 0, 15, 16]]
).unsqueeze(-1),
)
)
def test_split_x_y(self):
tokens = torch.tensor(
[[5, 6, 7, 8, 2, 13, 14, 15, 16], [5, 6, 7, 1, 2, 3, 4, 15, 16]]
).unsqueeze(-1)
mask = torch.tensor([[0, 0, 0, 0, 1, 1, 2, 2, 2], [0, 0, 0, 1, 1, 1, 1, 2, 2]])
x, y, x_mask, y_mask, _ = Decoder.split_x_y(tokens, mask)
self.assertTrue(torch.equal(x, torch.tensor([[14, 15, 16], [15, 16, 1]]).unsqueeze(-1)))
self.assertTrue(torch.equal(y, torch.tensor([[5, 6, 7, 8], [4, 5, 6, 7]]).unsqueeze(-1)))
self.assertTrue(torch.equal(x_mask, torch.tensor([[1, 1, 1], [1, 1, 0]])))
self.assertTrue(torch.equal(y_mask, torch.tensor([[1, 1, 1, 1], [0, 1, 1, 1]])))
def test_x_y_there_and_back_again(self):
tokens = torch.tensor(
[[5, 6, 7, 8, 2, 13, 14, 15, 16], [5, 6, 7, 1, 2, 3, 4, 15, 16]]
).unsqueeze(-1)
mask = torch.tensor([[0, 0, 0, 0, 1, 1, 2, 2, 2], [0, 0, 0, 1, 1, 1, 1, 2, 2]])
x, y, x_mask, y_mask, indices = Decoder.split_x_y(tokens, mask)
new_tokens = Decoder.combine_x_y(x, y, x_mask, y_mask, indices)
tokens[mask == 1] = 0
self.assertTrue(torch.equal(tokens, new_tokens))
def test_load_from_device(self):
config = load_check_config("nano.json")
original_encoder = Encoder(**config["model"]["encoder"])
with tempfile.TemporaryDirectory() as tempdir:
torch.save(original_encoder.state_dict(), Path(tempdir) / ENCODER_FILENAME)
with (Path(tempdir) / CONFIG_FILENAME).open("w") as f:
json.dump(config, f)
new_encoder = Encoder.load_from_folder(Path(tempdir))
for key, val in new_encoder.state_dict().items():
self.assertTrue(torch.equal(val, original_encoder.state_dict()[key]))
def test_decoder_and_mask_static(self):
patch_size = 4
ratio = 0.25
ds = Dataset(TIFS_FOLDER, False)
tensor_batch = self.to_tensor_with_batch_d(ds[0])
self.assertTrue(tensor_batch[0].shape[1] == tensor_batch[0].shape[2])
for f in [batch_mask_time, batch_mask_space]:
masked_output = f(
*tensor_batch,
encode_ratio=ratio,
decode_ratio=ratio,
mode=[("space", "DW")],
decoder_mode=[("static", "LS")],
patch_size=patch_size,
)
encoder = Encoder(embedding_size=32, num_heads=1)
decoder = Decoder(
encoder_embedding_size=32,
decoder_embedding_size=32,
num_heads=1,
)
encoder_output = encoder(
masked_output.space_time_x,
masked_output.space_x,
masked_output.time_x,
masked_output.static_x,
masked_output.space_time_mask,
masked_output.space_mask,
masked_output.time_mask,
masked_output.static_mask,
masked_output.months.long(),
patch_size=patch_size,
)
s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m, _ = encoder_output
x, m = decoder.collapse_and_combine_hwtc(
s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m
)
x, _, _, _, _ = decoder.split_x_y(x, m)
self.assertTrue(x.shape[1] == 1, x.shape)
def test_token_exit_cfgs_single_exit_equivalency(self):
self._token_exit_cfgs_single_exit_equivalency(0)
self._token_exit_cfgs_single_exit_equivalency(6)
self._token_exit_cfgs_single_exit_equivalency(12)
@torch.no_grad()
def _token_exit_cfgs_single_exit_equivalency(self, depth):
embedding_size, patch_size = 16, 1
image_size = patch_size * 4
num_timesteps = 3
encoder = Encoder(embedding_size=embedding_size, num_heads=1, depth=12)
encoder.eval()
ds = Dataset(TIFS_FOLDER, False)
for i in range(len(ds)):
s_t_x, sp_x, t_x, st_x, months = self.to_tensor_with_batch_d(ds[i])
masked_output = batch_subset_mask_galileo(
s_t_x,
sp_x,
t_x,
st_x,
months,
encode_ratio=0.25,
decode_ratio=0.25,
patch_size=patch_size,
image_size=image_size,
num_timesteps=num_timesteps,
augmentation_strategies=None,
masking_probabilities=[1] * len(MASKING_MODES),
masking_function=MaskingFunctions.SPACE,
max_unmasking_channels=4,
)
# for this test, we will keep the same
# values per shape since we call layer norm
# on each shape output
token_exit_cfgs = {
"S1": depth,
"S2_RGB": depth,
"S2_Red_Edge": depth,
"S2_NIR_10m": depth,
"S2_NIR_20m": depth,
"S2_SWIR": depth,
"NDVI": depth,
"ERA5": depth,
"TC": depth,
"VIIRS": depth,
"SRTM": depth,
"DW": depth,
"WC": depth,
"LS": depth,
"location": depth,
"DW_static": depth,
"WC_static": depth,
}
encoder_output_depth = encoder(
masked_output.space_time_x,
masked_output.space_x,
masked_output.time_x,
masked_output.static_x,
torch.zeros_like(masked_output.space_time_mask),
torch.zeros_like(masked_output.space_mask),
torch.zeros_like(masked_output.time_mask),
torch.zeros_like(masked_output.static_mask),
masked_output.months.long(),
patch_size=patch_size,
exit_after=depth,
)
encoder_output_depth_varied = encoder(
masked_output.space_time_x,
masked_output.space_x,
masked_output.time_x,
masked_output.static_x,
torch.zeros_like(masked_output.space_time_mask),
torch.zeros_like(masked_output.space_mask),
torch.zeros_like(masked_output.time_mask),
torch.zeros_like(masked_output.static_mask),
masked_output.months.long(),
patch_size=patch_size,
token_exit_cfg=token_exit_cfgs,
exit_after=None,
)
# s_t_x
self.assertTrue(torch.equal(encoder_output_depth_varied[0], encoder_output_depth[0]))
# sp_x
self.assertTrue(torch.equal(encoder_output_depth_varied[1], encoder_output_depth[1]))
# t_x
self.assertTrue(torch.equal(encoder_output_depth_varied[2], encoder_output_depth[2]))
# st_x
self.assertTrue(torch.equal(encoder_output_depth_varied[3], encoder_output_depth[3]))
def test_single_file_galileo_matches_galileo(self):
org_model = Encoder.load_from_folder(DATA_FOLDER / "models/nano")
sf_model = SingleFileEncoder.load_from_folder(
DATA_FOLDER / "models/nano", device=torch.device("cpu")
)
for model_p, sf_model_p in zip(org_model.parameters(), sf_model.parameters()):
self.assertTrue(torch.equal(model_p, sf_model_p))