Spaces:
Running
Running
import tempfile | |
import unittest | |
from pathlib import Path | |
import h5py | |
import numpy as np | |
import torch | |
from src.data.dataset import ( | |
SPACE_BANDS, | |
SPACE_TIME_BANDS, | |
STATIC_BANDS, | |
TIME_BANDS, | |
Dataset, | |
Normalizer, | |
to_cartesian, | |
) | |
BROKEN_FILE = "min_lat=24.7979_min_lon=-105.1508_max_lat=24.8069_max_lon=-105.141_dates=2022-01-01_2023-12-31.tif" | |
TEST_FILENAMES = [ | |
"min_lat=5.4427_min_lon=101.4016_max_lat=5.4518_max_lon=101.4107_dates=2022-01-01_2023-12-31.tif", | |
"min_lat=-27.6721_min_lon=25.6796_max_lat=-27.663_max_lon=25.6897_dates=2022-01-01_2023-12-31.tif", | |
] | |
TIFS_FOLDER = Path(__file__).parents[1] / "data/tifs" | |
TEST_FILES = [TIFS_FOLDER / x for x in TEST_FILENAMES] | |
class TestDataset(unittest.TestCase): | |
def test_tif_to_array(self): | |
ds = Dataset(TIFS_FOLDER, download=False) | |
for test_file in TEST_FILES: | |
s_t_x, sp_x, t_x, st_x, months = ds._tif_to_array(test_file) | |
self.assertFalse(np.isnan(s_t_x).any()) | |
self.assertFalse(np.isnan(sp_x).any()) | |
self.assertFalse(np.isnan(t_x).any()) | |
self.assertFalse(np.isnan(st_x).any()) | |
self.assertFalse(np.isinf(s_t_x).any()) | |
self.assertFalse(np.isinf(sp_x).any()) | |
self.assertFalse(np.isinf(t_x).any()) | |
self.assertFalse(np.isinf(st_x).any()) | |
self.assertEqual(sp_x.shape[0], s_t_x.shape[0]) | |
self.assertEqual(sp_x.shape[1], s_t_x.shape[1]) | |
self.assertEqual(t_x.shape[0], s_t_x.shape[2]) | |
self.assertEqual(len(SPACE_TIME_BANDS), s_t_x.shape[-1]) | |
self.assertEqual(len(SPACE_BANDS), sp_x.shape[-1]) | |
self.assertEqual(len(TIME_BANDS), t_x.shape[-1]) | |
self.assertEqual(len(STATIC_BANDS), st_x.shape[-1]) | |
self.assertEqual(months[0], 0) | |
def test_files_are_replaced(self): | |
ds = Dataset(TIFS_FOLDER, download=False) | |
assert TIFS_FOLDER / BROKEN_FILE in ds.tifs | |
for b in ds: | |
assert len(b) == 5 | |
assert TIFS_FOLDER / BROKEN_FILE not in ds.tifs | |
def test_normalization(self): | |
ds = Dataset(TIFS_FOLDER, download=False) | |
o = ds.load_normalization_values(path=Path("config/normalization.json")) | |
for t in [len(SPACE_TIME_BANDS), len(SPACE_BANDS), len(STATIC_BANDS), len(TIME_BANDS)]: | |
subdict = o[t] | |
self.assertTrue("mean" in subdict) | |
self.assertTrue("std" in subdict) | |
self.assertTrue(len(subdict["mean"]) == len(subdict["std"])) | |
normalizer = Normalizer(normalizing_dicts=o) | |
ds.normalizer = normalizer | |
for b in ds: | |
for t in b: | |
self.assertFalse(np.isnan(t).any()) | |
def test_subset_image_with_minimum_size(self): | |
input = np.ones((3, 3, 1)) | |
months = static = np.ones(1) | |
output = Dataset.subset_image(input, input, months, static, months, 3, 1) | |
self.assertTrue(np.equal(input, output[0]).all()) | |
self.assertTrue(np.equal(input, output[1]).all()) | |
self.assertTrue(np.equal(months, output[2]).all()) | |
def test_subset_with_too_small_image(self): | |
input = np.ones((2, 2, 1)) | |
months = static = np.ones(1) | |
self.assertRaises( | |
AssertionError, Dataset.subset_image, input, input, months, static, months, 3, 1 | |
) | |
def test_subset_with_larger_images(self): | |
input = np.ones((5, 5, 1)) | |
months = static = np.ones(1) | |
output = Dataset.subset_image(input, input, months, static, months, 3, 1) | |
self.assertTrue(np.equal(np.ones((3, 3, 1)), output[0]).all()) | |
self.assertTrue(np.equal(np.ones((3, 3, 1)), output[1]).all()) | |
self.assertTrue(np.equal(months, output[2]).all()) | |
def test_latlon_checks_float(self): | |
# just checking it runs | |
_ = to_cartesian( | |
30.0, | |
40.0, | |
) | |
with self.assertRaises(AssertionError): | |
to_cartesian(1000.0, 1000.0) | |
def test_latlon_checks_np(self): | |
# just checking it runs | |
_ = to_cartesian(np.array([30.0]), np.array([40.0])) | |
with self.assertRaises(AssertionError): | |
to_cartesian(np.array([1000.0]), np.array([1000.0])) | |
def test_latlon_checks_tensor(self): | |
# just checking it runs | |
_ = to_cartesian(torch.tensor([30.0]), torch.tensor([40.0])) | |
with self.assertRaises(AssertionError): | |
to_cartesian(torch.tensor([1000.0]), torch.tensor([1000.0])) | |
def test_process_h5pys(self): | |
with tempfile.TemporaryDirectory() as tempdir_str: | |
tempdir = Path(tempdir_str) | |
dataset = Dataset( | |
TIFS_FOLDER, | |
download=False, | |
h5py_folder=tempdir, | |
h5pys_only=False, | |
) | |
dataset.process_h5pys() | |
h5py_files = list(tempdir.glob("*.h5")) | |
self.assertEqual(len(h5py_files), 2) | |
for h5_file in h5py_files: | |
with h5py.File(h5_file, "r") as f: | |
# mostly checking it can be read | |
self.assertEqual(f["t_x"].shape[0], 24) | |