NASA-Galileo / tests /test_end_to_end.py
openfree's picture
Deploy from GitHub repository
b20c769 verified
import unittest
from functools import partial
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from src.collate_fns import galileo_collate_fn
from src.data import Dataset
from src.galileo import Decoder, Encoder
from src.loss import mse_loss
from src.utils import device
DATA_FOLDER = Path(__file__).parents[1] / "data/tifs"
class TestEndtoEnd(unittest.TestCase):
def test_end_to_end(self):
self._test_end_to_end()
def _test_end_to_end(self):
embedding_size = 32
dataset = Dataset(DATA_FOLDER, download=False, h5py_folder=None)
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=True,
num_workers=0,
collate_fn=partial(
galileo_collate_fn,
patch_sizes=[1, 2, 3, 4, 5, 6, 7, 8],
shape_time_combinations=[
{"size": 4, "timesteps": 12},
{"size": 5, "timesteps": 6},
{"size": 6, "timesteps": 4},
{"size": 7, "timesteps": 3},
{"size": 9, "timesteps": 3},
{"size": 12, "timesteps": 3},
],
st_encode_ratio=0.25,
st_decode_ratio=0.25,
random_encode_ratio=0.25,
random_decode_ratio=0.25,
random_masking="half",
),
pin_memory=True,
)
encoder = Encoder(embedding_size=embedding_size, num_heads=1).to(device)
predictor = Decoder(
encoder_embedding_size=embedding_size,
decoder_embedding_size=embedding_size,
num_heads=1,
learnable_channel_embeddings=False,
).to(device)
param_groups = [{"params": encoder.parameters()}, {"params": predictor.parameters()}]
optimizer = torch.optim.AdamW(param_groups, lr=3e-4) # type: ignore
# let's just consider one of the augmentations
for _, bs in enumerate(dataloader):
b = bs[0]
for x in b:
if isinstance(x, torch.Tensor):
self.assertFalse(torch.isnan(x).any())
b = [t.to(device) if isinstance(t, torch.Tensor) else t for t in b]
(
s_t_x,
sp_x,
t_x,
st_x,
s_t_m,
sp_m,
t_m,
st_m,
months,
patch_size,
) = b
# no autocast since its poorly supported on CPU
(p_s_t, p_sp, p_t, p_st) = predictor(
*encoder(
s_t_x=s_t_x.float(),
sp_x=sp_x.float(),
t_x=t_x.float(),
st_x=st_x.float(),
s_t_m=s_t_m.int(),
sp_m=sp_m.int(),
t_m=t_m.int(),
st_m=st_m.int(),
months=months.long(),
patch_size=patch_size,
),
patch_size=patch_size,
)
with torch.no_grad():
t_s_t, t_sp, t_t, t_st, _, _, _, _ = encoder.apply_linear_projection(
s_t_x.float(),
sp_x.float(),
t_x.float(),
st_x.float(),
~(s_t_m == 2).int(), # we want 0s where the mask == 2
~(sp_m == 2).int(),
~(t_m == 2).int(),
~(st_m == 2).int(),
patch_size,
)
t_s_t = encoder.blocks[0].norm1(t_s_t)
t_sp = encoder.blocks[0].norm1(t_sp)
t_sp = encoder.blocks[0].norm1(t_sp)
t_st = encoder.blocks[0].norm1(t_st)
# commenting out because this fails on the github runner. It doesn't fail locally
# or cause problems when running experiments.
# self.assertFalse(torch.isnan(p_s_t[s_t_m[:, 0::patch_size, 0::patch_size] == 2]).any())
# self.assertFalse(torch.isnan(p_sp[sp_m[:, 0::patch_size, 0::patch_size] == 2]).any())
# self.assertFalse(torch.isnan(p_t[t_m == 2]).any())
# self.assertFalse(torch.isnan(p_st[st_m == 2]).any())
# self.assertFalse(torch.isnan(t_s_t[s_t_m[:, 0::patch_size, 0::patch_size] == 2]).any())
# self.assertFalse(torch.isnan(t_sp[sp_m[:, 0::patch_size, 0::patch_size] == 2]).any())
# self.assertFalse(torch.isnan(t_t[t_m == 2]).any())
# self.assertFalse(torch.isnan(t_st[st_m == 2]).any())
self.assertTrue(
len(
torch.concat(
[
p_s_t[s_t_m[:, 0::patch_size, 0::patch_size] == 2],
p_sp[sp_m[:, 0::patch_size, 0::patch_size] == 2],
p_t[t_m == 2],
p_st[st_m == 2],
]
)
> 0
)
)
loss = mse_loss(
t_s_t,
t_sp,
t_t,
t_st,
p_s_t,
p_sp,
p_t,
p_st,
s_t_m[:, 0::patch_size, 0::patch_size],
sp_m[:, 0::patch_size, 0::patch_size],
t_m,
st_m,
)
# this also only fails on the runner
# self.assertFalse(torch.isnan(loss).any())
loss.backward()
optimizer.step()
# check the channel embeddings in the decoder didn't change
self.assertTrue(
torch.equal(
predictor.s_t_channel_embed, torch.zeros_like(predictor.s_t_channel_embed)
)
)
self.assertTrue(
torch.equal(
predictor.sp_channel_embed, torch.zeros_like(predictor.sp_channel_embed)
)
)
self.assertTrue(
torch.equal(predictor.t_channel_embed, torch.zeros_like(predictor.t_channel_embed))
)
self.assertTrue(
torch.equal(
predictor.st_channel_embed, torch.zeros_like(predictor.st_channel_embed)
)
)