import unittest import numpy as np from src.data.dataset import ( SPACE_BAND_GROUPS_IDX, SPACE_BANDS, SPACE_TIME_BANDS, SPACE_TIME_BANDS_GROUPS_IDX, STATIC_BAND_GROUPS_IDX, STATIC_BANDS, TIME_BAND_GROUPS_IDX, TIME_BANDS, Normalizer, ) from src.eval.cropharvest.cropharvest_eval import BANDS, BinaryCropHarvestEval class TestCropHarvest(unittest.TestCase): def test_to_galileo_arrays(self): for ignore_band_groups in [["S1"], None]: for include_latlons in [True, False]: class BinaryCropHarvestEvalNoDownload(BinaryCropHarvestEval): def __init__(self, ignore_band_groups, include_latlons: bool = True): self.include_latlons = include_latlons self.normalizer = Normalizer(std=False) self.ignore_s_t_band_groups = self.indices_of_ignored( SPACE_TIME_BANDS_GROUPS_IDX, ignore_band_groups ) self.ignore_sp_band_groups = self.indices_of_ignored( SPACE_BAND_GROUPS_IDX, ignore_band_groups ) self.ignore_t_band_groups = self.indices_of_ignored( TIME_BAND_GROUPS_IDX, ignore_band_groups ) self.ignore_st_band_groups = self.indices_of_ignored( STATIC_BAND_GROUPS_IDX, ignore_band_groups ) eval = BinaryCropHarvestEvalNoDownload(ignore_band_groups, include_latlons) b, t = 8, 12 array = np.ones((b, t, len(BANDS))) latlons = np.ones((b, 2)) ( s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m, months, ) = eval.cropharvest_array_to_normalized_galileo(array, latlons, start_month=1) self.assertEqual(s_t_x.shape, (b, 1, 1, t, len(SPACE_TIME_BANDS))) self.assertEqual(s_t_m.shape, (b, 1, 1, t, len(SPACE_TIME_BANDS_GROUPS_IDX))) if ignore_band_groups is not None: # check s1 got masked self.assertTrue( ( s_t_m[:, :, :, :, list(SPACE_TIME_BANDS_GROUPS_IDX.keys()).index("S1")] == 1 ).all() ) self.assertEqual(sp_x.shape, (b, 1, 1, len(SPACE_BANDS))) self.assertEqual(sp_m.shape, (b, 1, 1, len(SPACE_BAND_GROUPS_IDX))) self.assertTrue( (sp_m[:, :, :, list(SPACE_BAND_GROUPS_IDX.keys()).index("SRTM")] == 0).all() ) self.assertTrue( (sp_m[:, :, :, list(SPACE_BAND_GROUPS_IDX.keys()).index("DW")] == 1).all() ) self.assertEqual(t_x.shape, (b, t, len(TIME_BANDS))) self.assertEqual(t_m.shape, (b, t, len(TIME_BAND_GROUPS_IDX))) self.assertEqual(st_x.shape, (b, len(STATIC_BANDS))) self.assertEqual(st_m.shape, (b, len(STATIC_BAND_GROUPS_IDX))) self.assertEqual(months.shape, (b, t))