Spaces:
Sleeping
Sleeping
File size: 1,267 Bytes
b20c769 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import unittest
import torch
from einops import repeat
from src.data_augmentation import FlipAndRotateSpace
class TestAugmentation(unittest.TestCase):
def test_flip_and_rotate_space(self):
aug = FlipAndRotateSpace(enabled=True)
space_x = torch.randn(100, 10, 10, 3) # (b, h, w, c)
space_time_x = repeat(space_x.clone(), "b h w c -> b h w t c", t=8)
new_space_time_x, new_space_x = aug.apply(space_time_x, space_x)
# check that space_x and space_time_x are transformed the *same* way
self.assertTrue(torch.equal(new_space_time_x.mean(dim=-2), new_space_x))
# check that tensors were changed when flip+rotate=True
self.assertFalse(torch.equal(new_space_time_x, space_time_x))
self.assertFalse(torch.equal(new_space_x, space_x))
aug = FlipAndRotateSpace(enabled=False)
space_x = torch.randn(100, 10, 10, 3) # (b, h, w, c)
space_time_x = repeat(space_x.clone(), "b h w c -> b h w t c", t=8)
new_space_time_x, new_space_x = aug.apply(space_time_x, space_x)
# check that tensors were not changed when flip+rotate=False
self.assertTrue(torch.equal(new_space_time_x, space_time_x))
self.assertTrue(torch.equal(new_space_x, space_x))
|