rayst3r / utils /augmentations.py
bartduis's picture
init
70d1188
raw
history blame
6.8 kB
import random
import torch
import torch.nn.functional as F
from abc import ABC, abstractmethod
from torchvision.transforms import GaussianBlur
from utils.batch_prep import compute_pointmaps
import imgaug as ia
import imgaug.augmenters as iaa
import numpy as np
class ChangeBright(torch.nn.Module):
def __init__(self,prob=0.5,mag=[0.5,2.0]):
super().__init__()
self.mag = mag
self.prob = prob
def forward(self,rgb):
#if np.random.uniform()>=self.prob:
#return rgb
n = rgb.shape[0]
apply_aug = np.random.uniform(0,1,size=n) < self.prob
aug = iaa.MultiplyBrightness(np.random.uniform(self.mag[0],self.mag[1])) #NOTE iaa has bug about deterministic, we sample ourselves
rgb[apply_aug] = aug(images=rgb[apply_aug])
return rgb
class ChangeContrast(torch.nn.Module):
def __init__(self,prob=0.5,mag=[0.5,2.0]):
self.mag = mag
self.prob = prob
def __call__(self,rgb):
n = rgb.shape[0]
apply_aug = np.random.uniform(0,1,size=n) < self.prob
aug = iaa.GammaContrast(np.random.uniform(self.mag[0],self.mag[1]))
rgb[apply_aug] = aug(images=rgb[apply_aug])
return rgb
class SaltAndPepper:
def __init__(self, prob=0.3, ratio=0.1, per_channel=True):
self.prob = prob
self.ratio = ratio
self.per_channel = per_channel
def __call__(self, rgb):
n = rgb.shape[0]
apply_aug = np.random.uniform(0,1,size=n) < self.prob
aug = iaa.SaltAndPepper(self.ratio, per_channel=self.per_channel).to_deterministic()
rgb[apply_aug] = aug(images=rgb[apply_aug])
return rgb
class RGBGaussianNoise:
def __init__(self, max_noise=10, prob=0.5):
self.max_noise = max_noise
self.prob = prob
def __call__(self, rgb):
n = rgb.shape[0]
apply_aug = np.random.uniform(0,1,size=n) < self.prob
shape = rgb.shape
noise = np.random.normal(0, self.max_noise, size=shape).clip(-self.max_noise, self.max_noise)
rgb[apply_aug] = (rgb[apply_aug].astype(float) + noise[apply_aug]).clip(0,255).astype(np.uint8)
return rgb
# from https://github.com/mihdalal/manipgen/blob/master/manipgen/utils/obs_utils.py
class DepthWarping(torch.nn.Module):
def __init__(self, std=0.5, prob=0.8):
super().__init__()
self.std = std
self.prob = prob
def forward(self, depths, device=None):
if device is None:
device = depths.device
n, _, h, w = depths.shape
# Generate Gaussian shifts
gaussian_shifts = torch.normal(mean=0, std=self.std, size=(n, h, w, 2), device=device).float()
apply_shifts = torch.rand(n, device=device) < self.prob
gaussian_shifts[~apply_shifts] = 0.0
# Create grid for the original coordinates
xx = torch.linspace(0, w - 1, w, device=device)
yy = torch.linspace(0, h - 1, h, device=device)
xx = xx.unsqueeze(0).repeat(h, 1)
yy = yy.unsqueeze(1).repeat(1, w)
grid = torch.stack((xx, yy), 2).unsqueeze(0) # Add batch dimension
# Apply Gaussian shifts to the grid
grid = grid + gaussian_shifts
# Normalize grid values to the range [-1, 1] for grid_sample
grid[..., 0] = (grid[..., 0] / (w - 1)) * 2 - 1
grid[..., 1] = (grid[..., 1] / (h - 1)) * 2 - 1
# Perform the remapping using grid_sample
depth_interp = F.grid_sample(depths, grid, mode='bilinear', padding_mode='border', align_corners=True)
# Remove the batch and channel dimensions
depth_interp = depth_interp.squeeze(0).squeeze(0)
return depth_interp
class DepthHoles(torch.nn.Module):
def __init__(self, prob=0.5, kernel_size_lower=3, kernel_size_upper=27, sigma_lower=1.0,
sigma_upper=7.0, thresh_lower=0.6, thresh_upper=0.9):
super().__init__()
self.prob = prob
self.kernel_size_lower = kernel_size_lower
self.kernel_size_upper = kernel_size_upper
self.sigma_lower = sigma_lower
self.sigma_upper = sigma_upper
self.thresh_lower = thresh_lower
self.thresh_upper = thresh_upper
def forward(self, depths, device=None):
if device is None:
device = depths.device
n, _, h, w = depths.shape
# generate random noise
noise = torch.rand(n, 1, h, w, device=device)
# apply gaussian blur
k = random.choice(list(range(self.kernel_size_lower, self.kernel_size_upper+1, 2)))
noise = GaussianBlur(kernel_size=k, sigma=(self.sigma_lower, self.sigma_upper))(noise)
# normalize noise
noise = (noise - noise.min()) / (noise.max() - noise.min())
# apply thresholding
thresh = torch.rand(n, 1, 1, 1, device=device) * (self.thresh_upper - self.thresh_lower) + self.thresh_lower
mask = (noise > thresh)
prob = self.prob
keep_mask = torch.rand(n, device=device) < prob
mask[~keep_mask, :] = 0
return mask
class DepthNoise(torch.nn.Module):
def __init__(self, std=0.005,prob=1.0):
super().__init__()
self.std = std
self.prob = prob
def forward(self, depths, device=None):
if device is None:
device = depths.device
n, _, h, w = depths.shape
apply_noise = torch.rand(n, device=device) < self.prob
noise = torch.randn(n, 1, h, w, device=device) * self.std
noise[~apply_noise] = 0.0
return depths + noise
class Augmentor(torch.nn.Module):
def __init__(self, depth_holes=DepthHoles(), depth_warping=DepthWarping(),depth_noise=DepthNoise(),
rgb_operators=[ChangeBright(),SaltAndPepper(),ChangeContrast(),RGBGaussianNoise()]):
super().__init__()
self.depth_holes = depth_holes
self.depth_warping = depth_warping
self.depth_noise = depth_noise
self.rgb_operators = rgb_operators
def forward(self, batch):
input_depths = batch['input_cams']['depths']
if self.depth_holes.prob > 0:
masks = self.depth_holes(input_depths)
batch['input_cams']['valid_masks'][masks] = False
#if self.depth_warping.prob > 0:
#input_depths = self.depth_warping(input_depths)
if self.depth_noise.prob > 0:
input_depths = self.depth_noise(input_depths)
input_rgbs = batch['input_cams']['imgs'].squeeze(1).cpu().numpy() # this is a bit inefficient, but it's ok..
for op in self.rgb_operators:
input_rgbs = op(input_rgbs)
batch['input_cams']['imgs'] = torch.from_numpy(input_rgbs).cuda().unsqueeze(1)
batch['input_cams']['depths'] = input_depths
batch['input_cams']['pointmaps'] = compute_pointmaps(batch['input_cams']['depths'],batch['input_cams']['Ks'],batch['input_cams']['c2ws']) # now we're doing this twice, but alas
return batch