File size: 6,803 Bytes
70d1188 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
import random
import torch
import torch.nn.functional as F
from abc import ABC, abstractmethod
from torchvision.transforms import GaussianBlur
from utils.batch_prep import compute_pointmaps
import imgaug as ia
import imgaug.augmenters as iaa
import numpy as np
class ChangeBright(torch.nn.Module):
def __init__(self,prob=0.5,mag=[0.5,2.0]):
super().__init__()
self.mag = mag
self.prob = prob
def forward(self,rgb):
#if np.random.uniform()>=self.prob:
#return rgb
n = rgb.shape[0]
apply_aug = np.random.uniform(0,1,size=n) < self.prob
aug = iaa.MultiplyBrightness(np.random.uniform(self.mag[0],self.mag[1])) #NOTE iaa has bug about deterministic, we sample ourselves
rgb[apply_aug] = aug(images=rgb[apply_aug])
return rgb
class ChangeContrast(torch.nn.Module):
def __init__(self,prob=0.5,mag=[0.5,2.0]):
self.mag = mag
self.prob = prob
def __call__(self,rgb):
n = rgb.shape[0]
apply_aug = np.random.uniform(0,1,size=n) < self.prob
aug = iaa.GammaContrast(np.random.uniform(self.mag[0],self.mag[1]))
rgb[apply_aug] = aug(images=rgb[apply_aug])
return rgb
class SaltAndPepper:
def __init__(self, prob=0.3, ratio=0.1, per_channel=True):
self.prob = prob
self.ratio = ratio
self.per_channel = per_channel
def __call__(self, rgb):
n = rgb.shape[0]
apply_aug = np.random.uniform(0,1,size=n) < self.prob
aug = iaa.SaltAndPepper(self.ratio, per_channel=self.per_channel).to_deterministic()
rgb[apply_aug] = aug(images=rgb[apply_aug])
return rgb
class RGBGaussianNoise:
def __init__(self, max_noise=10, prob=0.5):
self.max_noise = max_noise
self.prob = prob
def __call__(self, rgb):
n = rgb.shape[0]
apply_aug = np.random.uniform(0,1,size=n) < self.prob
shape = rgb.shape
noise = np.random.normal(0, self.max_noise, size=shape).clip(-self.max_noise, self.max_noise)
rgb[apply_aug] = (rgb[apply_aug].astype(float) + noise[apply_aug]).clip(0,255).astype(np.uint8)
return rgb
# from https://github.com/mihdalal/manipgen/blob/master/manipgen/utils/obs_utils.py
class DepthWarping(torch.nn.Module):
def __init__(self, std=0.5, prob=0.8):
super().__init__()
self.std = std
self.prob = prob
def forward(self, depths, device=None):
if device is None:
device = depths.device
n, _, h, w = depths.shape
# Generate Gaussian shifts
gaussian_shifts = torch.normal(mean=0, std=self.std, size=(n, h, w, 2), device=device).float()
apply_shifts = torch.rand(n, device=device) < self.prob
gaussian_shifts[~apply_shifts] = 0.0
# Create grid for the original coordinates
xx = torch.linspace(0, w - 1, w, device=device)
yy = torch.linspace(0, h - 1, h, device=device)
xx = xx.unsqueeze(0).repeat(h, 1)
yy = yy.unsqueeze(1).repeat(1, w)
grid = torch.stack((xx, yy), 2).unsqueeze(0) # Add batch dimension
# Apply Gaussian shifts to the grid
grid = grid + gaussian_shifts
# Normalize grid values to the range [-1, 1] for grid_sample
grid[..., 0] = (grid[..., 0] / (w - 1)) * 2 - 1
grid[..., 1] = (grid[..., 1] / (h - 1)) * 2 - 1
# Perform the remapping using grid_sample
depth_interp = F.grid_sample(depths, grid, mode='bilinear', padding_mode='border', align_corners=True)
# Remove the batch and channel dimensions
depth_interp = depth_interp.squeeze(0).squeeze(0)
return depth_interp
class DepthHoles(torch.nn.Module):
def __init__(self, prob=0.5, kernel_size_lower=3, kernel_size_upper=27, sigma_lower=1.0,
sigma_upper=7.0, thresh_lower=0.6, thresh_upper=0.9):
super().__init__()
self.prob = prob
self.kernel_size_lower = kernel_size_lower
self.kernel_size_upper = kernel_size_upper
self.sigma_lower = sigma_lower
self.sigma_upper = sigma_upper
self.thresh_lower = thresh_lower
self.thresh_upper = thresh_upper
def forward(self, depths, device=None):
if device is None:
device = depths.device
n, _, h, w = depths.shape
# generate random noise
noise = torch.rand(n, 1, h, w, device=device)
# apply gaussian blur
k = random.choice(list(range(self.kernel_size_lower, self.kernel_size_upper+1, 2)))
noise = GaussianBlur(kernel_size=k, sigma=(self.sigma_lower, self.sigma_upper))(noise)
# normalize noise
noise = (noise - noise.min()) / (noise.max() - noise.min())
# apply thresholding
thresh = torch.rand(n, 1, 1, 1, device=device) * (self.thresh_upper - self.thresh_lower) + self.thresh_lower
mask = (noise > thresh)
prob = self.prob
keep_mask = torch.rand(n, device=device) < prob
mask[~keep_mask, :] = 0
return mask
class DepthNoise(torch.nn.Module):
def __init__(self, std=0.005,prob=1.0):
super().__init__()
self.std = std
self.prob = prob
def forward(self, depths, device=None):
if device is None:
device = depths.device
n, _, h, w = depths.shape
apply_noise = torch.rand(n, device=device) < self.prob
noise = torch.randn(n, 1, h, w, device=device) * self.std
noise[~apply_noise] = 0.0
return depths + noise
class Augmentor(torch.nn.Module):
def __init__(self, depth_holes=DepthHoles(), depth_warping=DepthWarping(),depth_noise=DepthNoise(),
rgb_operators=[ChangeBright(),SaltAndPepper(),ChangeContrast(),RGBGaussianNoise()]):
super().__init__()
self.depth_holes = depth_holes
self.depth_warping = depth_warping
self.depth_noise = depth_noise
self.rgb_operators = rgb_operators
def forward(self, batch):
input_depths = batch['input_cams']['depths']
if self.depth_holes.prob > 0:
masks = self.depth_holes(input_depths)
batch['input_cams']['valid_masks'][masks] = False
#if self.depth_warping.prob > 0:
#input_depths = self.depth_warping(input_depths)
if self.depth_noise.prob > 0:
input_depths = self.depth_noise(input_depths)
input_rgbs = batch['input_cams']['imgs'].squeeze(1).cpu().numpy() # this is a bit inefficient, but it's ok..
for op in self.rgb_operators:
input_rgbs = op(input_rgbs)
batch['input_cams']['imgs'] = torch.from_numpy(input_rgbs).cuda().unsqueeze(1)
batch['input_cams']['depths'] = input_depths
batch['input_cams']['pointmaps'] = compute_pointmaps(batch['input_cams']['depths'],batch['input_cams']['Ks'],batch['input_cams']['c2ws']) # now we're doing this twice, but alas
return batch |