File size: 2,529 Bytes
b20c769
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import unittest

import torch

from src.data.dataset import (
    SPACE_BAND_GROUPS_IDX,
    SPACE_TIME_BANDS_GROUPS_IDX,
    STATIC_BAND_GROUPS_IDX,
    TIME_BAND_GROUPS_IDX,
)
from src.loss import mae_loss


class TestLoss(unittest.TestCase):
    def test_mae_loss(self):
        b, t_h, t_w, t, patch_size = 16, 4, 4, 3, 2
        pixel_h, pixel_w = t_h * patch_size, t_w * patch_size
        max_patch_size = 8
        max_group_length = max(
            [
                max([len(v) for _, v in SPACE_TIME_BANDS_GROUPS_IDX.items()]),
                max([len(v) for _, v in TIME_BAND_GROUPS_IDX.items()]),
                max([len(v) for _, v in SPACE_BAND_GROUPS_IDX.items()]),
                max([len(v) for _, v in STATIC_BAND_GROUPS_IDX.items()]),
            ]
        )
        p_s_t = torch.randn(
            (
                b,
                t_h,
                t_w,
                t,
                len(SPACE_TIME_BANDS_GROUPS_IDX),
                max_group_length * (max_patch_size**2),
            )
        )
        p_sp = torch.randn(
            (b, t_h, t_w, len(SPACE_BAND_GROUPS_IDX), max_group_length * (max_patch_size**2))
        )
        p_t = torch.randn(
            (b, t, len(TIME_BAND_GROUPS_IDX), max_group_length * (max_patch_size**2))
        )
        p_st = torch.randn(
            (b, len(STATIC_BAND_GROUPS_IDX), max_group_length * (max_patch_size**2))
        )
        s_t_x = torch.randn(
            b, pixel_h, pixel_w, t, sum([len(x) for _, x in SPACE_TIME_BANDS_GROUPS_IDX.items()])
        )
        sp_x = torch.randn(
            b, pixel_h, pixel_w, sum([len(x) for _, x in SPACE_BAND_GROUPS_IDX.items()])
        )
        t_x = torch.randn(b, t, sum([len(x) for _, x in TIME_BAND_GROUPS_IDX.items()]))
        st_x = torch.randn(b, sum([len(x) for _, x in STATIC_BAND_GROUPS_IDX.items()]))
        s_t_m = torch.ones((b, pixel_h, pixel_w, t, len(SPACE_TIME_BANDS_GROUPS_IDX))) * 2
        sp_m = torch.ones((b, pixel_h, pixel_w, len(SPACE_BAND_GROUPS_IDX))) * 2
        t_m = torch.ones((b, t, len(TIME_BAND_GROUPS_IDX))) * 2
        st_m = torch.ones((b, len(STATIC_BAND_GROUPS_IDX))) * 2
        max_patch_size = 8

        loss = mae_loss(
            p_s_t,
            p_sp,
            p_t,
            p_st,
            s_t_x,
            sp_x,
            t_x,
            st_x,
            s_t_m,
            sp_m,
            t_m,
            st_m,
            patch_size,
            max_patch_size,
        )
        self.assertFalse(torch.isnan(loss))