File size: 1,267 Bytes
b20c769
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import unittest

import torch
from einops import repeat

from src.data_augmentation import FlipAndRotateSpace


class TestAugmentation(unittest.TestCase):
    def test_flip_and_rotate_space(self):
        aug = FlipAndRotateSpace(enabled=True)
        space_x = torch.randn(100, 10, 10, 3)  # (b, h, w, c)
        space_time_x = repeat(space_x.clone(), "b h w c -> b h w t c", t=8)
        new_space_time_x, new_space_x = aug.apply(space_time_x, space_x)

        # check that space_x and space_time_x are transformed the *same* way
        self.assertTrue(torch.equal(new_space_time_x.mean(dim=-2), new_space_x))

        # check that tensors were changed when flip+rotate=True
        self.assertFalse(torch.equal(new_space_time_x, space_time_x))
        self.assertFalse(torch.equal(new_space_x, space_x))

        aug = FlipAndRotateSpace(enabled=False)
        space_x = torch.randn(100, 10, 10, 3)  # (b, h, w, c)
        space_time_x = repeat(space_x.clone(), "b h w c -> b h w t c", t=8)
        new_space_time_x, new_space_x = aug.apply(space_time_x, space_x)

        # check that tensors were not changed when flip+rotate=False
        self.assertTrue(torch.equal(new_space_time_x, space_time_x))
        self.assertTrue(torch.equal(new_space_x, space_x))