File size: 6,803 Bytes
70d1188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import random
import torch
import torch.nn.functional as F
from abc import ABC, abstractmethod
from torchvision.transforms import GaussianBlur
from utils.batch_prep import compute_pointmaps
import imgaug as ia
import imgaug.augmenters as iaa
import numpy as np

class ChangeBright(torch.nn.Module):
  def __init__(self,prob=0.5,mag=[0.5,2.0]):
    super().__init__()
    self.mag = mag
    self.prob = prob

  def forward(self,rgb):
    #if np.random.uniform()>=self.prob:
      #return rgb
    n = rgb.shape[0]
    apply_aug = np.random.uniform(0,1,size=n) < self.prob
    aug = iaa.MultiplyBrightness(np.random.uniform(self.mag[0],self.mag[1]))  #NOTE iaa has bug about deterministic, we sample ourselves
    rgb[apply_aug] = aug(images=rgb[apply_aug])
    return rgb

class ChangeContrast(torch.nn.Module):
  def __init__(self,prob=0.5,mag=[0.5,2.0]):
    self.mag = mag
    self.prob = prob

  def __call__(self,rgb):
    n = rgb.shape[0]
    apply_aug = np.random.uniform(0,1,size=n) < self.prob

    aug = iaa.GammaContrast(np.random.uniform(self.mag[0],self.mag[1]))
    rgb[apply_aug] = aug(images=rgb[apply_aug])
    return rgb

class SaltAndPepper:
  def __init__(self, prob=0.3, ratio=0.1, per_channel=True):
    self.prob = prob
    self.ratio = ratio
    self.per_channel = per_channel

  def __call__(self, rgb):
    n = rgb.shape[0]
    apply_aug = np.random.uniform(0,1,size=n) < self.prob
    aug = iaa.SaltAndPepper(self.ratio, per_channel=self.per_channel).to_deterministic()
    rgb[apply_aug] = aug(images=rgb[apply_aug])
    return rgb

class RGBGaussianNoise:
  def __init__(self, max_noise=10, prob=0.5):
    self.max_noise = max_noise
    self.prob = prob

  def __call__(self, rgb):
    n = rgb.shape[0]
    apply_aug = np.random.uniform(0,1,size=n) < self.prob

    shape = rgb.shape
    noise = np.random.normal(0, self.max_noise, size=shape).clip(-self.max_noise, self.max_noise)
    rgb[apply_aug] = (rgb[apply_aug].astype(float) + noise[apply_aug]).clip(0,255).astype(np.uint8)
    return rgb

# from https://github.com/mihdalal/manipgen/blob/master/manipgen/utils/obs_utils.py
class DepthWarping(torch.nn.Module):
    def __init__(self, std=0.5, prob=0.8):
        super().__init__()
        self.std = std
        self.prob = prob
    
    def forward(self, depths, device=None):
        if device is None:
            device = depths.device

        n, _, h, w = depths.shape

        # Generate Gaussian shifts
        gaussian_shifts = torch.normal(mean=0, std=self.std, size=(n, h, w, 2), device=device).float()
        apply_shifts = torch.rand(n, device=device) < self.prob
        gaussian_shifts[~apply_shifts] = 0.0

        # Create grid for the original coordinates
        xx = torch.linspace(0, w - 1, w, device=device)
        yy = torch.linspace(0, h - 1, h, device=device)
        xx = xx.unsqueeze(0).repeat(h, 1)
        yy = yy.unsqueeze(1).repeat(1, w)
        grid = torch.stack((xx, yy), 2).unsqueeze(0)  # Add batch dimension

        # Apply Gaussian shifts to the grid
        grid = grid + gaussian_shifts

        # Normalize grid values to the range [-1, 1] for grid_sample
        grid[..., 0] = (grid[..., 0] / (w - 1)) * 2 - 1
        grid[..., 1] = (grid[..., 1] / (h - 1)) * 2 - 1

        # Perform the remapping using grid_sample
        depth_interp = F.grid_sample(depths, grid, mode='bilinear', padding_mode='border', align_corners=True)

        # Remove the batch and channel dimensions
        depth_interp = depth_interp.squeeze(0).squeeze(0)

        return depth_interp

class DepthHoles(torch.nn.Module):
    def __init__(self, prob=0.5, kernel_size_lower=3, kernel_size_upper=27, sigma_lower=1.0, 
    sigma_upper=7.0, thresh_lower=0.6, thresh_upper=0.9):
        super().__init__()
        self.prob = prob
        self.kernel_size_lower = kernel_size_lower
        self.kernel_size_upper = kernel_size_upper
        self.sigma_lower = sigma_lower
        self.sigma_upper = sigma_upper
        self.thresh_lower = thresh_lower
        self.thresh_upper = thresh_upper

    def forward(self, depths, device=None):
        if device is None:
            device = depths.device

        n, _, h, w = depths.shape
        # generate random noise
        noise = torch.rand(n, 1, h, w, device=device)

        # apply gaussian blur
        k = random.choice(list(range(self.kernel_size_lower, self.kernel_size_upper+1, 2)))
        noise = GaussianBlur(kernel_size=k, sigma=(self.sigma_lower, self.sigma_upper))(noise)

        # normalize noise
        noise = (noise - noise.min()) / (noise.max() - noise.min())

        # apply thresholding
        thresh = torch.rand(n, 1, 1, 1, device=device) * (self.thresh_upper - self.thresh_lower) + self.thresh_lower
        mask = (noise > thresh)
        prob = self.prob
        keep_mask = torch.rand(n, device=device) < prob
        mask[~keep_mask, :] = 0

        return mask

class DepthNoise(torch.nn.Module):
    def __init__(self, std=0.005,prob=1.0):
        super().__init__()
        self.std = std
        self.prob = prob
        
    def forward(self, depths, device=None):
        if device is None:
            device = depths.device

        n, _, h, w = depths.shape
        apply_noise = torch.rand(n, device=device) < self.prob
        noise = torch.randn(n, 1, h, w, device=device) * self.std
        noise[~apply_noise] = 0.0
        return depths + noise

class Augmentor(torch.nn.Module):
    def __init__(self, depth_holes=DepthHoles(), depth_warping=DepthWarping(),depth_noise=DepthNoise(),
    rgb_operators=[ChangeBright(),SaltAndPepper(),ChangeContrast(),RGBGaussianNoise()]):
        super().__init__()
        self.depth_holes = depth_holes
        self.depth_warping = depth_warping
        self.depth_noise = depth_noise
        self.rgb_operators = rgb_operators

    def forward(self, batch):
        input_depths = batch['input_cams']['depths']
        if self.depth_holes.prob > 0:
            masks = self.depth_holes(input_depths)
            batch['input_cams']['valid_masks'][masks] = False
        #if self.depth_warping.prob > 0:
            #input_depths = self.depth_warping(input_depths)
        if self.depth_noise.prob > 0:
            input_depths = self.depth_noise(input_depths)
        
        input_rgbs = batch['input_cams']['imgs'].squeeze(1).cpu().numpy() # this is a bit inefficient, but it's ok..
        for op in self.rgb_operators:
            input_rgbs = op(input_rgbs)
        batch['input_cams']['imgs'] = torch.from_numpy(input_rgbs).cuda().unsqueeze(1)

        batch['input_cams']['depths'] = input_depths
        batch['input_cams']['pointmaps'] = compute_pointmaps(batch['input_cams']['depths'],batch['input_cams']['Ks'],batch['input_cams']['c2ws']) # now we're doing this twice, but alas
        return batch