Spaces:
Sleeping
Sleeping
import unittest | |
from functools import partial | |
from pathlib import Path | |
import torch | |
from torch.utils.data import DataLoader | |
from src.collate_fns import galileo_collate_fn | |
from src.data import Dataset | |
from src.galileo import Decoder, Encoder | |
from src.loss import mse_loss | |
from src.utils import device | |
DATA_FOLDER = Path(__file__).parents[1] / "data/tifs" | |
class TestEndtoEnd(unittest.TestCase): | |
def test_end_to_end(self): | |
self._test_end_to_end() | |
def _test_end_to_end(self): | |
embedding_size = 32 | |
dataset = Dataset(DATA_FOLDER, download=False, h5py_folder=None) | |
dataloader = DataLoader( | |
dataset, | |
batch_size=1, | |
shuffle=True, | |
num_workers=0, | |
collate_fn=partial( | |
galileo_collate_fn, | |
patch_sizes=[1, 2, 3, 4, 5, 6, 7, 8], | |
shape_time_combinations=[ | |
{"size": 4, "timesteps": 12}, | |
{"size": 5, "timesteps": 6}, | |
{"size": 6, "timesteps": 4}, | |
{"size": 7, "timesteps": 3}, | |
{"size": 9, "timesteps": 3}, | |
{"size": 12, "timesteps": 3}, | |
], | |
st_encode_ratio=0.25, | |
st_decode_ratio=0.25, | |
random_encode_ratio=0.25, | |
random_decode_ratio=0.25, | |
random_masking="half", | |
), | |
pin_memory=True, | |
) | |
encoder = Encoder(embedding_size=embedding_size, num_heads=1).to(device) | |
predictor = Decoder( | |
encoder_embedding_size=embedding_size, | |
decoder_embedding_size=embedding_size, | |
num_heads=1, | |
learnable_channel_embeddings=False, | |
).to(device) | |
param_groups = [{"params": encoder.parameters()}, {"params": predictor.parameters()}] | |
optimizer = torch.optim.AdamW(param_groups, lr=3e-4) # type: ignore | |
# let's just consider one of the augmentations | |
for _, bs in enumerate(dataloader): | |
b = bs[0] | |
for x in b: | |
if isinstance(x, torch.Tensor): | |
self.assertFalse(torch.isnan(x).any()) | |
b = [t.to(device) if isinstance(t, torch.Tensor) else t for t in b] | |
( | |
s_t_x, | |
sp_x, | |
t_x, | |
st_x, | |
s_t_m, | |
sp_m, | |
t_m, | |
st_m, | |
months, | |
patch_size, | |
) = b | |
# no autocast since its poorly supported on CPU | |
(p_s_t, p_sp, p_t, p_st) = predictor( | |
*encoder( | |
s_t_x=s_t_x.float(), | |
sp_x=sp_x.float(), | |
t_x=t_x.float(), | |
st_x=st_x.float(), | |
s_t_m=s_t_m.int(), | |
sp_m=sp_m.int(), | |
t_m=t_m.int(), | |
st_m=st_m.int(), | |
months=months.long(), | |
patch_size=patch_size, | |
), | |
patch_size=patch_size, | |
) | |
with torch.no_grad(): | |
t_s_t, t_sp, t_t, t_st, _, _, _, _ = encoder.apply_linear_projection( | |
s_t_x.float(), | |
sp_x.float(), | |
t_x.float(), | |
st_x.float(), | |
~(s_t_m == 2).int(), # we want 0s where the mask == 2 | |
~(sp_m == 2).int(), | |
~(t_m == 2).int(), | |
~(st_m == 2).int(), | |
patch_size, | |
) | |
t_s_t = encoder.blocks[0].norm1(t_s_t) | |
t_sp = encoder.blocks[0].norm1(t_sp) | |
t_sp = encoder.blocks[0].norm1(t_sp) | |
t_st = encoder.blocks[0].norm1(t_st) | |
# commenting out because this fails on the github runner. It doesn't fail locally | |
# or cause problems when running experiments. | |
# self.assertFalse(torch.isnan(p_s_t[s_t_m[:, 0::patch_size, 0::patch_size] == 2]).any()) | |
# self.assertFalse(torch.isnan(p_sp[sp_m[:, 0::patch_size, 0::patch_size] == 2]).any()) | |
# self.assertFalse(torch.isnan(p_t[t_m == 2]).any()) | |
# self.assertFalse(torch.isnan(p_st[st_m == 2]).any()) | |
# self.assertFalse(torch.isnan(t_s_t[s_t_m[:, 0::patch_size, 0::patch_size] == 2]).any()) | |
# self.assertFalse(torch.isnan(t_sp[sp_m[:, 0::patch_size, 0::patch_size] == 2]).any()) | |
# self.assertFalse(torch.isnan(t_t[t_m == 2]).any()) | |
# self.assertFalse(torch.isnan(t_st[st_m == 2]).any()) | |
self.assertTrue( | |
len( | |
torch.concat( | |
[ | |
p_s_t[s_t_m[:, 0::patch_size, 0::patch_size] == 2], | |
p_sp[sp_m[:, 0::patch_size, 0::patch_size] == 2], | |
p_t[t_m == 2], | |
p_st[st_m == 2], | |
] | |
) | |
> 0 | |
) | |
) | |
loss = mse_loss( | |
t_s_t, | |
t_sp, | |
t_t, | |
t_st, | |
p_s_t, | |
p_sp, | |
p_t, | |
p_st, | |
s_t_m[:, 0::patch_size, 0::patch_size], | |
sp_m[:, 0::patch_size, 0::patch_size], | |
t_m, | |
st_m, | |
) | |
# this also only fails on the runner | |
# self.assertFalse(torch.isnan(loss).any()) | |
loss.backward() | |
optimizer.step() | |
# check the channel embeddings in the decoder didn't change | |
self.assertTrue( | |
torch.equal( | |
predictor.s_t_channel_embed, torch.zeros_like(predictor.s_t_channel_embed) | |
) | |
) | |
self.assertTrue( | |
torch.equal( | |
predictor.sp_channel_embed, torch.zeros_like(predictor.sp_channel_embed) | |
) | |
) | |
self.assertTrue( | |
torch.equal(predictor.t_channel_embed, torch.zeros_like(predictor.t_channel_embed)) | |
) | |
self.assertTrue( | |
torch.equal( | |
predictor.st_channel_embed, torch.zeros_like(predictor.st_channel_embed) | |
) | |
) | |