|
import random |
|
import torch |
|
import torch.nn.functional as F |
|
from abc import ABC, abstractmethod |
|
from torchvision.transforms import GaussianBlur |
|
from utils.batch_prep import compute_pointmaps |
|
import imgaug as ia |
|
import imgaug.augmenters as iaa |
|
import numpy as np |
|
|
|
class ChangeBright(torch.nn.Module): |
|
def __init__(self,prob=0.5,mag=[0.5,2.0]): |
|
super().__init__() |
|
self.mag = mag |
|
self.prob = prob |
|
|
|
def forward(self,rgb): |
|
|
|
|
|
n = rgb.shape[0] |
|
apply_aug = np.random.uniform(0,1,size=n) < self.prob |
|
aug = iaa.MultiplyBrightness(np.random.uniform(self.mag[0],self.mag[1])) |
|
rgb[apply_aug] = aug(images=rgb[apply_aug]) |
|
return rgb |
|
|
|
class ChangeContrast(torch.nn.Module): |
|
def __init__(self,prob=0.5,mag=[0.5,2.0]): |
|
self.mag = mag |
|
self.prob = prob |
|
|
|
def __call__(self,rgb): |
|
n = rgb.shape[0] |
|
apply_aug = np.random.uniform(0,1,size=n) < self.prob |
|
|
|
aug = iaa.GammaContrast(np.random.uniform(self.mag[0],self.mag[1])) |
|
rgb[apply_aug] = aug(images=rgb[apply_aug]) |
|
return rgb |
|
|
|
class SaltAndPepper: |
|
def __init__(self, prob=0.3, ratio=0.1, per_channel=True): |
|
self.prob = prob |
|
self.ratio = ratio |
|
self.per_channel = per_channel |
|
|
|
def __call__(self, rgb): |
|
n = rgb.shape[0] |
|
apply_aug = np.random.uniform(0,1,size=n) < self.prob |
|
aug = iaa.SaltAndPepper(self.ratio, per_channel=self.per_channel).to_deterministic() |
|
rgb[apply_aug] = aug(images=rgb[apply_aug]) |
|
return rgb |
|
|
|
class RGBGaussianNoise: |
|
def __init__(self, max_noise=10, prob=0.5): |
|
self.max_noise = max_noise |
|
self.prob = prob |
|
|
|
def __call__(self, rgb): |
|
n = rgb.shape[0] |
|
apply_aug = np.random.uniform(0,1,size=n) < self.prob |
|
|
|
shape = rgb.shape |
|
noise = np.random.normal(0, self.max_noise, size=shape).clip(-self.max_noise, self.max_noise) |
|
rgb[apply_aug] = (rgb[apply_aug].astype(float) + noise[apply_aug]).clip(0,255).astype(np.uint8) |
|
return rgb |
|
|
|
|
|
class DepthWarping(torch.nn.Module): |
|
def __init__(self, std=0.5, prob=0.8): |
|
super().__init__() |
|
self.std = std |
|
self.prob = prob |
|
|
|
def forward(self, depths, device=None): |
|
if device is None: |
|
device = depths.device |
|
|
|
n, _, h, w = depths.shape |
|
|
|
|
|
gaussian_shifts = torch.normal(mean=0, std=self.std, size=(n, h, w, 2), device=device).float() |
|
apply_shifts = torch.rand(n, device=device) < self.prob |
|
gaussian_shifts[~apply_shifts] = 0.0 |
|
|
|
|
|
xx = torch.linspace(0, w - 1, w, device=device) |
|
yy = torch.linspace(0, h - 1, h, device=device) |
|
xx = xx.unsqueeze(0).repeat(h, 1) |
|
yy = yy.unsqueeze(1).repeat(1, w) |
|
grid = torch.stack((xx, yy), 2).unsqueeze(0) |
|
|
|
|
|
grid = grid + gaussian_shifts |
|
|
|
|
|
grid[..., 0] = (grid[..., 0] / (w - 1)) * 2 - 1 |
|
grid[..., 1] = (grid[..., 1] / (h - 1)) * 2 - 1 |
|
|
|
|
|
depth_interp = F.grid_sample(depths, grid, mode='bilinear', padding_mode='border', align_corners=True) |
|
|
|
|
|
depth_interp = depth_interp.squeeze(0).squeeze(0) |
|
|
|
return depth_interp |
|
|
|
class DepthHoles(torch.nn.Module): |
|
def __init__(self, prob=0.5, kernel_size_lower=3, kernel_size_upper=27, sigma_lower=1.0, |
|
sigma_upper=7.0, thresh_lower=0.6, thresh_upper=0.9): |
|
super().__init__() |
|
self.prob = prob |
|
self.kernel_size_lower = kernel_size_lower |
|
self.kernel_size_upper = kernel_size_upper |
|
self.sigma_lower = sigma_lower |
|
self.sigma_upper = sigma_upper |
|
self.thresh_lower = thresh_lower |
|
self.thresh_upper = thresh_upper |
|
|
|
def forward(self, depths, device=None): |
|
if device is None: |
|
device = depths.device |
|
|
|
n, _, h, w = depths.shape |
|
|
|
noise = torch.rand(n, 1, h, w, device=device) |
|
|
|
|
|
k = random.choice(list(range(self.kernel_size_lower, self.kernel_size_upper+1, 2))) |
|
noise = GaussianBlur(kernel_size=k, sigma=(self.sigma_lower, self.sigma_upper))(noise) |
|
|
|
|
|
noise = (noise - noise.min()) / (noise.max() - noise.min()) |
|
|
|
|
|
thresh = torch.rand(n, 1, 1, 1, device=device) * (self.thresh_upper - self.thresh_lower) + self.thresh_lower |
|
mask = (noise > thresh) |
|
prob = self.prob |
|
keep_mask = torch.rand(n, device=device) < prob |
|
mask[~keep_mask, :] = 0 |
|
|
|
return mask |
|
|
|
class DepthNoise(torch.nn.Module): |
|
def __init__(self, std=0.005,prob=1.0): |
|
super().__init__() |
|
self.std = std |
|
self.prob = prob |
|
|
|
def forward(self, depths, device=None): |
|
if device is None: |
|
device = depths.device |
|
|
|
n, _, h, w = depths.shape |
|
apply_noise = torch.rand(n, device=device) < self.prob |
|
noise = torch.randn(n, 1, h, w, device=device) * self.std |
|
noise[~apply_noise] = 0.0 |
|
return depths + noise |
|
|
|
class Augmentor(torch.nn.Module): |
|
def __init__(self, depth_holes=DepthHoles(), depth_warping=DepthWarping(),depth_noise=DepthNoise(), |
|
rgb_operators=[ChangeBright(),SaltAndPepper(),ChangeContrast(),RGBGaussianNoise()]): |
|
super().__init__() |
|
self.depth_holes = depth_holes |
|
self.depth_warping = depth_warping |
|
self.depth_noise = depth_noise |
|
self.rgb_operators = rgb_operators |
|
|
|
def forward(self, batch): |
|
input_depths = batch['input_cams']['depths'] |
|
if self.depth_holes.prob > 0: |
|
masks = self.depth_holes(input_depths) |
|
batch['input_cams']['valid_masks'][masks] = False |
|
|
|
|
|
if self.depth_noise.prob > 0: |
|
input_depths = self.depth_noise(input_depths) |
|
|
|
input_rgbs = batch['input_cams']['imgs'].squeeze(1).cpu().numpy() |
|
for op in self.rgb_operators: |
|
input_rgbs = op(input_rgbs) |
|
batch['input_cams']['imgs'] = torch.from_numpy(input_rgbs).cuda().unsqueeze(1) |
|
|
|
batch['input_cams']['depths'] = input_depths |
|
batch['input_cams']['pointmaps'] = compute_pointmaps(batch['input_cams']['depths'],batch['input_cams']['Ks'],batch['input_cams']['c2ws']) |
|
return batch |