File size: 3,422 Bytes
b20c769
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import unittest

import numpy as np

from src.data.dataset import (
    SPACE_BAND_GROUPS_IDX,
    SPACE_BANDS,
    SPACE_TIME_BANDS,
    SPACE_TIME_BANDS_GROUPS_IDX,
    STATIC_BAND_GROUPS_IDX,
    STATIC_BANDS,
    TIME_BAND_GROUPS_IDX,
    TIME_BANDS,
    Normalizer,
)
from src.eval.cropharvest.cropharvest_eval import BANDS, BinaryCropHarvestEval


class TestCropHarvest(unittest.TestCase):
    def test_to_galileo_arrays(self):
        for ignore_band_groups in [["S1"], None]:
            for include_latlons in [True, False]:

                class BinaryCropHarvestEvalNoDownload(BinaryCropHarvestEval):
                    def __init__(self, ignore_band_groups, include_latlons: bool = True):
                        self.include_latlons = include_latlons
                        self.normalizer = Normalizer(std=False)
                        self.ignore_s_t_band_groups = self.indices_of_ignored(
                            SPACE_TIME_BANDS_GROUPS_IDX, ignore_band_groups
                        )
                        self.ignore_sp_band_groups = self.indices_of_ignored(
                            SPACE_BAND_GROUPS_IDX, ignore_band_groups
                        )
                        self.ignore_t_band_groups = self.indices_of_ignored(
                            TIME_BAND_GROUPS_IDX, ignore_band_groups
                        )
                        self.ignore_st_band_groups = self.indices_of_ignored(
                            STATIC_BAND_GROUPS_IDX, ignore_band_groups
                        )

                eval = BinaryCropHarvestEvalNoDownload(ignore_band_groups, include_latlons)
                b, t = 8, 12
                array = np.ones((b, t, len(BANDS)))
                latlons = np.ones((b, 2))
                (
                    s_t_x,
                    sp_x,
                    t_x,
                    st_x,
                    s_t_m,
                    sp_m,
                    t_m,
                    st_m,
                    months,
                ) = eval.cropharvest_array_to_normalized_galileo(array, latlons, start_month=1)

                self.assertEqual(s_t_x.shape, (b, 1, 1, t, len(SPACE_TIME_BANDS)))
                self.assertEqual(s_t_m.shape, (b, 1, 1, t, len(SPACE_TIME_BANDS_GROUPS_IDX)))
                if ignore_band_groups is not None:
                    # check s1 got masked
                    self.assertTrue(
                        (
                            s_t_m[:, :, :, :, list(SPACE_TIME_BANDS_GROUPS_IDX.keys()).index("S1")]
                            == 1
                        ).all()
                    )
                self.assertEqual(sp_x.shape, (b, 1, 1, len(SPACE_BANDS)))
                self.assertEqual(sp_m.shape, (b, 1, 1, len(SPACE_BAND_GROUPS_IDX)))
                self.assertTrue(
                    (sp_m[:, :, :, list(SPACE_BAND_GROUPS_IDX.keys()).index("SRTM")] == 0).all()
                )
                self.assertTrue(
                    (sp_m[:, :, :, list(SPACE_BAND_GROUPS_IDX.keys()).index("DW")] == 1).all()
                )
                self.assertEqual(t_x.shape, (b, t, len(TIME_BANDS)))
                self.assertEqual(t_m.shape, (b, t, len(TIME_BAND_GROUPS_IDX)))
                self.assertEqual(st_x.shape, (b, len(STATIC_BANDS)))
                self.assertEqual(st_m.shape, (b, len(STATIC_BAND_GROUPS_IDX)))
                self.assertEqual(months.shape, (b, t))