Spaces:
Running
Running
File size: 2,529 Bytes
b20c769 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import unittest
import torch
from src.data.dataset import (
SPACE_BAND_GROUPS_IDX,
SPACE_TIME_BANDS_GROUPS_IDX,
STATIC_BAND_GROUPS_IDX,
TIME_BAND_GROUPS_IDX,
)
from src.loss import mae_loss
class TestLoss(unittest.TestCase):
def test_mae_loss(self):
b, t_h, t_w, t, patch_size = 16, 4, 4, 3, 2
pixel_h, pixel_w = t_h * patch_size, t_w * patch_size
max_patch_size = 8
max_group_length = max(
[
max([len(v) for _, v in SPACE_TIME_BANDS_GROUPS_IDX.items()]),
max([len(v) for _, v in TIME_BAND_GROUPS_IDX.items()]),
max([len(v) for _, v in SPACE_BAND_GROUPS_IDX.items()]),
max([len(v) for _, v in STATIC_BAND_GROUPS_IDX.items()]),
]
)
p_s_t = torch.randn(
(
b,
t_h,
t_w,
t,
len(SPACE_TIME_BANDS_GROUPS_IDX),
max_group_length * (max_patch_size**2),
)
)
p_sp = torch.randn(
(b, t_h, t_w, len(SPACE_BAND_GROUPS_IDX), max_group_length * (max_patch_size**2))
)
p_t = torch.randn(
(b, t, len(TIME_BAND_GROUPS_IDX), max_group_length * (max_patch_size**2))
)
p_st = torch.randn(
(b, len(STATIC_BAND_GROUPS_IDX), max_group_length * (max_patch_size**2))
)
s_t_x = torch.randn(
b, pixel_h, pixel_w, t, sum([len(x) for _, x in SPACE_TIME_BANDS_GROUPS_IDX.items()])
)
sp_x = torch.randn(
b, pixel_h, pixel_w, sum([len(x) for _, x in SPACE_BAND_GROUPS_IDX.items()])
)
t_x = torch.randn(b, t, sum([len(x) for _, x in TIME_BAND_GROUPS_IDX.items()]))
st_x = torch.randn(b, sum([len(x) for _, x in STATIC_BAND_GROUPS_IDX.items()]))
s_t_m = torch.ones((b, pixel_h, pixel_w, t, len(SPACE_TIME_BANDS_GROUPS_IDX))) * 2
sp_m = torch.ones((b, pixel_h, pixel_w, len(SPACE_BAND_GROUPS_IDX))) * 2
t_m = torch.ones((b, t, len(TIME_BAND_GROUPS_IDX))) * 2
st_m = torch.ones((b, len(STATIC_BAND_GROUPS_IDX))) * 2
max_patch_size = 8
loss = mae_loss(
p_s_t,
p_sp,
p_t,
p_st,
s_t_x,
sp_x,
t_x,
st_x,
s_t_m,
sp_m,
t_m,
st_m,
patch_size,
max_patch_size,
)
self.assertFalse(torch.isnan(loss))
|