File size: 1,477 Bytes
b20c769
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import unittest

import torch

from src.data.utils import (
    S2_BANDS,
    SPACE_TIME_BANDS,
    SPACE_TIME_BANDS_GROUPS_IDX,
    construct_galileo_input,
)


class TestDataUtils(unittest.TestCase):
    def test_construct_galileo_input_s2(self):
        t, h, w = 2, 4, 4
        s2 = torch.randn((t, h, w, len(S2_BANDS)))
        for normalize in [True, False]:
            masked_output = construct_galileo_input(s2=s2, normalize=normalize)

            self.assertTrue((masked_output.space_mask == 1).all())
            self.assertTrue((masked_output.time_mask == 1).all())
            self.assertTrue((masked_output.static_mask == 1).all())

            # check that only the s2 bands got unmasked
            not_s2 = [
                idx for idx, key in enumerate(SPACE_TIME_BANDS_GROUPS_IDX) if "S2" not in key
            ]
            self.assertTrue((masked_output.space_time_mask[:, :, :, not_s2] == 1).all())
            # and that s2 got unmasked
            s2_mask_indices = [
                idx for idx, key in enumerate(SPACE_TIME_BANDS_GROUPS_IDX) if "S2" in key
            ]
            self.assertTrue((masked_output.space_time_mask[:, :, :, s2_mask_indices] == 0).all())

            # and got assigned to the right indices
            if not normalize:
                s2_indices = [idx for idx, val in enumerate(SPACE_TIME_BANDS) if val in S2_BANDS]
                self.assertTrue(torch.equal(masked_output.space_time_x[:, :, :, s2_indices], s2))