Spaces:
Running
on
Zero
Running
on
Zero
| # -*- coding: utf-8 -*- | |
| import torch.fft | |
| import torch | |
| import numpy as np | |
| from scipy import ndimage | |
| from scipy.interpolate import interp2d | |
| def splits(a, sf): | |
| '''split a into sfxsf distinct blocks | |
| Args: | |
| a: NxCxWxH | |
| sf: split factor | |
| Returns: | |
| b: NxCx(W/sf)x(H/sf)x(sf^2) | |
| ''' | |
| b = torch.stack(torch.chunk(a, sf, dim=2), dim=4) | |
| b = torch.cat(torch.chunk(b, sf, dim=3), dim=4) | |
| return b | |
| def p2o(psf, shape): | |
| ''' | |
| Convert point-spread function to optical transfer function. | |
| otf = p2o(psf) computes the Fast Fourier Transform (FFT) of the | |
| point-spread function (PSF) array and creates the optical transfer | |
| function (OTF) array that is not influenced by the PSF off-centering. | |
| Args: | |
| psf: NxCxhxw | |
| shape: [H, W] | |
| Returns: | |
| otf: NxCxHxWx2 | |
| ''' | |
| otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf) | |
| otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf) | |
| for axis, axis_size in enumerate(psf.shape[2:]): | |
| otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2) | |
| otf = torch.fft.fftn(otf, dim=(-2,-1)) | |
| #n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf))) | |
| #otf[..., 1][torch.abs(otf[..., 1]) < n_ops*2.22e-16] = torch.tensor(0).type_as(psf) | |
| return otf | |
| def upsample(x, sf=3): | |
| '''s-fold upsampler | |
| Upsampling the spatial size by filling the new entries with zeros | |
| x: tensor image, NxCxWxH | |
| ''' | |
| st = 0 | |
| z = torch.zeros((x.shape[0], x.shape[1], x.shape[2]*sf, x.shape[3]*sf)).type_as(x) | |
| z[..., st::sf, st::sf].copy_(x) | |
| return z | |
| def downsample(x, sf=3): | |
| '''s-fold downsampler | |
| Keeping the upper-left pixel for each distinct sfxsf patch and discarding the others | |
| x: tensor image, NxCxWxH | |
| ''' | |
| st = 0 | |
| return x[..., st::sf, st::sf] | |
| def data_solution_simple(x, F2K, FKFy, rho): | |
| rho = rho.clip(min=1e-2) | |
| numerator = FKFy + torch.fft.fftn(rho*x, dim=(-2,-1)) | |
| denominator = F2K + rho | |
| FX = numerator / denominator | |
| Xest = torch.real(torch.fft.ifftn(FX, dim=(-2, -1))) | |
| return Xest | |
| def data_solution_nonuniform(x, FK, FKC, F2KM, FKFMy, rho): | |
| rho = rho.clip(min=1e-2) | |
| numerator = FKFMy + torch.fft.fftn(rho*x, dim=(-2,-1)) | |
| denominator = F2KM + rho | |
| FX = numerator / denominator | |
| Xest = torch.real(torch.fft.ifftn(FX, dim=(-2, -1))) | |
| return Xest | |
| def pre_calculate(x, k): | |
| ''' | |
| Args: | |
| x: NxCxHxW, LR input | |
| k: NxCxhxw | |
| Returns: | |
| FK, FKC, F2K | |
| will be reused during iterations | |
| ''' | |
| w, h = x.shape[-2:] | |
| FK = p2o(k, (w, h)) | |
| FKC = torch.conj(FK) | |
| F2K = torch.pow(torch.abs(FK), 2) | |
| return FK, FKC, F2K | |
| def pre_calculate_FK(k): | |
| ''' | |
| Args: | |
| k: [25, 1, 33, 33] 25 is the number of filters | |
| Returns: | |
| FK: | |
| FKC: | |
| ''' | |
| # [25, 1, 512, 512] (expanded from) [25, 1, 33, 33] | |
| FK = p2o(k, (512, 512)) | |
| FKC = torch.conj(FK) | |
| return FK, FKC | |
| def pre_calculate_nonuniform(x, y, FK, FKC, mask): | |
| ''' | |
| Args: | |
| x: [1, 3, 512, 512] | |
| y: [1, 3, 512, 512] | |
| FK: [25, 1, 512, 512] 25 is the number of filters | |
| FKC: [25, 1, 512, 512] | |
| m: [1, 25, 512, 512] | |
| Returns: | |
| ''' | |
| mask = mask.transpose(0, 1) | |
| w, h = x.shape[-2:] | |
| # [1, 3, 512, 512] -> [25, 3, 512, 512] | |
| By = y.repeat(mask.shape[0], 1, 1, 1) | |
| # [25, 3, 512, 512] | |
| My = mask * By | |
| # or use just fft..? | |
| FMy = torch.fft.fft2(My) | |
| # [25, 3, 512, 512] | |
| FKFMy = FK * FMy | |
| # [1, 3, 512, 512] | |
| FKFMy = torch.sum(FKFMy, dim=0, keepdim=True) | |
| # [25, 1, 512, 512] | |
| F2KM = torch.abs(FKC * (mask ** 2) * FK) | |
| # [1, 1, 512, 512] | |
| F2KM = torch.sum(F2KM, dim=0, keepdim=True) | |
| return F2KM, FKFMy | |
| def classical_degradation(x, k, sf=3): | |
| ''' blur + downsampling | |
| Args: | |
| x: HxWxC image, [0, 1]/[0, 255] | |
| k: hxw, double | |
| sf: down-scale factor | |
| Return: | |
| downsampled LR image | |
| ''' | |
| x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') | |
| #x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) | |
| st = 0 | |
| return x[st::sf, st::sf, ...] | |
| def shift_pixel(x, sf, upper_left=True): | |
| """shift pixel for super-resolution with different scale factors | |
| Args: | |
| x: WxHxC or WxH, image or kernel | |
| sf: scale factor | |
| upper_left: shift direction | |
| """ | |
| h, w = x.shape[:2] | |
| shift = (sf-1)*0.5 | |
| xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) | |
| if upper_left: | |
| x1 = xv + shift | |
| y1 = yv + shift | |
| else: | |
| x1 = xv - shift | |
| y1 = yv - shift | |
| x1 = np.clip(x1, 0, w-1) | |
| y1 = np.clip(y1, 0, h-1) | |
| if x.ndim == 2: | |
| x = interp2d(xv, yv, x)(x1, y1) | |
| if x.ndim == 3: | |
| for i in range(x.shape[-1]): | |
| x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) | |
| return x | |