Spaces:
Build error
Build error
| # https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/utils/img_process_util.py | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from .common import filter2D | |
| class USMSharp(torch.nn.Module): | |
| def __init__(self, radius=50, sigma=0): | |
| super(USMSharp, self).__init__() | |
| if radius % 2 == 0: | |
| radius += 1 | |
| self.radius = radius | |
| kernel = cv2.getGaussianKernel(radius, sigma) | |
| kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0) | |
| self.register_buffer('kernel', kernel) | |
| def forward(self, img, weight=0.5, threshold=10): | |
| blur = filter2D(img, self.kernel) | |
| residual = img - blur | |
| mask = torch.abs(residual) * 255 > threshold | |
| mask = mask.float() | |
| soft_mask = filter2D(mask, self.kernel) | |
| sharp = img + weight * residual | |
| sharp = torch.clip(sharp, 0, 1) | |
| return soft_mask * sharp + (1 - soft_mask) * img | |