Spaces:
Running
Running
File size: 3,422 Bytes
b20c769 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import unittest
import numpy as np
from src.data.dataset import (
SPACE_BAND_GROUPS_IDX,
SPACE_BANDS,
SPACE_TIME_BANDS,
SPACE_TIME_BANDS_GROUPS_IDX,
STATIC_BAND_GROUPS_IDX,
STATIC_BANDS,
TIME_BAND_GROUPS_IDX,
TIME_BANDS,
Normalizer,
)
from src.eval.cropharvest.cropharvest_eval import BANDS, BinaryCropHarvestEval
class TestCropHarvest(unittest.TestCase):
def test_to_galileo_arrays(self):
for ignore_band_groups in [["S1"], None]:
for include_latlons in [True, False]:
class BinaryCropHarvestEvalNoDownload(BinaryCropHarvestEval):
def __init__(self, ignore_band_groups, include_latlons: bool = True):
self.include_latlons = include_latlons
self.normalizer = Normalizer(std=False)
self.ignore_s_t_band_groups = self.indices_of_ignored(
SPACE_TIME_BANDS_GROUPS_IDX, ignore_band_groups
)
self.ignore_sp_band_groups = self.indices_of_ignored(
SPACE_BAND_GROUPS_IDX, ignore_band_groups
)
self.ignore_t_band_groups = self.indices_of_ignored(
TIME_BAND_GROUPS_IDX, ignore_band_groups
)
self.ignore_st_band_groups = self.indices_of_ignored(
STATIC_BAND_GROUPS_IDX, ignore_band_groups
)
eval = BinaryCropHarvestEvalNoDownload(ignore_band_groups, include_latlons)
b, t = 8, 12
array = np.ones((b, t, len(BANDS)))
latlons = np.ones((b, 2))
(
s_t_x,
sp_x,
t_x,
st_x,
s_t_m,
sp_m,
t_m,
st_m,
months,
) = eval.cropharvest_array_to_normalized_galileo(array, latlons, start_month=1)
self.assertEqual(s_t_x.shape, (b, 1, 1, t, len(SPACE_TIME_BANDS)))
self.assertEqual(s_t_m.shape, (b, 1, 1, t, len(SPACE_TIME_BANDS_GROUPS_IDX)))
if ignore_band_groups is not None:
# check s1 got masked
self.assertTrue(
(
s_t_m[:, :, :, :, list(SPACE_TIME_BANDS_GROUPS_IDX.keys()).index("S1")]
== 1
).all()
)
self.assertEqual(sp_x.shape, (b, 1, 1, len(SPACE_BANDS)))
self.assertEqual(sp_m.shape, (b, 1, 1, len(SPACE_BAND_GROUPS_IDX)))
self.assertTrue(
(sp_m[:, :, :, list(SPACE_BAND_GROUPS_IDX.keys()).index("SRTM")] == 0).all()
)
self.assertTrue(
(sp_m[:, :, :, list(SPACE_BAND_GROUPS_IDX.keys()).index("DW")] == 1).all()
)
self.assertEqual(t_x.shape, (b, t, len(TIME_BANDS)))
self.assertEqual(t_m.shape, (b, t, len(TIME_BAND_GROUPS_IDX)))
self.assertEqual(st_x.shape, (b, len(STATIC_BANDS)))
self.assertEqual(st_m.shape, (b, len(STATIC_BAND_GROUPS_IDX)))
self.assertEqual(months.shape, (b, t))
|