Tzktz commited on
Commit
8e542dc
·
verified ·
1 Parent(s): fcc0337

Upload 174 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. CodeFormer/weights/CodeFormer/codeformer.pth +3 -0
  2. CodeFormer/weights/facelib/detection_Resnet50_Final.pth +3 -0
  3. CodeFormer/weights/facelib/parsing_parsenet.pth +3 -0
  4. CodeFormer/weights/realesrgan/RealESRGAN_x2plus.pth +3 -0
  5. app.py +188 -0
  6. basicsr/VERSION +1 -0
  7. basicsr/__init__.py +11 -0
  8. basicsr/__pycache__/__init__.cpython-39.pyc +0 -0
  9. basicsr/__pycache__/train.cpython-39.pyc +0 -0
  10. basicsr/__pycache__/version.cpython-39.pyc +0 -0
  11. basicsr/archs/__init__.py +25 -0
  12. basicsr/archs/__pycache__/__init__.cpython-39.pyc +0 -0
  13. basicsr/archs/__pycache__/arcface_arch.cpython-39.pyc +0 -0
  14. basicsr/archs/__pycache__/arch_util.cpython-39.pyc +0 -0
  15. basicsr/archs/__pycache__/codeformer_arch.cpython-39.pyc +0 -0
  16. basicsr/archs/__pycache__/rrdbnet_arch.cpython-39.pyc +0 -0
  17. basicsr/archs/__pycache__/vgg_arch.cpython-39.pyc +0 -0
  18. basicsr/archs/__pycache__/vqgan_arch.cpython-39.pyc +0 -0
  19. basicsr/archs/arcface_arch.py +245 -0
  20. basicsr/archs/arch_util.py +318 -0
  21. basicsr/archs/codeformer_arch.py +280 -0
  22. basicsr/archs/rrdbnet_arch.py +119 -0
  23. basicsr/archs/vgg_arch.py +161 -0
  24. basicsr/archs/vqgan_arch.py +434 -0
  25. basicsr/data/__init__.py +100 -0
  26. basicsr/data/__pycache__/__init__.cpython-39.pyc +0 -0
  27. basicsr/data/__pycache__/data_sampler.cpython-39.pyc +0 -0
  28. basicsr/data/__pycache__/data_util.cpython-39.pyc +0 -0
  29. basicsr/data/__pycache__/ffhq_blind_dataset.cpython-39.pyc +0 -0
  30. basicsr/data/__pycache__/ffhq_blind_joint_dataset.cpython-39.pyc +0 -0
  31. basicsr/data/__pycache__/gaussian_kernels.cpython-39.pyc +0 -0
  32. basicsr/data/__pycache__/paired_image_dataset.cpython-39.pyc +0 -0
  33. basicsr/data/__pycache__/prefetch_dataloader.cpython-39.pyc +0 -0
  34. basicsr/data/__pycache__/transforms.cpython-39.pyc +0 -0
  35. basicsr/data/data_sampler.py +48 -0
  36. basicsr/data/data_util.py +392 -0
  37. basicsr/data/ffhq_blind_dataset.py +299 -0
  38. basicsr/data/ffhq_blind_joint_dataset.py +324 -0
  39. basicsr/data/gaussian_kernels.py +690 -0
  40. basicsr/data/paired_image_dataset.py +101 -0
  41. basicsr/data/prefetch_dataloader.py +125 -0
  42. basicsr/data/transforms.py +165 -0
  43. basicsr/losses/__init__.py +26 -0
  44. basicsr/losses/__pycache__/__init__.cpython-39.pyc +0 -0
  45. basicsr/losses/__pycache__/loss_util.cpython-39.pyc +0 -0
  46. basicsr/losses/__pycache__/losses.cpython-39.pyc +0 -0
  47. basicsr/losses/loss_util.py +95 -0
  48. basicsr/losses/losses.py +455 -0
  49. basicsr/metrics/__init__.py +19 -0
  50. basicsr/metrics/__pycache__/__init__.cpython-39.pyc +0 -0
CodeFormer/weights/CodeFormer/codeformer.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1009e537e0c2a07d4cabce6355f53cb66767cd4b4297ec7a4a64ca4b8a5684b7
3
+ size 376637898
CodeFormer/weights/facelib/detection_Resnet50_Final.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d1de9c2944f2ccddca5f5e010ea5ae64a39845a86311af6fdf30841b0a5a16d
3
+ size 109497761
CodeFormer/weights/facelib/parsing_parsenet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d558d8d0e42c20224f13cf5a29c79eba2d59913419f945545d8cf7b72920de2
3
+ size 85331193
CodeFormer/weights/realesrgan/RealESRGAN_x2plus.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:49fafd45f8fd7aa8d31ab2a22d14d91b536c34494a5cfe31eb5d89c2fa266abb
3
+ size 67061725
app.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import gradio as gr
5
+ from torchvision.transforms.functional import normalize
6
+ from basicsr.archs.rrdbnet_arch import RRDBNet
7
+ from basicsr.utils import imwrite, img2tensor, tensor2img
8
+ from basicsr.utils.misc import gpu_is_available, get_device
9
+ from basicsr.utils.realesrgan_utils import RealESRGANer
10
+ from basicsr.utils.registry import ARCH_REGISTRY
11
+
12
+ from facelib.utils.face_restoration_helper import FaceRestoreHelper
13
+ from facelib.utils.misc import is_gray
14
+
15
+
16
+ def imread(img_path):
17
+ img = cv2.imread(img_path)
18
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
19
+ return img
20
+
21
+
22
+ def set_realesrgan():
23
+ half = True if gpu_is_available() else False
24
+ model = RRDBNet(
25
+ num_in_ch=3,
26
+ num_out_ch=3,
27
+ num_feat=64,
28
+ num_block=23,
29
+ num_grow_ch=32,
30
+ scale=2,
31
+ )
32
+ upsampler = RealESRGANer(
33
+ scale=2,
34
+ model_path="CodeFormer/weights/realesrgan/RealESRGAN_x2plus.pth",
35
+ model=model,
36
+ tile=400,
37
+ tile_pad=40,
38
+ pre_pad=0,
39
+ half=half,
40
+ )
41
+ return upsampler
42
+
43
+
44
+ upsampler = set_realesrgan()
45
+
46
+ device = get_device()
47
+ codeformer_net = ARCH_REGISTRY.get("CodeFormer")(
48
+ dim_embd=512,
49
+ codebook_size=1024,
50
+ n_head=8,
51
+ n_layers=9,
52
+ connect_list=["32", "64", "128", "256"],
53
+ ).to(device)
54
+ ckpt_path = "CodeFormer/weights/CodeFormer/codeformer.pth"
55
+ checkpoint = torch.load(ckpt_path)["params_ema"]
56
+ codeformer_net.load_state_dict(checkpoint)
57
+ codeformer_net.eval()
58
+
59
+ os.makedirs('output', exist_ok=True)
60
+
61
+
62
+ def inference(image, background_enhance, face_upsample, upscale, codeformer_fidelity):
63
+ """Run a single prediction on the model"""
64
+ try: # global try
65
+ # take the default setting for the demo
66
+ has_aligned = False
67
+ only_center_face = False
68
+ draw_box = False
69
+ detection_model = "retinaface_resnet50"
70
+
71
+ img = cv2.imread(str(image), cv2.IMREAD_COLOR)
72
+
73
+ upscale = int(upscale)
74
+ if upscale > 4:
75
+ upscale = 4
76
+ if upscale > 2 and max(img.shape[:2]) > 1000:
77
+ upscale = 2
78
+ if max(img.shape[:2]) > 1500:
79
+ upscale = 1
80
+ background_enhance = False
81
+ face_upsample = False
82
+
83
+ face_helper = FaceRestoreHelper(
84
+ upscale,
85
+ face_size=512,
86
+ crop_ratio=(1, 1),
87
+ det_model=detection_model,
88
+ save_ext="png",
89
+ use_parse=True,
90
+ device=device,
91
+ )
92
+ bg_upsampler = upsampler if background_enhance else None
93
+ face_upsampler = upsampler if face_upsample else None
94
+
95
+ if has_aligned:
96
+ # the input faces are already cropped and aligned
97
+ img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
98
+ face_helper.is_gray = is_gray(img, threshold=5)
99
+ if face_helper.is_gray:
100
+ print('\tgrayscale input: True')
101
+ face_helper.cropped_faces = [img]
102
+ else:
103
+ face_helper.read_image(img)
104
+ # get face landmarks for each face
105
+ num_det_faces = face_helper.get_face_landmarks_5(
106
+ only_center_face=only_center_face, resize=640, eye_dist_threshold=5
107
+ )
108
+ print(f'\tdetect {num_det_faces} faces')
109
+ # align and warp each face
110
+ face_helper.align_warp_face()
111
+
112
+ # face restoration for each cropped face
113
+ for idx, cropped_face in enumerate(face_helper.cropped_faces):
114
+ # prepare data
115
+ cropped_face_t = img2tensor(
116
+ cropped_face / 255.0, bgr2rgb=True, float32=True
117
+ )
118
+ normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
119
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
120
+
121
+ try:
122
+ with torch.no_grad():
123
+ output = codeformer_net(
124
+ cropped_face_t, w=codeformer_fidelity, adain=True
125
+ )[0]
126
+ restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
127
+ del output
128
+ torch.cuda.empty_cache()
129
+ except RuntimeError as error:
130
+ print(f"Failed inference for CodeFormer: {error}")
131
+ restored_face = tensor2img(
132
+ cropped_face_t, rgb2bgr=True, min_max=(-1, 1)
133
+ )
134
+
135
+ restored_face = restored_face.astype("uint8")
136
+ face_helper.add_restored_face(restored_face)
137
+
138
+ if not has_aligned:
139
+ # upsample the background
140
+ if bg_upsampler is not None:
141
+ # Now only support RealESRGAN for upsampling background
142
+ bg_img = bg_upsampler.enhance(img, outscale=upscale)[0]
143
+ else:
144
+ bg_img = None
145
+ face_helper.get_inverse_affine(None)
146
+ # paste each restored face to the input image
147
+ if face_upsample and face_upsampler is not None:
148
+ restored_img = face_helper.paste_faces_to_input_image(
149
+ upsample_img=bg_img,
150
+ draw_box=draw_box,
151
+ face_upsampler=face_upsampler,
152
+ )
153
+ else:
154
+ restored_img = face_helper.paste_faces_to_input_image(
155
+ upsample_img=bg_img, draw_box=draw_box
156
+ )
157
+
158
+ # save restored img
159
+ save_path = f'output/out.png'
160
+ imwrite(restored_img, str(save_path))
161
+
162
+ restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
163
+ return restored_img, save_path
164
+ except Exception as error:
165
+ print('Global exception', error)
166
+ return None, None
167
+
168
+
169
+ title = "CodeFormer: Face Restoration "
170
+
171
+ demo = gr.Interface(
172
+ inference, [
173
+ gr.inputs.Image(type="filepath", label="Input"),
174
+ gr.inputs.Checkbox(default=True, label="Background_Enhance"),
175
+ gr.inputs.Checkbox(default=True, label="Face_Upsample"),
176
+ gr.inputs.Number(default=2, label="Rescaling_Factor (up to 4)"),
177
+ gr.Slider(0, 1, value=0.5, step=0.01, label='Codeformer_Fidelity (0 for better quality, 1 for better identity)')
178
+ ], [
179
+ gr.outputs.Image(type="numpy", label="Output"),
180
+ gr.outputs.File(label="Download the output")
181
+
182
+ ],
183
+ title=title,
184
+
185
+ )
186
+
187
+ demo.queue(concurrency_count=2)
188
+ demo.launch()
basicsr/VERSION ADDED
@@ -0,0 +1 @@
 
 
1
+ 1.3.2
basicsr/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/xinntao/BasicSR
2
+ # flake8: noqa
3
+ from .archs import *
4
+ from .data import *
5
+ from .losses import *
6
+ from .metrics import *
7
+ from .models import *
8
+ from .ops import *
9
+ from .train import *
10
+ from .utils import *
11
+ from .version import __gitsha__, __version__
basicsr/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (331 Bytes). View file
 
basicsr/__pycache__/train.cpython-39.pyc ADDED
Binary file (6.3 kB). View file
 
basicsr/__pycache__/version.cpython-39.pyc ADDED
Binary file (209 Bytes). View file
 
basicsr/archs/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from copy import deepcopy
3
+ from os import path as osp
4
+
5
+ from basicsr.utils import get_root_logger, scandir
6
+ from basicsr.utils.registry import ARCH_REGISTRY
7
+
8
+ __all__ = ['build_network']
9
+
10
+ # automatically scan and import arch modules for registry
11
+ # scan all the files under the 'archs' folder and collect files ending with
12
+ # '_arch.py'
13
+ arch_folder = osp.dirname(osp.abspath(__file__))
14
+ arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
15
+ # import all the arch modules
16
+ _arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames]
17
+
18
+
19
+ def build_network(opt):
20
+ opt = deepcopy(opt)
21
+ network_type = opt.pop('type')
22
+ net = ARCH_REGISTRY.get(network_type)(**opt)
23
+ logger = get_root_logger()
24
+ logger.info(f'Network [{net.__class__.__name__}] is created.')
25
+ return net
basicsr/archs/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (1.11 kB). View file
 
basicsr/archs/__pycache__/arcface_arch.cpython-39.pyc ADDED
Binary file (7.38 kB). View file
 
basicsr/archs/__pycache__/arch_util.cpython-39.pyc ADDED
Binary file (10.8 kB). View file
 
basicsr/archs/__pycache__/codeformer_arch.cpython-39.pyc ADDED
Binary file (9.28 kB). View file
 
basicsr/archs/__pycache__/rrdbnet_arch.cpython-39.pyc ADDED
Binary file (4.41 kB). View file
 
basicsr/archs/__pycache__/vgg_arch.cpython-39.pyc ADDED
Binary file (4.8 kB). View file
 
basicsr/archs/__pycache__/vqgan_arch.cpython-39.pyc ADDED
Binary file (11.2 kB). View file
 
basicsr/archs/arcface_arch.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from basicsr.utils.registry import ARCH_REGISTRY
3
+
4
+
5
+ def conv3x3(inplanes, outplanes, stride=1):
6
+ """A simple wrapper for 3x3 convolution with padding.
7
+
8
+ Args:
9
+ inplanes (int): Channel number of inputs.
10
+ outplanes (int): Channel number of outputs.
11
+ stride (int): Stride in convolution. Default: 1.
12
+ """
13
+ return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False)
14
+
15
+
16
+ class BasicBlock(nn.Module):
17
+ """Basic residual block used in the ResNetArcFace architecture.
18
+
19
+ Args:
20
+ inplanes (int): Channel number of inputs.
21
+ planes (int): Channel number of outputs.
22
+ stride (int): Stride in convolution. Default: 1.
23
+ downsample (nn.Module): The downsample module. Default: None.
24
+ """
25
+ expansion = 1 # output channel expansion ratio
26
+
27
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
28
+ super(BasicBlock, self).__init__()
29
+ self.conv1 = conv3x3(inplanes, planes, stride)
30
+ self.bn1 = nn.BatchNorm2d(planes)
31
+ self.relu = nn.ReLU(inplace=True)
32
+ self.conv2 = conv3x3(planes, planes)
33
+ self.bn2 = nn.BatchNorm2d(planes)
34
+ self.downsample = downsample
35
+ self.stride = stride
36
+
37
+ def forward(self, x):
38
+ residual = x
39
+
40
+ out = self.conv1(x)
41
+ out = self.bn1(out)
42
+ out = self.relu(out)
43
+
44
+ out = self.conv2(out)
45
+ out = self.bn2(out)
46
+
47
+ if self.downsample is not None:
48
+ residual = self.downsample(x)
49
+
50
+ out += residual
51
+ out = self.relu(out)
52
+
53
+ return out
54
+
55
+
56
+ class IRBlock(nn.Module):
57
+ """Improved residual block (IR Block) used in the ResNetArcFace architecture.
58
+
59
+ Args:
60
+ inplanes (int): Channel number of inputs.
61
+ planes (int): Channel number of outputs.
62
+ stride (int): Stride in convolution. Default: 1.
63
+ downsample (nn.Module): The downsample module. Default: None.
64
+ use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
65
+ """
66
+ expansion = 1 # output channel expansion ratio
67
+
68
+ def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
69
+ super(IRBlock, self).__init__()
70
+ self.bn0 = nn.BatchNorm2d(inplanes)
71
+ self.conv1 = conv3x3(inplanes, inplanes)
72
+ self.bn1 = nn.BatchNorm2d(inplanes)
73
+ self.prelu = nn.PReLU()
74
+ self.conv2 = conv3x3(inplanes, planes, stride)
75
+ self.bn2 = nn.BatchNorm2d(planes)
76
+ self.downsample = downsample
77
+ self.stride = stride
78
+ self.use_se = use_se
79
+ if self.use_se:
80
+ self.se = SEBlock(planes)
81
+
82
+ def forward(self, x):
83
+ residual = x
84
+ out = self.bn0(x)
85
+ out = self.conv1(out)
86
+ out = self.bn1(out)
87
+ out = self.prelu(out)
88
+
89
+ out = self.conv2(out)
90
+ out = self.bn2(out)
91
+ if self.use_se:
92
+ out = self.se(out)
93
+
94
+ if self.downsample is not None:
95
+ residual = self.downsample(x)
96
+
97
+ out += residual
98
+ out = self.prelu(out)
99
+
100
+ return out
101
+
102
+
103
+ class Bottleneck(nn.Module):
104
+ """Bottleneck block used in the ResNetArcFace architecture.
105
+
106
+ Args:
107
+ inplanes (int): Channel number of inputs.
108
+ planes (int): Channel number of outputs.
109
+ stride (int): Stride in convolution. Default: 1.
110
+ downsample (nn.Module): The downsample module. Default: None.
111
+ """
112
+ expansion = 4 # output channel expansion ratio
113
+
114
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
115
+ super(Bottleneck, self).__init__()
116
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
117
+ self.bn1 = nn.BatchNorm2d(planes)
118
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
119
+ self.bn2 = nn.BatchNorm2d(planes)
120
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
121
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
122
+ self.relu = nn.ReLU(inplace=True)
123
+ self.downsample = downsample
124
+ self.stride = stride
125
+
126
+ def forward(self, x):
127
+ residual = x
128
+
129
+ out = self.conv1(x)
130
+ out = self.bn1(out)
131
+ out = self.relu(out)
132
+
133
+ out = self.conv2(out)
134
+ out = self.bn2(out)
135
+ out = self.relu(out)
136
+
137
+ out = self.conv3(out)
138
+ out = self.bn3(out)
139
+
140
+ if self.downsample is not None:
141
+ residual = self.downsample(x)
142
+
143
+ out += residual
144
+ out = self.relu(out)
145
+
146
+ return out
147
+
148
+
149
+ class SEBlock(nn.Module):
150
+ """The squeeze-and-excitation block (SEBlock) used in the IRBlock.
151
+
152
+ Args:
153
+ channel (int): Channel number of inputs.
154
+ reduction (int): Channel reduction ration. Default: 16.
155
+ """
156
+
157
+ def __init__(self, channel, reduction=16):
158
+ super(SEBlock, self).__init__()
159
+ self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information
160
+ self.fc = nn.Sequential(
161
+ nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),
162
+ nn.Sigmoid())
163
+
164
+ def forward(self, x):
165
+ b, c, _, _ = x.size()
166
+ y = self.avg_pool(x).view(b, c)
167
+ y = self.fc(y).view(b, c, 1, 1)
168
+ return x * y
169
+
170
+
171
+ @ARCH_REGISTRY.register()
172
+ class ResNetArcFace(nn.Module):
173
+ """ArcFace with ResNet architectures.
174
+
175
+ Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
176
+
177
+ Args:
178
+ block (str): Block used in the ArcFace architecture.
179
+ layers (tuple(int)): Block numbers in each layer.
180
+ use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
181
+ """
182
+
183
+ def __init__(self, block, layers, use_se=True):
184
+ if block == 'IRBlock':
185
+ block = IRBlock
186
+ self.inplanes = 64
187
+ self.use_se = use_se
188
+ super(ResNetArcFace, self).__init__()
189
+
190
+ self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
191
+ self.bn1 = nn.BatchNorm2d(64)
192
+ self.prelu = nn.PReLU()
193
+ self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
194
+ self.layer1 = self._make_layer(block, 64, layers[0])
195
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
196
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
197
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
198
+ self.bn4 = nn.BatchNorm2d(512)
199
+ self.dropout = nn.Dropout()
200
+ self.fc5 = nn.Linear(512 * 8 * 8, 512)
201
+ self.bn5 = nn.BatchNorm1d(512)
202
+
203
+ # initialization
204
+ for m in self.modules():
205
+ if isinstance(m, nn.Conv2d):
206
+ nn.init.xavier_normal_(m.weight)
207
+ elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
208
+ nn.init.constant_(m.weight, 1)
209
+ nn.init.constant_(m.bias, 0)
210
+ elif isinstance(m, nn.Linear):
211
+ nn.init.xavier_normal_(m.weight)
212
+ nn.init.constant_(m.bias, 0)
213
+
214
+ def _make_layer(self, block, planes, num_blocks, stride=1):
215
+ downsample = None
216
+ if stride != 1 or self.inplanes != planes * block.expansion:
217
+ downsample = nn.Sequential(
218
+ nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
219
+ nn.BatchNorm2d(planes * block.expansion),
220
+ )
221
+ layers = []
222
+ layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
223
+ self.inplanes = planes
224
+ for _ in range(1, num_blocks):
225
+ layers.append(block(self.inplanes, planes, use_se=self.use_se))
226
+
227
+ return nn.Sequential(*layers)
228
+
229
+ def forward(self, x):
230
+ x = self.conv1(x)
231
+ x = self.bn1(x)
232
+ x = self.prelu(x)
233
+ x = self.maxpool(x)
234
+
235
+ x = self.layer1(x)
236
+ x = self.layer2(x)
237
+ x = self.layer3(x)
238
+ x = self.layer4(x)
239
+ x = self.bn4(x)
240
+ x = self.dropout(x)
241
+ x = x.view(x.size(0), -1)
242
+ x = self.fc5(x)
243
+ x = self.bn5(x)
244
+
245
+ return x
basicsr/archs/arch_util.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections.abc
2
+ import math
3
+ import torch
4
+ import torchvision
5
+ import warnings
6
+ from distutils.version import LooseVersion
7
+ from itertools import repeat
8
+ from torch import nn as nn
9
+ from torch.nn import functional as F
10
+ from torch.nn import init as init
11
+ from torch.nn.modules.batchnorm import _BatchNorm
12
+
13
+ from basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv
14
+ from basicsr.utils import get_root_logger
15
+
16
+
17
+ @torch.no_grad()
18
+ def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
19
+ """Initialize network weights.
20
+
21
+ Args:
22
+ module_list (list[nn.Module] | nn.Module): Modules to be initialized.
23
+ scale (float): Scale initialized weights, especially for residual
24
+ blocks. Default: 1.
25
+ bias_fill (float): The value to fill bias. Default: 0
26
+ kwargs (dict): Other arguments for initialization function.
27
+ """
28
+ if not isinstance(module_list, list):
29
+ module_list = [module_list]
30
+ for module in module_list:
31
+ for m in module.modules():
32
+ if isinstance(m, nn.Conv2d):
33
+ init.kaiming_normal_(m.weight, **kwargs)
34
+ m.weight.data *= scale
35
+ if m.bias is not None:
36
+ m.bias.data.fill_(bias_fill)
37
+ elif isinstance(m, nn.Linear):
38
+ init.kaiming_normal_(m.weight, **kwargs)
39
+ m.weight.data *= scale
40
+ if m.bias is not None:
41
+ m.bias.data.fill_(bias_fill)
42
+ elif isinstance(m, _BatchNorm):
43
+ init.constant_(m.weight, 1)
44
+ if m.bias is not None:
45
+ m.bias.data.fill_(bias_fill)
46
+
47
+
48
+ def make_layer(basic_block, num_basic_block, **kwarg):
49
+ """Make layers by stacking the same blocks.
50
+
51
+ Args:
52
+ basic_block (nn.module): nn.module class for basic block.
53
+ num_basic_block (int): number of blocks.
54
+
55
+ Returns:
56
+ nn.Sequential: Stacked blocks in nn.Sequential.
57
+ """
58
+ layers = []
59
+ for _ in range(num_basic_block):
60
+ layers.append(basic_block(**kwarg))
61
+ return nn.Sequential(*layers)
62
+
63
+
64
+ class ResidualBlockNoBN(nn.Module):
65
+ """Residual block without BN.
66
+
67
+ It has a style of:
68
+ ---Conv-ReLU-Conv-+-
69
+ |________________|
70
+
71
+ Args:
72
+ num_feat (int): Channel number of intermediate features.
73
+ Default: 64.
74
+ res_scale (float): Residual scale. Default: 1.
75
+ pytorch_init (bool): If set to True, use pytorch default init,
76
+ otherwise, use default_init_weights. Default: False.
77
+ """
78
+
79
+ def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
80
+ super(ResidualBlockNoBN, self).__init__()
81
+ self.res_scale = res_scale
82
+ self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
83
+ self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
84
+ self.relu = nn.ReLU(inplace=True)
85
+
86
+ if not pytorch_init:
87
+ default_init_weights([self.conv1, self.conv2], 0.1)
88
+
89
+ def forward(self, x):
90
+ identity = x
91
+ out = self.conv2(self.relu(self.conv1(x)))
92
+ return identity + out * self.res_scale
93
+
94
+
95
+ class Upsample(nn.Sequential):
96
+ """Upsample module.
97
+
98
+ Args:
99
+ scale (int): Scale factor. Supported scales: 2^n and 3.
100
+ num_feat (int): Channel number of intermediate features.
101
+ """
102
+
103
+ def __init__(self, scale, num_feat):
104
+ m = []
105
+ if (scale & (scale - 1)) == 0: # scale = 2^n
106
+ for _ in range(int(math.log(scale, 2))):
107
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
108
+ m.append(nn.PixelShuffle(2))
109
+ elif scale == 3:
110
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
111
+ m.append(nn.PixelShuffle(3))
112
+ else:
113
+ raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
114
+ super(Upsample, self).__init__(*m)
115
+
116
+
117
+ def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
118
+ """Warp an image or feature map with optical flow.
119
+
120
+ Args:
121
+ x (Tensor): Tensor with size (n, c, h, w).
122
+ flow (Tensor): Tensor with size (n, h, w, 2), normal value.
123
+ interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
124
+ padding_mode (str): 'zeros' or 'border' or 'reflection'.
125
+ Default: 'zeros'.
126
+ align_corners (bool): Before pytorch 1.3, the default value is
127
+ align_corners=True. After pytorch 1.3, the default value is
128
+ align_corners=False. Here, we use the True as default.
129
+
130
+ Returns:
131
+ Tensor: Warped image or feature map.
132
+ """
133
+ assert x.size()[-2:] == flow.size()[1:3]
134
+ _, _, h, w = x.size()
135
+ # create mesh grid
136
+ grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
137
+ grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
138
+ grid.requires_grad = False
139
+
140
+ vgrid = grid + flow
141
+ # scale grid to [-1,1]
142
+ vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
143
+ vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
144
+ vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
145
+ output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
146
+
147
+ # TODO, what if align_corners=False
148
+ return output
149
+
150
+
151
+ def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
152
+ """Resize a flow according to ratio or shape.
153
+
154
+ Args:
155
+ flow (Tensor): Precomputed flow. shape [N, 2, H, W].
156
+ size_type (str): 'ratio' or 'shape'.
157
+ sizes (list[int | float]): the ratio for resizing or the final output
158
+ shape.
159
+ 1) The order of ratio should be [ratio_h, ratio_w]. For
160
+ downsampling, the ratio should be smaller than 1.0 (i.e., ratio
161
+ < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
162
+ ratio > 1.0).
163
+ 2) The order of output_size should be [out_h, out_w].
164
+ interp_mode (str): The mode of interpolation for resizing.
165
+ Default: 'bilinear'.
166
+ align_corners (bool): Whether align corners. Default: False.
167
+
168
+ Returns:
169
+ Tensor: Resized flow.
170
+ """
171
+ _, _, flow_h, flow_w = flow.size()
172
+ if size_type == 'ratio':
173
+ output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
174
+ elif size_type == 'shape':
175
+ output_h, output_w = sizes[0], sizes[1]
176
+ else:
177
+ raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
178
+
179
+ input_flow = flow.clone()
180
+ ratio_h = output_h / flow_h
181
+ ratio_w = output_w / flow_w
182
+ input_flow[:, 0, :, :] *= ratio_w
183
+ input_flow[:, 1, :, :] *= ratio_h
184
+ resized_flow = F.interpolate(
185
+ input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
186
+ return resized_flow
187
+
188
+
189
+ # TODO: may write a cpp file
190
+ def pixel_unshuffle(x, scale):
191
+ """ Pixel unshuffle.
192
+
193
+ Args:
194
+ x (Tensor): Input feature with shape (b, c, hh, hw).
195
+ scale (int): Downsample ratio.
196
+
197
+ Returns:
198
+ Tensor: the pixel unshuffled feature.
199
+ """
200
+ b, c, hh, hw = x.size()
201
+ out_channel = c * (scale**2)
202
+ assert hh % scale == 0 and hw % scale == 0
203
+ h = hh // scale
204
+ w = hw // scale
205
+ x_view = x.view(b, c, h, scale, w, scale)
206
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
207
+
208
+
209
+ class DCNv2Pack(ModulatedDeformConvPack):
210
+ """Modulated deformable conv for deformable alignment.
211
+
212
+ Different from the official DCNv2Pack, which generates offsets and masks
213
+ from the preceding features, this DCNv2Pack takes another different
214
+ features to generate offsets and masks.
215
+
216
+ Ref:
217
+ Delving Deep into Deformable Alignment in Video Super-Resolution.
218
+ """
219
+
220
+ def forward(self, x, feat):
221
+ out = self.conv_offset(feat)
222
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
223
+ offset = torch.cat((o1, o2), dim=1)
224
+ mask = torch.sigmoid(mask)
225
+
226
+ offset_absmean = torch.mean(torch.abs(offset))
227
+ if offset_absmean > 50:
228
+ logger = get_root_logger()
229
+ logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.')
230
+
231
+ if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'):
232
+ return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
233
+ self.dilation, mask)
234
+ else:
235
+ return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding,
236
+ self.dilation, self.groups, self.deformable_groups)
237
+
238
+
239
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
240
+ # From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
241
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
242
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
243
+ def norm_cdf(x):
244
+ # Computes standard normal cumulative distribution function
245
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
246
+
247
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
248
+ warnings.warn(
249
+ 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
250
+ 'The distribution of values may be incorrect.',
251
+ stacklevel=2)
252
+
253
+ with torch.no_grad():
254
+ # Values are generated by using a truncated uniform distribution and
255
+ # then using the inverse CDF for the normal distribution.
256
+ # Get upper and lower cdf values
257
+ low = norm_cdf((a - mean) / std)
258
+ up = norm_cdf((b - mean) / std)
259
+
260
+ # Uniformly fill tensor with values from [low, up], then translate to
261
+ # [2l-1, 2u-1].
262
+ tensor.uniform_(2 * low - 1, 2 * up - 1)
263
+
264
+ # Use inverse cdf transform for normal distribution to get truncated
265
+ # standard normal
266
+ tensor.erfinv_()
267
+
268
+ # Transform to proper mean, std
269
+ tensor.mul_(std * math.sqrt(2.))
270
+ tensor.add_(mean)
271
+
272
+ # Clamp to ensure it's in the proper range
273
+ tensor.clamp_(min=a, max=b)
274
+ return tensor
275
+
276
+
277
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
278
+ r"""Fills the input Tensor with values drawn from a truncated
279
+ normal distribution.
280
+
281
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
282
+
283
+ The values are effectively drawn from the
284
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
285
+ with values outside :math:`[a, b]` redrawn until they are within
286
+ the bounds. The method used for generating the random values works
287
+ best when :math:`a \leq \text{mean} \leq b`.
288
+
289
+ Args:
290
+ tensor: an n-dimensional `torch.Tensor`
291
+ mean: the mean of the normal distribution
292
+ std: the standard deviation of the normal distribution
293
+ a: the minimum cutoff value
294
+ b: the maximum cutoff value
295
+
296
+ Examples:
297
+ >>> w = torch.empty(3, 5)
298
+ >>> nn.init.trunc_normal_(w)
299
+ """
300
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
301
+
302
+
303
+ # From PyTorch
304
+ def _ntuple(n):
305
+
306
+ def parse(x):
307
+ if isinstance(x, collections.abc.Iterable):
308
+ return x
309
+ return tuple(repeat(x, n))
310
+
311
+ return parse
312
+
313
+
314
+ to_1tuple = _ntuple(1)
315
+ to_2tuple = _ntuple(2)
316
+ to_3tuple = _ntuple(3)
317
+ to_4tuple = _ntuple(4)
318
+ to_ntuple = _ntuple
basicsr/archs/codeformer_arch.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn, Tensor
5
+ import torch.nn.functional as F
6
+ from typing import Optional, List
7
+
8
+ from basicsr.archs.vqgan_arch import *
9
+ from basicsr.utils import get_root_logger
10
+ from basicsr.utils.registry import ARCH_REGISTRY
11
+
12
+ def calc_mean_std(feat, eps=1e-5):
13
+ """Calculate mean and std for adaptive_instance_normalization.
14
+
15
+ Args:
16
+ feat (Tensor): 4D tensor.
17
+ eps (float): A small value added to the variance to avoid
18
+ divide-by-zero. Default: 1e-5.
19
+ """
20
+ size = feat.size()
21
+ assert len(size) == 4, 'The input feature should be 4D tensor.'
22
+ b, c = size[:2]
23
+ feat_var = feat.view(b, c, -1).var(dim=2) + eps
24
+ feat_std = feat_var.sqrt().view(b, c, 1, 1)
25
+ feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
26
+ return feat_mean, feat_std
27
+
28
+
29
+ def adaptive_instance_normalization(content_feat, style_feat):
30
+ """Adaptive instance normalization.
31
+
32
+ Adjust the reference features to have the similar color and illuminations
33
+ as those in the degradate features.
34
+
35
+ Args:
36
+ content_feat (Tensor): The reference feature.
37
+ style_feat (Tensor): The degradate features.
38
+ """
39
+ size = content_feat.size()
40
+ style_mean, style_std = calc_mean_std(style_feat)
41
+ content_mean, content_std = calc_mean_std(content_feat)
42
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
43
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
44
+
45
+
46
+ class PositionEmbeddingSine(nn.Module):
47
+ """
48
+ This is a more standard version of the position embedding, very similar to the one
49
+ used by the Attention is all you need paper, generalized to work on images.
50
+ """
51
+
52
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
53
+ super().__init__()
54
+ self.num_pos_feats = num_pos_feats
55
+ self.temperature = temperature
56
+ self.normalize = normalize
57
+ if scale is not None and normalize is False:
58
+ raise ValueError("normalize should be True if scale is passed")
59
+ if scale is None:
60
+ scale = 2 * math.pi
61
+ self.scale = scale
62
+
63
+ def forward(self, x, mask=None):
64
+ if mask is None:
65
+ mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
66
+ not_mask = ~mask
67
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
68
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
69
+ if self.normalize:
70
+ eps = 1e-6
71
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
72
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
73
+
74
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
75
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
76
+
77
+ pos_x = x_embed[:, :, :, None] / dim_t
78
+ pos_y = y_embed[:, :, :, None] / dim_t
79
+ pos_x = torch.stack(
80
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
81
+ ).flatten(3)
82
+ pos_y = torch.stack(
83
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
84
+ ).flatten(3)
85
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
86
+ return pos
87
+
88
+ def _get_activation_fn(activation):
89
+ """Return an activation function given a string"""
90
+ if activation == "relu":
91
+ return F.relu
92
+ if activation == "gelu":
93
+ return F.gelu
94
+ if activation == "glu":
95
+ return F.glu
96
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
97
+
98
+
99
+ class TransformerSALayer(nn.Module):
100
+ def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
101
+ super().__init__()
102
+ self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
103
+ # Implementation of Feedforward model - MLP
104
+ self.linear1 = nn.Linear(embed_dim, dim_mlp)
105
+ self.dropout = nn.Dropout(dropout)
106
+ self.linear2 = nn.Linear(dim_mlp, embed_dim)
107
+
108
+ self.norm1 = nn.LayerNorm(embed_dim)
109
+ self.norm2 = nn.LayerNorm(embed_dim)
110
+ self.dropout1 = nn.Dropout(dropout)
111
+ self.dropout2 = nn.Dropout(dropout)
112
+
113
+ self.activation = _get_activation_fn(activation)
114
+
115
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
116
+ return tensor if pos is None else tensor + pos
117
+
118
+ def forward(self, tgt,
119
+ tgt_mask: Optional[Tensor] = None,
120
+ tgt_key_padding_mask: Optional[Tensor] = None,
121
+ query_pos: Optional[Tensor] = None):
122
+
123
+ # self attention
124
+ tgt2 = self.norm1(tgt)
125
+ q = k = self.with_pos_embed(tgt2, query_pos)
126
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
127
+ key_padding_mask=tgt_key_padding_mask)[0]
128
+ tgt = tgt + self.dropout1(tgt2)
129
+
130
+ # ffn
131
+ tgt2 = self.norm2(tgt)
132
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
133
+ tgt = tgt + self.dropout2(tgt2)
134
+ return tgt
135
+
136
+ class Fuse_sft_block(nn.Module):
137
+ def __init__(self, in_ch, out_ch):
138
+ super().__init__()
139
+ self.encode_enc = ResBlock(2*in_ch, out_ch)
140
+
141
+ self.scale = nn.Sequential(
142
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
143
+ nn.LeakyReLU(0.2, True),
144
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
145
+
146
+ self.shift = nn.Sequential(
147
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
148
+ nn.LeakyReLU(0.2, True),
149
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
150
+
151
+ def forward(self, enc_feat, dec_feat, w=1):
152
+ enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
153
+ scale = self.scale(enc_feat)
154
+ shift = self.shift(enc_feat)
155
+ residual = w * (dec_feat * scale + shift)
156
+ out = dec_feat + residual
157
+ return out
158
+
159
+
160
+ @ARCH_REGISTRY.register()
161
+ class CodeFormer(VQAutoEncoder):
162
+ def __init__(self, dim_embd=512, n_head=8, n_layers=9,
163
+ codebook_size=1024, latent_size=256,
164
+ connect_list=['32', '64', '128', '256'],
165
+ fix_modules=['quantize','generator'], vqgan_path=None):
166
+ super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
167
+
168
+ if vqgan_path is not None:
169
+ self.load_state_dict(
170
+ torch.load(vqgan_path, map_location='cpu')['params_ema'])
171
+
172
+ if fix_modules is not None:
173
+ for module in fix_modules:
174
+ for param in getattr(self, module).parameters():
175
+ param.requires_grad = False
176
+
177
+ self.connect_list = connect_list
178
+ self.n_layers = n_layers
179
+ self.dim_embd = dim_embd
180
+ self.dim_mlp = dim_embd*2
181
+
182
+ self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
183
+ self.feat_emb = nn.Linear(256, self.dim_embd)
184
+
185
+ # transformer
186
+ self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
187
+ for _ in range(self.n_layers)])
188
+
189
+ # logits_predict head
190
+ self.idx_pred_layer = nn.Sequential(
191
+ nn.LayerNorm(dim_embd),
192
+ nn.Linear(dim_embd, codebook_size, bias=False))
193
+
194
+ self.channels = {
195
+ '16': 512,
196
+ '32': 256,
197
+ '64': 256,
198
+ '128': 128,
199
+ '256': 128,
200
+ '512': 64,
201
+ }
202
+
203
+ # after second residual block for > 16, before attn layer for ==16
204
+ self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
205
+ # after first residual block for > 16, before attn layer for ==16
206
+ self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
207
+
208
+ # fuse_convs_dict
209
+ self.fuse_convs_dict = nn.ModuleDict()
210
+ for f_size in self.connect_list:
211
+ in_ch = self.channels[f_size]
212
+ self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
213
+
214
+ def _init_weights(self, module):
215
+ if isinstance(module, (nn.Linear, nn.Embedding)):
216
+ module.weight.data.normal_(mean=0.0, std=0.02)
217
+ if isinstance(module, nn.Linear) and module.bias is not None:
218
+ module.bias.data.zero_()
219
+ elif isinstance(module, nn.LayerNorm):
220
+ module.bias.data.zero_()
221
+ module.weight.data.fill_(1.0)
222
+
223
+ def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
224
+ # ################### Encoder #####################
225
+ enc_feat_dict = {}
226
+ out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
227
+ for i, block in enumerate(self.encoder.blocks):
228
+ x = block(x)
229
+ if i in out_list:
230
+ enc_feat_dict[str(x.shape[-1])] = x.clone()
231
+
232
+ lq_feat = x
233
+ # ################# Transformer ###################
234
+ # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
235
+ pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
236
+ # BCHW -> BC(HW) -> (HW)BC
237
+ feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
238
+ query_emb = feat_emb
239
+ # Transformer encoder
240
+ for layer in self.ft_layers:
241
+ query_emb = layer(query_emb, query_pos=pos_emb)
242
+
243
+ # output logits
244
+ logits = self.idx_pred_layer(query_emb) # (hw)bn
245
+ logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
246
+
247
+ if code_only: # for training stage II
248
+ # logits doesn't need softmax before cross_entropy loss
249
+ return logits, lq_feat
250
+
251
+ # ################# Quantization ###################
252
+ # if self.training:
253
+ # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
254
+ # # b(hw)c -> bc(hw) -> bchw
255
+ # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
256
+ # ------------
257
+ soft_one_hot = F.softmax(logits, dim=2)
258
+ _, top_idx = torch.topk(soft_one_hot, 1, dim=2)
259
+ quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
260
+ # preserve gradients
261
+ # quant_feat = lq_feat + (quant_feat - lq_feat).detach()
262
+
263
+ if detach_16:
264
+ quant_feat = quant_feat.detach() # for training stage III
265
+ if adain:
266
+ quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
267
+
268
+ # ################## Generator ####################
269
+ x = quant_feat
270
+ fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
271
+
272
+ for i, block in enumerate(self.generator.blocks):
273
+ x = block(x)
274
+ if i in fuse_list: # fuse after i-th block
275
+ f_size = str(x.shape[-1])
276
+ if w>0:
277
+ x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
278
+ out = x
279
+ # logits doesn't need softmax before cross_entropy loss
280
+ return out, logits, lq_feat
basicsr/archs/rrdbnet_arch.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn as nn
3
+ from torch.nn import functional as F
4
+
5
+ from basicsr.utils.registry import ARCH_REGISTRY
6
+ from .arch_util import default_init_weights, make_layer, pixel_unshuffle
7
+
8
+
9
+ class ResidualDenseBlock(nn.Module):
10
+ """Residual Dense Block.
11
+
12
+ Used in RRDB block in ESRGAN.
13
+
14
+ Args:
15
+ num_feat (int): Channel number of intermediate features.
16
+ num_grow_ch (int): Channels for each growth.
17
+ """
18
+
19
+ def __init__(self, num_feat=64, num_grow_ch=32):
20
+ super(ResidualDenseBlock, self).__init__()
21
+ self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
22
+ self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
23
+ self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
24
+ self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
25
+ self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
26
+
27
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
28
+
29
+ # initialization
30
+ default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
31
+
32
+ def forward(self, x):
33
+ x1 = self.lrelu(self.conv1(x))
34
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
35
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
36
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
37
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
38
+ # Emperically, we use 0.2 to scale the residual for better performance
39
+ return x5 * 0.2 + x
40
+
41
+
42
+ class RRDB(nn.Module):
43
+ """Residual in Residual Dense Block.
44
+
45
+ Used in RRDB-Net in ESRGAN.
46
+
47
+ Args:
48
+ num_feat (int): Channel number of intermediate features.
49
+ num_grow_ch (int): Channels for each growth.
50
+ """
51
+
52
+ def __init__(self, num_feat, num_grow_ch=32):
53
+ super(RRDB, self).__init__()
54
+ self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
55
+ self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
56
+ self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
57
+
58
+ def forward(self, x):
59
+ out = self.rdb1(x)
60
+ out = self.rdb2(out)
61
+ out = self.rdb3(out)
62
+ # Emperically, we use 0.2 to scale the residual for better performance
63
+ return out * 0.2 + x
64
+
65
+
66
+ @ARCH_REGISTRY.register()
67
+ class RRDBNet(nn.Module):
68
+ """Networks consisting of Residual in Residual Dense Block, which is used
69
+ in ESRGAN.
70
+
71
+ ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
72
+
73
+ We extend ESRGAN for scale x2 and scale x1.
74
+ Note: This is one option for scale 1, scale 2 in RRDBNet.
75
+ We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
76
+ and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
77
+
78
+ Args:
79
+ num_in_ch (int): Channel number of inputs.
80
+ num_out_ch (int): Channel number of outputs.
81
+ num_feat (int): Channel number of intermediate features.
82
+ Default: 64
83
+ num_block (int): Block number in the trunk network. Defaults: 23
84
+ num_grow_ch (int): Channels for each growth. Default: 32.
85
+ """
86
+
87
+ def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
88
+ super(RRDBNet, self).__init__()
89
+ self.scale = scale
90
+ if scale == 2:
91
+ num_in_ch = num_in_ch * 4
92
+ elif scale == 1:
93
+ num_in_ch = num_in_ch * 16
94
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
95
+ self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
96
+ self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
97
+ # upsample
98
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
99
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
100
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
101
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
102
+
103
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
104
+
105
+ def forward(self, x):
106
+ if self.scale == 2:
107
+ feat = pixel_unshuffle(x, scale=2)
108
+ elif self.scale == 1:
109
+ feat = pixel_unshuffle(x, scale=4)
110
+ else:
111
+ feat = x
112
+ feat = self.conv_first(feat)
113
+ body_feat = self.conv_body(self.body(feat))
114
+ feat = feat + body_feat
115
+ # upsample
116
+ feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
117
+ feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
118
+ out = self.conv_last(self.lrelu(self.conv_hr(feat)))
119
+ return out
basicsr/archs/vgg_arch.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from collections import OrderedDict
4
+ from torch import nn as nn
5
+ from torchvision.models import vgg as vgg
6
+
7
+ from basicsr.utils.registry import ARCH_REGISTRY
8
+
9
+ VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth'
10
+ NAMES = {
11
+ 'vgg11': [
12
+ 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
13
+ 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
14
+ 'pool5'
15
+ ],
16
+ 'vgg13': [
17
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
18
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
19
+ 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
20
+ ],
21
+ 'vgg16': [
22
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
23
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
24
+ 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
25
+ 'pool5'
26
+ ],
27
+ 'vgg19': [
28
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
29
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
30
+ 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
31
+ 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
32
+ ]
33
+ }
34
+
35
+
36
+ def insert_bn(names):
37
+ """Insert bn layer after each conv.
38
+
39
+ Args:
40
+ names (list): The list of layer names.
41
+
42
+ Returns:
43
+ list: The list of layer names with bn layers.
44
+ """
45
+ names_bn = []
46
+ for name in names:
47
+ names_bn.append(name)
48
+ if 'conv' in name:
49
+ position = name.replace('conv', '')
50
+ names_bn.append('bn' + position)
51
+ return names_bn
52
+
53
+
54
+ @ARCH_REGISTRY.register()
55
+ class VGGFeatureExtractor(nn.Module):
56
+ """VGG network for feature extraction.
57
+
58
+ In this implementation, we allow users to choose whether use normalization
59
+ in the input feature and the type of vgg network. Note that the pretrained
60
+ path must fit the vgg type.
61
+
62
+ Args:
63
+ layer_name_list (list[str]): Forward function returns the corresponding
64
+ features according to the layer_name_list.
65
+ Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
66
+ vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
67
+ use_input_norm (bool): If True, normalize the input image. Importantly,
68
+ the input feature must in the range [0, 1]. Default: True.
69
+ range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
70
+ Default: False.
71
+ requires_grad (bool): If true, the parameters of VGG network will be
72
+ optimized. Default: False.
73
+ remove_pooling (bool): If true, the max pooling operations in VGG net
74
+ will be removed. Default: False.
75
+ pooling_stride (int): The stride of max pooling operation. Default: 2.
76
+ """
77
+
78
+ def __init__(self,
79
+ layer_name_list,
80
+ vgg_type='vgg19',
81
+ use_input_norm=True,
82
+ range_norm=False,
83
+ requires_grad=False,
84
+ remove_pooling=False,
85
+ pooling_stride=2):
86
+ super(VGGFeatureExtractor, self).__init__()
87
+
88
+ self.layer_name_list = layer_name_list
89
+ self.use_input_norm = use_input_norm
90
+ self.range_norm = range_norm
91
+
92
+ self.names = NAMES[vgg_type.replace('_bn', '')]
93
+ if 'bn' in vgg_type:
94
+ self.names = insert_bn(self.names)
95
+
96
+ # only borrow layers that will be used to avoid unused params
97
+ max_idx = 0
98
+ for v in layer_name_list:
99
+ idx = self.names.index(v)
100
+ if idx > max_idx:
101
+ max_idx = idx
102
+
103
+ if os.path.exists(VGG_PRETRAIN_PATH):
104
+ vgg_net = getattr(vgg, vgg_type)(pretrained=False)
105
+ state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage)
106
+ vgg_net.load_state_dict(state_dict)
107
+ else:
108
+ vgg_net = getattr(vgg, vgg_type)(pretrained=True)
109
+
110
+ features = vgg_net.features[:max_idx + 1]
111
+
112
+ modified_net = OrderedDict()
113
+ for k, v in zip(self.names, features):
114
+ if 'pool' in k:
115
+ # if remove_pooling is true, pooling operation will be removed
116
+ if remove_pooling:
117
+ continue
118
+ else:
119
+ # in some cases, we may want to change the default stride
120
+ modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
121
+ else:
122
+ modified_net[k] = v
123
+
124
+ self.vgg_net = nn.Sequential(modified_net)
125
+
126
+ if not requires_grad:
127
+ self.vgg_net.eval()
128
+ for param in self.parameters():
129
+ param.requires_grad = False
130
+ else:
131
+ self.vgg_net.train()
132
+ for param in self.parameters():
133
+ param.requires_grad = True
134
+
135
+ if self.use_input_norm:
136
+ # the mean is for image with range [0, 1]
137
+ self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
138
+ # the std is for image with range [0, 1]
139
+ self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
140
+
141
+ def forward(self, x):
142
+ """Forward function.
143
+
144
+ Args:
145
+ x (Tensor): Input tensor with shape (n, c, h, w).
146
+
147
+ Returns:
148
+ Tensor: Forward results.
149
+ """
150
+ if self.range_norm:
151
+ x = (x + 1) / 2
152
+ if self.use_input_norm:
153
+ x = (x - self.mean) / self.std
154
+ output = {}
155
+
156
+ for key, layer in self.vgg_net._modules.items():
157
+ x = layer(x)
158
+ if key in self.layer_name_list:
159
+ output[key] = x.clone()
160
+
161
+ return output
basicsr/archs/vqgan_arch.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ VQGAN code, adapted from the original created by the Unleashing Transformers authors:
3
+ https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
4
+
5
+ '''
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import copy
11
+ from basicsr.utils import get_root_logger
12
+ from basicsr.utils.registry import ARCH_REGISTRY
13
+
14
+ def normalize(in_channels):
15
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
16
+
17
+
18
+ @torch.jit.script
19
+ def swish(x):
20
+ return x*torch.sigmoid(x)
21
+
22
+
23
+ # Define VQVAE classes
24
+ class VectorQuantizer(nn.Module):
25
+ def __init__(self, codebook_size, emb_dim, beta):
26
+ super(VectorQuantizer, self).__init__()
27
+ self.codebook_size = codebook_size # number of embeddings
28
+ self.emb_dim = emb_dim # dimension of embedding
29
+ self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
30
+ self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
31
+ self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
32
+
33
+ def forward(self, z):
34
+ # reshape z -> (batch, height, width, channel) and flatten
35
+ z = z.permute(0, 2, 3, 1).contiguous()
36
+ z_flattened = z.view(-1, self.emb_dim)
37
+
38
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
39
+ d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
40
+ 2 * torch.matmul(z_flattened, self.embedding.weight.t())
41
+
42
+ mean_distance = torch.mean(d)
43
+ # find closest encodings
44
+ min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
45
+ # min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
46
+ # [0-1], higher score, higher confidence
47
+ # min_encoding_scores = torch.exp(-min_encoding_scores/10)
48
+
49
+ min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
50
+ min_encodings.scatter_(1, min_encoding_indices, 1)
51
+
52
+ # get quantized latent vectors
53
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
54
+ # compute loss for embedding
55
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
56
+ # preserve gradients
57
+ z_q = z + (z_q - z).detach()
58
+
59
+ # perplexity
60
+ e_mean = torch.mean(min_encodings, dim=0)
61
+ perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
62
+ # reshape back to match original input shape
63
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
64
+
65
+ return z_q, loss, {
66
+ "perplexity": perplexity,
67
+ "min_encodings": min_encodings,
68
+ "min_encoding_indices": min_encoding_indices,
69
+ "mean_distance": mean_distance
70
+ }
71
+
72
+ def get_codebook_feat(self, indices, shape):
73
+ # input indices: batch*token_num -> (batch*token_num)*1
74
+ # shape: batch, height, width, channel
75
+ indices = indices.view(-1,1)
76
+ min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
77
+ min_encodings.scatter_(1, indices, 1)
78
+ # get quantized latent vectors
79
+ z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
80
+
81
+ if shape is not None: # reshape back to match original input shape
82
+ z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
83
+
84
+ return z_q
85
+
86
+
87
+ class GumbelQuantizer(nn.Module):
88
+ def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
89
+ super().__init__()
90
+ self.codebook_size = codebook_size # number of embeddings
91
+ self.emb_dim = emb_dim # dimension of embedding
92
+ self.straight_through = straight_through
93
+ self.temperature = temp_init
94
+ self.kl_weight = kl_weight
95
+ self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
96
+ self.embed = nn.Embedding(codebook_size, emb_dim)
97
+
98
+ def forward(self, z):
99
+ hard = self.straight_through if self.training else True
100
+
101
+ logits = self.proj(z)
102
+
103
+ soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
104
+
105
+ z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
106
+
107
+ # + kl divergence to the prior loss
108
+ qy = F.softmax(logits, dim=1)
109
+ diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
110
+ min_encoding_indices = soft_one_hot.argmax(dim=1)
111
+
112
+ return z_q, diff, {
113
+ "min_encoding_indices": min_encoding_indices
114
+ }
115
+
116
+
117
+ class Downsample(nn.Module):
118
+ def __init__(self, in_channels):
119
+ super().__init__()
120
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
121
+
122
+ def forward(self, x):
123
+ pad = (0, 1, 0, 1)
124
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
125
+ x = self.conv(x)
126
+ return x
127
+
128
+
129
+ class Upsample(nn.Module):
130
+ def __init__(self, in_channels):
131
+ super().__init__()
132
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
133
+
134
+ def forward(self, x):
135
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
136
+ x = self.conv(x)
137
+
138
+ return x
139
+
140
+
141
+ class ResBlock(nn.Module):
142
+ def __init__(self, in_channels, out_channels=None):
143
+ super(ResBlock, self).__init__()
144
+ self.in_channels = in_channels
145
+ self.out_channels = in_channels if out_channels is None else out_channels
146
+ self.norm1 = normalize(in_channels)
147
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
148
+ self.norm2 = normalize(out_channels)
149
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
150
+ if self.in_channels != self.out_channels:
151
+ self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
152
+
153
+ def forward(self, x_in):
154
+ x = x_in
155
+ x = self.norm1(x)
156
+ x = swish(x)
157
+ x = self.conv1(x)
158
+ x = self.norm2(x)
159
+ x = swish(x)
160
+ x = self.conv2(x)
161
+ if self.in_channels != self.out_channels:
162
+ x_in = self.conv_out(x_in)
163
+
164
+ return x + x_in
165
+
166
+
167
+ class AttnBlock(nn.Module):
168
+ def __init__(self, in_channels):
169
+ super().__init__()
170
+ self.in_channels = in_channels
171
+
172
+ self.norm = normalize(in_channels)
173
+ self.q = torch.nn.Conv2d(
174
+ in_channels,
175
+ in_channels,
176
+ kernel_size=1,
177
+ stride=1,
178
+ padding=0
179
+ )
180
+ self.k = torch.nn.Conv2d(
181
+ in_channels,
182
+ in_channels,
183
+ kernel_size=1,
184
+ stride=1,
185
+ padding=0
186
+ )
187
+ self.v = torch.nn.Conv2d(
188
+ in_channels,
189
+ in_channels,
190
+ kernel_size=1,
191
+ stride=1,
192
+ padding=0
193
+ )
194
+ self.proj_out = torch.nn.Conv2d(
195
+ in_channels,
196
+ in_channels,
197
+ kernel_size=1,
198
+ stride=1,
199
+ padding=0
200
+ )
201
+
202
+ def forward(self, x):
203
+ h_ = x
204
+ h_ = self.norm(h_)
205
+ q = self.q(h_)
206
+ k = self.k(h_)
207
+ v = self.v(h_)
208
+
209
+ # compute attention
210
+ b, c, h, w = q.shape
211
+ q = q.reshape(b, c, h*w)
212
+ q = q.permute(0, 2, 1)
213
+ k = k.reshape(b, c, h*w)
214
+ w_ = torch.bmm(q, k)
215
+ w_ = w_ * (int(c)**(-0.5))
216
+ w_ = F.softmax(w_, dim=2)
217
+
218
+ # attend to values
219
+ v = v.reshape(b, c, h*w)
220
+ w_ = w_.permute(0, 2, 1)
221
+ h_ = torch.bmm(v, w_)
222
+ h_ = h_.reshape(b, c, h, w)
223
+
224
+ h_ = self.proj_out(h_)
225
+
226
+ return x+h_
227
+
228
+
229
+ class Encoder(nn.Module):
230
+ def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
231
+ super().__init__()
232
+ self.nf = nf
233
+ self.num_resolutions = len(ch_mult)
234
+ self.num_res_blocks = num_res_blocks
235
+ self.resolution = resolution
236
+ self.attn_resolutions = attn_resolutions
237
+
238
+ curr_res = self.resolution
239
+ in_ch_mult = (1,)+tuple(ch_mult)
240
+
241
+ blocks = []
242
+ # initial convultion
243
+ blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
244
+
245
+ # residual and downsampling blocks, with attention on smaller res (16x16)
246
+ for i in range(self.num_resolutions):
247
+ block_in_ch = nf * in_ch_mult[i]
248
+ block_out_ch = nf * ch_mult[i]
249
+ for _ in range(self.num_res_blocks):
250
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
251
+ block_in_ch = block_out_ch
252
+ if curr_res in attn_resolutions:
253
+ blocks.append(AttnBlock(block_in_ch))
254
+
255
+ if i != self.num_resolutions - 1:
256
+ blocks.append(Downsample(block_in_ch))
257
+ curr_res = curr_res // 2
258
+
259
+ # non-local attention block
260
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
261
+ blocks.append(AttnBlock(block_in_ch))
262
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
263
+
264
+ # normalise and convert to latent size
265
+ blocks.append(normalize(block_in_ch))
266
+ blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
267
+ self.blocks = nn.ModuleList(blocks)
268
+
269
+ def forward(self, x):
270
+ for block in self.blocks:
271
+ x = block(x)
272
+
273
+ return x
274
+
275
+
276
+ class Generator(nn.Module):
277
+ def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
278
+ super().__init__()
279
+ self.nf = nf
280
+ self.ch_mult = ch_mult
281
+ self.num_resolutions = len(self.ch_mult)
282
+ self.num_res_blocks = res_blocks
283
+ self.resolution = img_size
284
+ self.attn_resolutions = attn_resolutions
285
+ self.in_channels = emb_dim
286
+ self.out_channels = 3
287
+ block_in_ch = self.nf * self.ch_mult[-1]
288
+ curr_res = self.resolution // 2 ** (self.num_resolutions-1)
289
+
290
+ blocks = []
291
+ # initial conv
292
+ blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
293
+
294
+ # non-local attention block
295
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
296
+ blocks.append(AttnBlock(block_in_ch))
297
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
298
+
299
+ for i in reversed(range(self.num_resolutions)):
300
+ block_out_ch = self.nf * self.ch_mult[i]
301
+
302
+ for _ in range(self.num_res_blocks):
303
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
304
+ block_in_ch = block_out_ch
305
+
306
+ if curr_res in self.attn_resolutions:
307
+ blocks.append(AttnBlock(block_in_ch))
308
+
309
+ if i != 0:
310
+ blocks.append(Upsample(block_in_ch))
311
+ curr_res = curr_res * 2
312
+
313
+ blocks.append(normalize(block_in_ch))
314
+ blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
315
+
316
+ self.blocks = nn.ModuleList(blocks)
317
+
318
+
319
+ def forward(self, x):
320
+ for block in self.blocks:
321
+ x = block(x)
322
+
323
+ return x
324
+
325
+
326
+ @ARCH_REGISTRY.register()
327
+ class VQAutoEncoder(nn.Module):
328
+ def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
329
+ beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
330
+ super().__init__()
331
+ logger = get_root_logger()
332
+ self.in_channels = 3
333
+ self.nf = nf
334
+ self.n_blocks = res_blocks
335
+ self.codebook_size = codebook_size
336
+ self.embed_dim = emb_dim
337
+ self.ch_mult = ch_mult
338
+ self.resolution = img_size
339
+ self.attn_resolutions = attn_resolutions
340
+ self.quantizer_type = quantizer
341
+ self.encoder = Encoder(
342
+ self.in_channels,
343
+ self.nf,
344
+ self.embed_dim,
345
+ self.ch_mult,
346
+ self.n_blocks,
347
+ self.resolution,
348
+ self.attn_resolutions
349
+ )
350
+ if self.quantizer_type == "nearest":
351
+ self.beta = beta #0.25
352
+ self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
353
+ elif self.quantizer_type == "gumbel":
354
+ self.gumbel_num_hiddens = emb_dim
355
+ self.straight_through = gumbel_straight_through
356
+ self.kl_weight = gumbel_kl_weight
357
+ self.quantize = GumbelQuantizer(
358
+ self.codebook_size,
359
+ self.embed_dim,
360
+ self.gumbel_num_hiddens,
361
+ self.straight_through,
362
+ self.kl_weight
363
+ )
364
+ self.generator = Generator(
365
+ self.nf,
366
+ self.embed_dim,
367
+ self.ch_mult,
368
+ self.n_blocks,
369
+ self.resolution,
370
+ self.attn_resolutions
371
+ )
372
+
373
+ if model_path is not None:
374
+ chkpt = torch.load(model_path, map_location='cpu')
375
+ if 'params_ema' in chkpt:
376
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
377
+ logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
378
+ elif 'params' in chkpt:
379
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
380
+ logger.info(f'vqgan is loaded from: {model_path} [params]')
381
+ else:
382
+ raise ValueError(f'Wrong params!')
383
+
384
+
385
+ def forward(self, x):
386
+ x = self.encoder(x)
387
+ quant, codebook_loss, quant_stats = self.quantize(x)
388
+ x = self.generator(quant)
389
+ return x, codebook_loss, quant_stats
390
+
391
+
392
+
393
+ # patch based discriminator
394
+ @ARCH_REGISTRY.register()
395
+ class VQGANDiscriminator(nn.Module):
396
+ def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
397
+ super().__init__()
398
+
399
+ layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
400
+ ndf_mult = 1
401
+ ndf_mult_prev = 1
402
+ for n in range(1, n_layers): # gradually increase the number of filters
403
+ ndf_mult_prev = ndf_mult
404
+ ndf_mult = min(2 ** n, 8)
405
+ layers += [
406
+ nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
407
+ nn.BatchNorm2d(ndf * ndf_mult),
408
+ nn.LeakyReLU(0.2, True)
409
+ ]
410
+
411
+ ndf_mult_prev = ndf_mult
412
+ ndf_mult = min(2 ** n_layers, 8)
413
+
414
+ layers += [
415
+ nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
416
+ nn.BatchNorm2d(ndf * ndf_mult),
417
+ nn.LeakyReLU(0.2, True)
418
+ ]
419
+
420
+ layers += [
421
+ nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
422
+ self.main = nn.Sequential(*layers)
423
+
424
+ if model_path is not None:
425
+ chkpt = torch.load(model_path, map_location='cpu')
426
+ if 'params_d' in chkpt:
427
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
428
+ elif 'params' in chkpt:
429
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
430
+ else:
431
+ raise ValueError(f'Wrong params!')
432
+
433
+ def forward(self, x):
434
+ return self.main(x)
basicsr/data/__init__.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import numpy as np
3
+ import random
4
+ import torch
5
+ import torch.utils.data
6
+ from copy import deepcopy
7
+ from functools import partial
8
+ from os import path as osp
9
+
10
+ from basicsr.data.prefetch_dataloader import PrefetchDataLoader
11
+ from basicsr.utils import get_root_logger, scandir
12
+ from basicsr.utils.dist_util import get_dist_info
13
+ from basicsr.utils.registry import DATASET_REGISTRY
14
+
15
+ __all__ = ['build_dataset', 'build_dataloader']
16
+
17
+ # automatically scan and import dataset modules for registry
18
+ # scan all the files under the data folder with '_dataset' in file names
19
+ data_folder = osp.dirname(osp.abspath(__file__))
20
+ dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
21
+ # import all the dataset modules
22
+ _dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]
23
+
24
+
25
+ def build_dataset(dataset_opt):
26
+ """Build dataset from options.
27
+
28
+ Args:
29
+ dataset_opt (dict): Configuration for dataset. It must constain:
30
+ name (str): Dataset name.
31
+ type (str): Dataset type.
32
+ """
33
+ dataset_opt = deepcopy(dataset_opt)
34
+ dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
35
+ logger = get_root_logger()
36
+ logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' 'is built.')
37
+ return dataset
38
+
39
+
40
+ def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
41
+ """Build dataloader.
42
+
43
+ Args:
44
+ dataset (torch.utils.data.Dataset): Dataset.
45
+ dataset_opt (dict): Dataset options. It contains the following keys:
46
+ phase (str): 'train' or 'val'.
47
+ num_worker_per_gpu (int): Number of workers for each GPU.
48
+ batch_size_per_gpu (int): Training batch size for each GPU.
49
+ num_gpu (int): Number of GPUs. Used only in the train phase.
50
+ Default: 1.
51
+ dist (bool): Whether in distributed training. Used only in the train
52
+ phase. Default: False.
53
+ sampler (torch.utils.data.sampler): Data sampler. Default: None.
54
+ seed (int | None): Seed. Default: None
55
+ """
56
+ phase = dataset_opt['phase']
57
+ rank, _ = get_dist_info()
58
+ if phase == 'train':
59
+ if dist: # distributed training
60
+ batch_size = dataset_opt['batch_size_per_gpu']
61
+ num_workers = dataset_opt['num_worker_per_gpu']
62
+ else: # non-distributed training
63
+ multiplier = 1 if num_gpu == 0 else num_gpu
64
+ batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
65
+ num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
66
+ dataloader_args = dict(
67
+ dataset=dataset,
68
+ batch_size=batch_size,
69
+ shuffle=False,
70
+ num_workers=num_workers,
71
+ sampler=sampler,
72
+ drop_last=True)
73
+ if sampler is None:
74
+ dataloader_args['shuffle'] = True
75
+ dataloader_args['worker_init_fn'] = partial(
76
+ worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
77
+ elif phase in ['val', 'test']: # validation
78
+ dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
79
+ else:
80
+ raise ValueError(f'Wrong dataset phase: {phase}. ' "Supported ones are 'train', 'val' and 'test'.")
81
+
82
+ dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
83
+
84
+ prefetch_mode = dataset_opt.get('prefetch_mode')
85
+ if prefetch_mode == 'cpu': # CPUPrefetcher
86
+ num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
87
+ logger = get_root_logger()
88
+ logger.info(f'Use {prefetch_mode} prefetch dataloader: ' f'num_prefetch_queue = {num_prefetch_queue}')
89
+ return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
90
+ else:
91
+ # prefetch_mode=None: Normal dataloader
92
+ # prefetch_mode='cuda': dataloader for CUDAPrefetcher
93
+ return torch.utils.data.DataLoader(**dataloader_args)
94
+
95
+
96
+ def worker_init_fn(worker_id, num_workers, rank, seed):
97
+ # Set the worker seed to num_workers * rank + worker_id + seed
98
+ worker_seed = num_workers * rank + worker_id + seed
99
+ np.random.seed(worker_seed)
100
+ random.seed(worker_seed)
basicsr/data/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (3.51 kB). View file
 
basicsr/data/__pycache__/data_sampler.cpython-39.pyc ADDED
Binary file (2.12 kB). View file
 
basicsr/data/__pycache__/data_util.cpython-39.pyc ADDED
Binary file (13.5 kB). View file
 
basicsr/data/__pycache__/ffhq_blind_dataset.cpython-39.pyc ADDED
Binary file (7.88 kB). View file
 
basicsr/data/__pycache__/ffhq_blind_joint_dataset.cpython-39.pyc ADDED
Binary file (8.47 kB). View file
 
basicsr/data/__pycache__/gaussian_kernels.cpython-39.pyc ADDED
Binary file (17.8 kB). View file
 
basicsr/data/__pycache__/paired_image_dataset.cpython-39.pyc ADDED
Binary file (3.73 kB). View file
 
basicsr/data/__pycache__/prefetch_dataloader.cpython-39.pyc ADDED
Binary file (4.35 kB). View file
 
basicsr/data/__pycache__/transforms.cpython-39.pyc ADDED
Binary file (5.41 kB). View file
 
basicsr/data/data_sampler.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.utils.data.sampler import Sampler
4
+
5
+
6
+ class EnlargedSampler(Sampler):
7
+ """Sampler that restricts data loading to a subset of the dataset.
8
+
9
+ Modified from torch.utils.data.distributed.DistributedSampler
10
+ Support enlarging the dataset for iteration-based training, for saving
11
+ time when restart the dataloader after each epoch
12
+
13
+ Args:
14
+ dataset (torch.utils.data.Dataset): Dataset used for sampling.
15
+ num_replicas (int | None): Number of processes participating in
16
+ the training. It is usually the world_size.
17
+ rank (int | None): Rank of the current process within num_replicas.
18
+ ratio (int): Enlarging ratio. Default: 1.
19
+ """
20
+
21
+ def __init__(self, dataset, num_replicas, rank, ratio=1):
22
+ self.dataset = dataset
23
+ self.num_replicas = num_replicas
24
+ self.rank = rank
25
+ self.epoch = 0
26
+ self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
27
+ self.total_size = self.num_samples * self.num_replicas
28
+
29
+ def __iter__(self):
30
+ # deterministically shuffle based on epoch
31
+ g = torch.Generator()
32
+ g.manual_seed(self.epoch)
33
+ indices = torch.randperm(self.total_size, generator=g).tolist()
34
+
35
+ dataset_size = len(self.dataset)
36
+ indices = [v % dataset_size for v in indices]
37
+
38
+ # subsample
39
+ indices = indices[self.rank:self.total_size:self.num_replicas]
40
+ assert len(indices) == self.num_samples
41
+
42
+ return iter(indices)
43
+
44
+ def __len__(self):
45
+ return self.num_samples
46
+
47
+ def set_epoch(self, epoch):
48
+ self.epoch = epoch
basicsr/data/data_util.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ from os import path as osp
6
+ from PIL import Image, ImageDraw
7
+ from torch.nn import functional as F
8
+
9
+ from basicsr.data.transforms import mod_crop
10
+ from basicsr.utils import img2tensor, scandir
11
+
12
+
13
+ def read_img_seq(path, require_mod_crop=False, scale=1):
14
+ """Read a sequence of images from a given folder path.
15
+
16
+ Args:
17
+ path (list[str] | str): List of image paths or image folder path.
18
+ require_mod_crop (bool): Require mod crop for each image.
19
+ Default: False.
20
+ scale (int): Scale factor for mod_crop. Default: 1.
21
+
22
+ Returns:
23
+ Tensor: size (t, c, h, w), RGB, [0, 1].
24
+ """
25
+ if isinstance(path, list):
26
+ img_paths = path
27
+ else:
28
+ img_paths = sorted(list(scandir(path, full_path=True)))
29
+ imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
30
+ if require_mod_crop:
31
+ imgs = [mod_crop(img, scale) for img in imgs]
32
+ imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
33
+ imgs = torch.stack(imgs, dim=0)
34
+ return imgs
35
+
36
+
37
+ def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
38
+ """Generate an index list for reading `num_frames` frames from a sequence
39
+ of images.
40
+
41
+ Args:
42
+ crt_idx (int): Current center index.
43
+ max_frame_num (int): Max number of the sequence of images (from 1).
44
+ num_frames (int): Reading num_frames frames.
45
+ padding (str): Padding mode, one of
46
+ 'replicate' | 'reflection' | 'reflection_circle' | 'circle'
47
+ Examples: current_idx = 0, num_frames = 5
48
+ The generated frame indices under different padding mode:
49
+ replicate: [0, 0, 0, 1, 2]
50
+ reflection: [2, 1, 0, 1, 2]
51
+ reflection_circle: [4, 3, 0, 1, 2]
52
+ circle: [3, 4, 0, 1, 2]
53
+
54
+ Returns:
55
+ list[int]: A list of indices.
56
+ """
57
+ assert num_frames % 2 == 1, 'num_frames should be an odd number.'
58
+ assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
59
+
60
+ max_frame_num = max_frame_num - 1 # start from 0
61
+ num_pad = num_frames // 2
62
+
63
+ indices = []
64
+ for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
65
+ if i < 0:
66
+ if padding == 'replicate':
67
+ pad_idx = 0
68
+ elif padding == 'reflection':
69
+ pad_idx = -i
70
+ elif padding == 'reflection_circle':
71
+ pad_idx = crt_idx + num_pad - i
72
+ else:
73
+ pad_idx = num_frames + i
74
+ elif i > max_frame_num:
75
+ if padding == 'replicate':
76
+ pad_idx = max_frame_num
77
+ elif padding == 'reflection':
78
+ pad_idx = max_frame_num * 2 - i
79
+ elif padding == 'reflection_circle':
80
+ pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
81
+ else:
82
+ pad_idx = i - num_frames
83
+ else:
84
+ pad_idx = i
85
+ indices.append(pad_idx)
86
+ return indices
87
+
88
+
89
+ def paired_paths_from_lmdb(folders, keys):
90
+ """Generate paired paths from lmdb files.
91
+
92
+ Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
93
+
94
+ lq.lmdb
95
+ ├── data.mdb
96
+ ├── lock.mdb
97
+ ├── meta_info.txt
98
+
99
+ The data.mdb and lock.mdb are standard lmdb files and you can refer to
100
+ https://lmdb.readthedocs.io/en/release/ for more details.
101
+
102
+ The meta_info.txt is a specified txt file to record the meta information
103
+ of our datasets. It will be automatically created when preparing
104
+ datasets by our provided dataset tools.
105
+ Each line in the txt file records
106
+ 1)image name (with extension),
107
+ 2)image shape,
108
+ 3)compression level, separated by a white space.
109
+ Example: `baboon.png (120,125,3) 1`
110
+
111
+ We use the image name without extension as the lmdb key.
112
+ Note that we use the same key for the corresponding lq and gt images.
113
+
114
+ Args:
115
+ folders (list[str]): A list of folder path. The order of list should
116
+ be [input_folder, gt_folder].
117
+ keys (list[str]): A list of keys identifying folders. The order should
118
+ be in consistent with folders, e.g., ['lq', 'gt'].
119
+ Note that this key is different from lmdb keys.
120
+
121
+ Returns:
122
+ list[str]: Returned path list.
123
+ """
124
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
125
+ f'But got {len(folders)}')
126
+ assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
127
+ input_folder, gt_folder = folders
128
+ input_key, gt_key = keys
129
+
130
+ if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
131
+ raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
132
+ f'formats. But received {input_key}: {input_folder}; '
133
+ f'{gt_key}: {gt_folder}')
134
+ # ensure that the two meta_info files are the same
135
+ with open(osp.join(input_folder, 'meta_info.txt')) as fin:
136
+ input_lmdb_keys = [line.split('.')[0] for line in fin]
137
+ with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
138
+ gt_lmdb_keys = [line.split('.')[0] for line in fin]
139
+ if set(input_lmdb_keys) != set(gt_lmdb_keys):
140
+ raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
141
+ else:
142
+ paths = []
143
+ for lmdb_key in sorted(input_lmdb_keys):
144
+ paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
145
+ return paths
146
+
147
+
148
+ def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
149
+ """Generate paired paths from an meta information file.
150
+
151
+ Each line in the meta information file contains the image names and
152
+ image shape (usually for gt), separated by a white space.
153
+
154
+ Example of an meta information file:
155
+ ```
156
+ 0001_s001.png (480,480,3)
157
+ 0001_s002.png (480,480,3)
158
+ ```
159
+
160
+ Args:
161
+ folders (list[str]): A list of folder path. The order of list should
162
+ be [input_folder, gt_folder].
163
+ keys (list[str]): A list of keys identifying folders. The order should
164
+ be in consistent with folders, e.g., ['lq', 'gt'].
165
+ meta_info_file (str): Path to the meta information file.
166
+ filename_tmpl (str): Template for each filename. Note that the
167
+ template excludes the file extension. Usually the filename_tmpl is
168
+ for files in the input folder.
169
+
170
+ Returns:
171
+ list[str]: Returned path list.
172
+ """
173
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
174
+ f'But got {len(folders)}')
175
+ assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
176
+ input_folder, gt_folder = folders
177
+ input_key, gt_key = keys
178
+
179
+ with open(meta_info_file, 'r') as fin:
180
+ gt_names = [line.split(' ')[0] for line in fin]
181
+
182
+ paths = []
183
+ for gt_name in gt_names:
184
+ basename, ext = osp.splitext(osp.basename(gt_name))
185
+ input_name = f'{filename_tmpl.format(basename)}{ext}'
186
+ input_path = osp.join(input_folder, input_name)
187
+ gt_path = osp.join(gt_folder, gt_name)
188
+ paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
189
+ return paths
190
+
191
+
192
+ def paired_paths_from_folder(folders, keys, filename_tmpl):
193
+ """Generate paired paths from folders.
194
+
195
+ Args:
196
+ folders (list[str]): A list of folder path. The order of list should
197
+ be [input_folder, gt_folder].
198
+ keys (list[str]): A list of keys identifying folders. The order should
199
+ be in consistent with folders, e.g., ['lq', 'gt'].
200
+ filename_tmpl (str): Template for each filename. Note that the
201
+ template excludes the file extension. Usually the filename_tmpl is
202
+ for files in the input folder.
203
+
204
+ Returns:
205
+ list[str]: Returned path list.
206
+ """
207
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
208
+ f'But got {len(folders)}')
209
+ assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
210
+ input_folder, gt_folder = folders
211
+ input_key, gt_key = keys
212
+
213
+ input_paths = list(scandir(input_folder))
214
+ gt_paths = list(scandir(gt_folder))
215
+ assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
216
+ f'{len(input_paths)}, {len(gt_paths)}.')
217
+ paths = []
218
+ for gt_path in gt_paths:
219
+ basename, ext = osp.splitext(osp.basename(gt_path))
220
+ input_name = f'{filename_tmpl.format(basename)}{ext}'
221
+ input_path = osp.join(input_folder, input_name)
222
+ assert input_name in input_paths, (f'{input_name} is not in ' f'{input_key}_paths.')
223
+ gt_path = osp.join(gt_folder, gt_path)
224
+ paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
225
+ return paths
226
+
227
+
228
+ def paths_from_folder(folder):
229
+ """Generate paths from folder.
230
+
231
+ Args:
232
+ folder (str): Folder path.
233
+
234
+ Returns:
235
+ list[str]: Returned path list.
236
+ """
237
+
238
+ paths = list(scandir(folder))
239
+ paths = [osp.join(folder, path) for path in paths]
240
+ return paths
241
+
242
+
243
+ def paths_from_lmdb(folder):
244
+ """Generate paths from lmdb.
245
+
246
+ Args:
247
+ folder (str): Folder path.
248
+
249
+ Returns:
250
+ list[str]: Returned path list.
251
+ """
252
+ if not folder.endswith('.lmdb'):
253
+ raise ValueError(f'Folder {folder}folder should in lmdb format.')
254
+ with open(osp.join(folder, 'meta_info.txt')) as fin:
255
+ paths = [line.split('.')[0] for line in fin]
256
+ return paths
257
+
258
+
259
+ def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
260
+ """Generate Gaussian kernel used in `duf_downsample`.
261
+
262
+ Args:
263
+ kernel_size (int): Kernel size. Default: 13.
264
+ sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
265
+
266
+ Returns:
267
+ np.array: The Gaussian kernel.
268
+ """
269
+ from scipy.ndimage import filters as filters
270
+ kernel = np.zeros((kernel_size, kernel_size))
271
+ # set element at the middle to one, a dirac delta
272
+ kernel[kernel_size // 2, kernel_size // 2] = 1
273
+ # gaussian-smooth the dirac, resulting in a gaussian filter
274
+ return filters.gaussian_filter(kernel, sigma)
275
+
276
+
277
+ def duf_downsample(x, kernel_size=13, scale=4):
278
+ """Downsamping with Gaussian kernel used in the DUF official code.
279
+
280
+ Args:
281
+ x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
282
+ kernel_size (int): Kernel size. Default: 13.
283
+ scale (int): Downsampling factor. Supported scale: (2, 3, 4).
284
+ Default: 4.
285
+
286
+ Returns:
287
+ Tensor: DUF downsampled frames.
288
+ """
289
+ assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
290
+
291
+ squeeze_flag = False
292
+ if x.ndim == 4:
293
+ squeeze_flag = True
294
+ x = x.unsqueeze(0)
295
+ b, t, c, h, w = x.size()
296
+ x = x.view(-1, 1, h, w)
297
+ pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
298
+ x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
299
+
300
+ gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
301
+ gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
302
+ x = F.conv2d(x, gaussian_filter, stride=scale)
303
+ x = x[:, :, 2:-2, 2:-2]
304
+ x = x.view(b, t, c, x.size(2), x.size(3))
305
+ if squeeze_flag:
306
+ x = x.squeeze(0)
307
+ return x
308
+
309
+
310
+ def brush_stroke_mask(img, color=(255,255,255)):
311
+ min_num_vertex = 8
312
+ max_num_vertex = 28
313
+ mean_angle = 2*math.pi / 5
314
+ angle_range = 2*math.pi / 12
315
+ # training large mask ratio (training setting)
316
+ min_width = 30
317
+ max_width = 70
318
+ # very large mask ratio (test setting and refine after 200k)
319
+ # min_width = 80
320
+ # max_width = 120
321
+ def generate_mask(H, W, img=None):
322
+ average_radius = math.sqrt(H*H+W*W) / 8
323
+ mask = Image.new('RGB', (W, H), 0)
324
+ if img is not None: mask = img # Image.fromarray(img)
325
+
326
+ for _ in range(np.random.randint(1, 4)):
327
+ num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
328
+ angle_min = mean_angle - np.random.uniform(0, angle_range)
329
+ angle_max = mean_angle + np.random.uniform(0, angle_range)
330
+ angles = []
331
+ vertex = []
332
+ for i in range(num_vertex):
333
+ if i % 2 == 0:
334
+ angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
335
+ else:
336
+ angles.append(np.random.uniform(angle_min, angle_max))
337
+
338
+ h, w = mask.size
339
+ vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
340
+ for i in range(num_vertex):
341
+ r = np.clip(
342
+ np.random.normal(loc=average_radius, scale=average_radius//2),
343
+ 0, 2*average_radius)
344
+ new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
345
+ new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
346
+ vertex.append((int(new_x), int(new_y)))
347
+
348
+ draw = ImageDraw.Draw(mask)
349
+ width = int(np.random.uniform(min_width, max_width))
350
+ draw.line(vertex, fill=color, width=width)
351
+ for v in vertex:
352
+ draw.ellipse((v[0] - width//2,
353
+ v[1] - width//2,
354
+ v[0] + width//2,
355
+ v[1] + width//2),
356
+ fill=color)
357
+
358
+ return mask
359
+
360
+ width, height = img.size
361
+ mask = generate_mask(height, width, img)
362
+ return mask
363
+
364
+
365
+ def random_ff_mask(shape, max_angle = 10, max_len = 100, max_width = 70, times = 10):
366
+ """Generate a random free form mask with configuration.
367
+ Args:
368
+ config: Config should have configuration including IMG_SHAPES,
369
+ VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH.
370
+ Returns:
371
+ tuple: (top, left, height, width)
372
+ Link:
373
+ https://github.com/csqiangwen/DeepFillv2_Pytorch/blob/master/train_dataset.py
374
+ """
375
+ height = shape[0]
376
+ width = shape[1]
377
+ mask = np.zeros((height, width), np.float32)
378
+ times = np.random.randint(times-5, times)
379
+ for i in range(times):
380
+ start_x = np.random.randint(width)
381
+ start_y = np.random.randint(height)
382
+ for j in range(1 + np.random.randint(5)):
383
+ angle = 0.01 + np.random.randint(max_angle)
384
+ if i % 2 == 0:
385
+ angle = 2 * 3.1415926 - angle
386
+ length = 10 + np.random.randint(max_len-20, max_len)
387
+ brush_w = 5 + np.random.randint(max_width-30, max_width)
388
+ end_x = (start_x + length * np.sin(angle)).astype(np.int32)
389
+ end_y = (start_y + length * np.cos(angle)).astype(np.int32)
390
+ cv2.line(mask, (start_y, start_x), (end_y, end_x), 1.0, brush_w)
391
+ start_x, start_y = end_x, end_y
392
+ return mask.astype(np.float32)
basicsr/data/ffhq_blind_dataset.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import random
4
+ import numpy as np
5
+ import os.path as osp
6
+ from scipy.io import loadmat
7
+ from PIL import Image
8
+ import torch
9
+ import torch.utils.data as data
10
+ from torchvision.transforms.functional import (adjust_brightness, adjust_contrast,
11
+ adjust_hue, adjust_saturation, normalize)
12
+ from basicsr.data import gaussian_kernels as gaussian_kernels
13
+ from basicsr.data.transforms import augment
14
+ from basicsr.data.data_util import paths_from_folder, brush_stroke_mask, random_ff_mask
15
+ from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
16
+ from basicsr.utils.registry import DATASET_REGISTRY
17
+
18
+ @DATASET_REGISTRY.register()
19
+ class FFHQBlindDataset(data.Dataset):
20
+
21
+ def __init__(self, opt):
22
+ super(FFHQBlindDataset, self).__init__()
23
+ logger = get_root_logger()
24
+ self.opt = opt
25
+ # file client (io backend)
26
+ self.file_client = None
27
+ self.io_backend_opt = opt['io_backend']
28
+
29
+ self.gt_folder = opt['dataroot_gt']
30
+ self.gt_size = opt.get('gt_size', 512)
31
+ self.in_size = opt.get('in_size', 512)
32
+ assert self.gt_size >= self.in_size, 'Wrong setting.'
33
+
34
+ self.mean = opt.get('mean', [0.5, 0.5, 0.5])
35
+ self.std = opt.get('std', [0.5, 0.5, 0.5])
36
+
37
+ self.component_path = opt.get('component_path', None)
38
+ self.latent_gt_path = opt.get('latent_gt_path', None)
39
+
40
+ if self.component_path is not None:
41
+ self.crop_components = True
42
+ self.components_dict = torch.load(self.component_path)
43
+ self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1.4)
44
+ self.nose_enlarge_ratio = opt.get('nose_enlarge_ratio', 1.1)
45
+ self.mouth_enlarge_ratio = opt.get('mouth_enlarge_ratio', 1.3)
46
+ else:
47
+ self.crop_components = False
48
+
49
+ if self.latent_gt_path is not None:
50
+ self.load_latent_gt = True
51
+ self.latent_gt_dict = torch.load(self.latent_gt_path)
52
+ else:
53
+ self.load_latent_gt = False
54
+
55
+ if self.io_backend_opt['type'] == 'lmdb':
56
+ self.io_backend_opt['db_paths'] = self.gt_folder
57
+ if not self.gt_folder.endswith('.lmdb'):
58
+ raise ValueError("'dataroot_gt' should end with '.lmdb', "f'but received {self.gt_folder}')
59
+ with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
60
+ self.paths = [line.split('.')[0] for line in fin]
61
+ else:
62
+ self.paths = paths_from_folder(self.gt_folder)
63
+
64
+ # inpainting mask
65
+ self.gen_inpaint_mask = opt.get('gen_inpaint_mask', False)
66
+ if self.gen_inpaint_mask:
67
+ logger.info(f'generate mask ...')
68
+ # self.mask_max_angle = opt.get('mask_max_angle', 10)
69
+ # self.mask_max_len = opt.get('mask_max_len', 150)
70
+ # self.mask_max_width = opt.get('mask_max_width', 50)
71
+ # self.mask_draw_times = opt.get('mask_draw_times', 10)
72
+ # # print
73
+ # logger.info(f'mask_max_angle: {self.mask_max_angle}')
74
+ # logger.info(f'mask_max_len: {self.mask_max_len}')
75
+ # logger.info(f'mask_max_width: {self.mask_max_width}')
76
+ # logger.info(f'mask_draw_times: {self.mask_draw_times}')
77
+
78
+ # perform corrupt
79
+ self.use_corrupt = opt.get('use_corrupt', True)
80
+ self.use_motion_kernel = False
81
+ # self.use_motion_kernel = opt.get('use_motion_kernel', True)
82
+
83
+ if self.use_motion_kernel:
84
+ self.motion_kernel_prob = opt.get('motion_kernel_prob', 0.001)
85
+ motion_kernel_path = opt.get('motion_kernel_path', 'basicsr/data/motion-blur-kernels-32.pth')
86
+ self.motion_kernels = torch.load(motion_kernel_path)
87
+
88
+ if self.use_corrupt and not self.gen_inpaint_mask:
89
+ # degradation configurations
90
+ self.blur_kernel_size = opt['blur_kernel_size']
91
+ self.blur_sigma = opt['blur_sigma']
92
+ self.kernel_list = opt['kernel_list']
93
+ self.kernel_prob = opt['kernel_prob']
94
+ self.downsample_range = opt['downsample_range']
95
+ self.noise_range = opt['noise_range']
96
+ self.jpeg_range = opt['jpeg_range']
97
+ # print
98
+ logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
99
+ logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
100
+ logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
101
+ logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
102
+
103
+ # color jitter
104
+ self.color_jitter_prob = opt.get('color_jitter_prob', None)
105
+ self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob', None)
106
+ self.color_jitter_shift = opt.get('color_jitter_shift', 20)
107
+ if self.color_jitter_prob is not None:
108
+ logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')
109
+
110
+ # to gray
111
+ self.gray_prob = opt.get('gray_prob', 0.0)
112
+ if self.gray_prob is not None:
113
+ logger.info(f'Use random gray. Prob: {self.gray_prob}')
114
+ self.color_jitter_shift /= 255.
115
+
116
+ @staticmethod
117
+ def color_jitter(img, shift):
118
+ """jitter color: randomly jitter the RGB values, in numpy formats"""
119
+ jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
120
+ img = img + jitter_val
121
+ img = np.clip(img, 0, 1)
122
+ return img
123
+
124
+ @staticmethod
125
+ def color_jitter_pt(img, brightness, contrast, saturation, hue):
126
+ """jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
127
+ fn_idx = torch.randperm(4)
128
+ for fn_id in fn_idx:
129
+ if fn_id == 0 and brightness is not None:
130
+ brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
131
+ img = adjust_brightness(img, brightness_factor)
132
+
133
+ if fn_id == 1 and contrast is not None:
134
+ contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
135
+ img = adjust_contrast(img, contrast_factor)
136
+
137
+ if fn_id == 2 and saturation is not None:
138
+ saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
139
+ img = adjust_saturation(img, saturation_factor)
140
+
141
+ if fn_id == 3 and hue is not None:
142
+ hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
143
+ img = adjust_hue(img, hue_factor)
144
+ return img
145
+
146
+
147
+ def get_component_locations(self, name, status):
148
+ components_bbox = self.components_dict[name]
149
+ if status[0]: # hflip
150
+ # exchange right and left eye
151
+ tmp = components_bbox['left_eye']
152
+ components_bbox['left_eye'] = components_bbox['right_eye']
153
+ components_bbox['right_eye'] = tmp
154
+ # modify the width coordinate
155
+ components_bbox['left_eye'][0] = self.gt_size - components_bbox['left_eye'][0]
156
+ components_bbox['right_eye'][0] = self.gt_size - components_bbox['right_eye'][0]
157
+ components_bbox['nose'][0] = self.gt_size - components_bbox['nose'][0]
158
+ components_bbox['mouth'][0] = self.gt_size - components_bbox['mouth'][0]
159
+
160
+ locations_gt = {}
161
+ locations_in = {}
162
+ for part in ['left_eye', 'right_eye', 'nose', 'mouth']:
163
+ mean = components_bbox[part][0:2]
164
+ half_len = components_bbox[part][2]
165
+ if 'eye' in part:
166
+ half_len *= self.eye_enlarge_ratio
167
+ elif part == 'nose':
168
+ half_len *= self.nose_enlarge_ratio
169
+ elif part == 'mouth':
170
+ half_len *= self.mouth_enlarge_ratio
171
+ loc = np.hstack((mean - half_len + 1, mean + half_len))
172
+ loc = torch.from_numpy(loc).float()
173
+ locations_gt[part] = loc
174
+ loc_in = loc/(self.gt_size//self.in_size)
175
+ locations_in[part] = loc_in
176
+ return locations_gt, locations_in
177
+
178
+
179
+ def __getitem__(self, index):
180
+ if self.file_client is None:
181
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
182
+
183
+ # load gt image
184
+ gt_path = self.paths[index]
185
+ name = osp.basename(gt_path)[:-4]
186
+ img_bytes = self.file_client.get(gt_path)
187
+ img_gt = imfrombytes(img_bytes, float32=True)
188
+
189
+ # random horizontal flip
190
+ img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
191
+
192
+ if self.load_latent_gt:
193
+ if status[0]:
194
+ latent_gt = self.latent_gt_dict['hflip'][name]
195
+ else:
196
+ latent_gt = self.latent_gt_dict['orig'][name]
197
+
198
+ if self.crop_components:
199
+ locations_gt, locations_in = self.get_component_locations(name, status)
200
+
201
+ # generate in image
202
+ img_in = img_gt
203
+ if self.use_corrupt and not self.gen_inpaint_mask:
204
+ # motion blur
205
+ if self.use_motion_kernel and random.random() < self.motion_kernel_prob:
206
+ m_i = random.randint(0,31)
207
+ k = self.motion_kernels[f'{m_i:02d}']
208
+ img_in = cv2.filter2D(img_in,-1,k)
209
+
210
+ # gaussian blur
211
+ kernel = gaussian_kernels.random_mixed_kernels(
212
+ self.kernel_list,
213
+ self.kernel_prob,
214
+ self.blur_kernel_size,
215
+ self.blur_sigma,
216
+ self.blur_sigma,
217
+ [-math.pi, math.pi],
218
+ noise_range=None)
219
+ img_in = cv2.filter2D(img_in, -1, kernel)
220
+
221
+ # downsample
222
+ scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
223
+ img_in = cv2.resize(img_in, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR)
224
+
225
+ # noise
226
+ if self.noise_range is not None:
227
+ noise_sigma = np.random.uniform(self.noise_range[0] / 255., self.noise_range[1] / 255.)
228
+ noise = np.float32(np.random.randn(*(img_in.shape))) * noise_sigma
229
+ img_in = img_in + noise
230
+ img_in = np.clip(img_in, 0, 1)
231
+
232
+ # jpeg
233
+ if self.jpeg_range is not None:
234
+ jpeg_p = np.random.uniform(self.jpeg_range[0], self.jpeg_range[1])
235
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_p]
236
+ _, encimg = cv2.imencode('.jpg', img_in * 255., encode_param)
237
+ img_in = np.float32(cv2.imdecode(encimg, 1)) / 255.
238
+
239
+ # resize to in_size
240
+ img_in = cv2.resize(img_in, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR)
241
+
242
+ # if self.gen_inpaint_mask:
243
+ # inpaint_mask = random_ff_mask(shape=(self.gt_size,self.gt_size),
244
+ # max_angle = self.mask_max_angle, max_len = self.mask_max_len,
245
+ # max_width = self.mask_max_width, times = self.mask_draw_times)
246
+ # img_in = img_in * (1 - inpaint_mask.reshape(self.gt_size,self.gt_size,1)) + \
247
+ # 1.0 * inpaint_mask.reshape(self.gt_size,self.gt_size,1)
248
+
249
+ # inpaint_mask = torch.from_numpy(inpaint_mask).view(1,self.gt_size,self.gt_size)
250
+
251
+ if self.gen_inpaint_mask:
252
+ img_in = (img_in*255).astype('uint8')
253
+ img_in = brush_stroke_mask(Image.fromarray(img_in))
254
+ img_in = np.array(img_in) / 255.
255
+
256
+ # random color jitter (only for lq)
257
+ if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
258
+ img_in = self.color_jitter(img_in, self.color_jitter_shift)
259
+ # random to gray (only for lq)
260
+ if self.gray_prob and np.random.uniform() < self.gray_prob:
261
+ img_in = cv2.cvtColor(img_in, cv2.COLOR_BGR2GRAY)
262
+ img_in = np.tile(img_in[:, :, None], [1, 1, 3])
263
+
264
+ # BGR to RGB, HWC to CHW, numpy to tensor
265
+ img_in, img_gt = img2tensor([img_in, img_gt], bgr2rgb=True, float32=True)
266
+
267
+ # random color jitter (pytorch version) (only for lq)
268
+ if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
269
+ brightness = self.opt.get('brightness', (0.5, 1.5))
270
+ contrast = self.opt.get('contrast', (0.5, 1.5))
271
+ saturation = self.opt.get('saturation', (0, 1.5))
272
+ hue = self.opt.get('hue', (-0.1, 0.1))
273
+ img_in = self.color_jitter_pt(img_in, brightness, contrast, saturation, hue)
274
+
275
+ # round and clip
276
+ img_in = np.clip((img_in * 255.0).round(), 0, 255) / 255.
277
+
278
+ # Set vgg range_norm=True if use the normalization here
279
+ # normalize
280
+ normalize(img_in, self.mean, self.std, inplace=True)
281
+ normalize(img_gt, self.mean, self.std, inplace=True)
282
+
283
+ return_dict = {'in': img_in, 'gt': img_gt, 'gt_path': gt_path}
284
+
285
+ if self.crop_components:
286
+ return_dict['locations_in'] = locations_in
287
+ return_dict['locations_gt'] = locations_gt
288
+
289
+ if self.load_latent_gt:
290
+ return_dict['latent_gt'] = latent_gt
291
+
292
+ # if self.gen_inpaint_mask:
293
+ # return_dict['inpaint_mask'] = inpaint_mask
294
+
295
+ return return_dict
296
+
297
+
298
+ def __len__(self):
299
+ return len(self.paths)
basicsr/data/ffhq_blind_joint_dataset.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import random
4
+ import numpy as np
5
+ import os.path as osp
6
+ from scipy.io import loadmat
7
+ import torch
8
+ import torch.utils.data as data
9
+ from torchvision.transforms.functional import (adjust_brightness, adjust_contrast,
10
+ adjust_hue, adjust_saturation, normalize)
11
+ from basicsr.data import gaussian_kernels as gaussian_kernels
12
+ from basicsr.data.transforms import augment
13
+ from basicsr.data.data_util import paths_from_folder
14
+ from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
15
+ from basicsr.utils.registry import DATASET_REGISTRY
16
+
17
+ @DATASET_REGISTRY.register()
18
+ class FFHQBlindJointDataset(data.Dataset):
19
+
20
+ def __init__(self, opt):
21
+ super(FFHQBlindJointDataset, self).__init__()
22
+ logger = get_root_logger()
23
+ self.opt = opt
24
+ # file client (io backend)
25
+ self.file_client = None
26
+ self.io_backend_opt = opt['io_backend']
27
+
28
+ self.gt_folder = opt['dataroot_gt']
29
+ self.gt_size = opt.get('gt_size', 512)
30
+ self.in_size = opt.get('in_size', 512)
31
+ assert self.gt_size >= self.in_size, 'Wrong setting.'
32
+
33
+ self.mean = opt.get('mean', [0.5, 0.5, 0.5])
34
+ self.std = opt.get('std', [0.5, 0.5, 0.5])
35
+
36
+ self.component_path = opt.get('component_path', None)
37
+ self.latent_gt_path = opt.get('latent_gt_path', None)
38
+
39
+ if self.component_path is not None:
40
+ self.crop_components = True
41
+ self.components_dict = torch.load(self.component_path)
42
+ self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1.4)
43
+ self.nose_enlarge_ratio = opt.get('nose_enlarge_ratio', 1.1)
44
+ self.mouth_enlarge_ratio = opt.get('mouth_enlarge_ratio', 1.3)
45
+ else:
46
+ self.crop_components = False
47
+
48
+ if self.latent_gt_path is not None:
49
+ self.load_latent_gt = True
50
+ self.latent_gt_dict = torch.load(self.latent_gt_path)
51
+ else:
52
+ self.load_latent_gt = False
53
+
54
+ if self.io_backend_opt['type'] == 'lmdb':
55
+ self.io_backend_opt['db_paths'] = self.gt_folder
56
+ if not self.gt_folder.endswith('.lmdb'):
57
+ raise ValueError("'dataroot_gt' should end with '.lmdb', "f'but received {self.gt_folder}')
58
+ with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
59
+ self.paths = [line.split('.')[0] for line in fin]
60
+ else:
61
+ self.paths = paths_from_folder(self.gt_folder)
62
+
63
+ # perform corrupt
64
+ self.use_corrupt = opt.get('use_corrupt', True)
65
+ self.use_motion_kernel = False
66
+ # self.use_motion_kernel = opt.get('use_motion_kernel', True)
67
+
68
+ if self.use_motion_kernel:
69
+ self.motion_kernel_prob = opt.get('motion_kernel_prob', 0.001)
70
+ motion_kernel_path = opt.get('motion_kernel_path', 'basicsr/data/motion-blur-kernels-32.pth')
71
+ self.motion_kernels = torch.load(motion_kernel_path)
72
+
73
+ if self.use_corrupt:
74
+ # degradation configurations
75
+ self.blur_kernel_size = self.opt['blur_kernel_size']
76
+ self.kernel_list = self.opt['kernel_list']
77
+ self.kernel_prob = self.opt['kernel_prob']
78
+ # Small degradation
79
+ self.blur_sigma = self.opt['blur_sigma']
80
+ self.downsample_range = self.opt['downsample_range']
81
+ self.noise_range = self.opt['noise_range']
82
+ self.jpeg_range = self.opt['jpeg_range']
83
+ # Large degradation
84
+ self.blur_sigma_large = self.opt['blur_sigma_large']
85
+ self.downsample_range_large = self.opt['downsample_range_large']
86
+ self.noise_range_large = self.opt['noise_range_large']
87
+ self.jpeg_range_large = self.opt['jpeg_range_large']
88
+
89
+ # print
90
+ logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
91
+ logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
92
+ logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
93
+ logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
94
+
95
+ # color jitter
96
+ self.color_jitter_prob = opt.get('color_jitter_prob', None)
97
+ self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob', None)
98
+ self.color_jitter_shift = opt.get('color_jitter_shift', 20)
99
+ if self.color_jitter_prob is not None:
100
+ logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')
101
+
102
+ # to gray
103
+ self.gray_prob = opt.get('gray_prob', 0.0)
104
+ if self.gray_prob is not None:
105
+ logger.info(f'Use random gray. Prob: {self.gray_prob}')
106
+ self.color_jitter_shift /= 255.
107
+
108
+ @staticmethod
109
+ def color_jitter(img, shift):
110
+ """jitter color: randomly jitter the RGB values, in numpy formats"""
111
+ jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
112
+ img = img + jitter_val
113
+ img = np.clip(img, 0, 1)
114
+ return img
115
+
116
+ @staticmethod
117
+ def color_jitter_pt(img, brightness, contrast, saturation, hue):
118
+ """jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
119
+ fn_idx = torch.randperm(4)
120
+ for fn_id in fn_idx:
121
+ if fn_id == 0 and brightness is not None:
122
+ brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
123
+ img = adjust_brightness(img, brightness_factor)
124
+
125
+ if fn_id == 1 and contrast is not None:
126
+ contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
127
+ img = adjust_contrast(img, contrast_factor)
128
+
129
+ if fn_id == 2 and saturation is not None:
130
+ saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
131
+ img = adjust_saturation(img, saturation_factor)
132
+
133
+ if fn_id == 3 and hue is not None:
134
+ hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
135
+ img = adjust_hue(img, hue_factor)
136
+ return img
137
+
138
+
139
+ def get_component_locations(self, name, status):
140
+ components_bbox = self.components_dict[name]
141
+ if status[0]: # hflip
142
+ # exchange right and left eye
143
+ tmp = components_bbox['left_eye']
144
+ components_bbox['left_eye'] = components_bbox['right_eye']
145
+ components_bbox['right_eye'] = tmp
146
+ # modify the width coordinate
147
+ components_bbox['left_eye'][0] = self.gt_size - components_bbox['left_eye'][0]
148
+ components_bbox['right_eye'][0] = self.gt_size - components_bbox['right_eye'][0]
149
+ components_bbox['nose'][0] = self.gt_size - components_bbox['nose'][0]
150
+ components_bbox['mouth'][0] = self.gt_size - components_bbox['mouth'][0]
151
+
152
+ locations_gt = {}
153
+ locations_in = {}
154
+ for part in ['left_eye', 'right_eye', 'nose', 'mouth']:
155
+ mean = components_bbox[part][0:2]
156
+ half_len = components_bbox[part][2]
157
+ if 'eye' in part:
158
+ half_len *= self.eye_enlarge_ratio
159
+ elif part == 'nose':
160
+ half_len *= self.nose_enlarge_ratio
161
+ elif part == 'mouth':
162
+ half_len *= self.mouth_enlarge_ratio
163
+ loc = np.hstack((mean - half_len + 1, mean + half_len))
164
+ loc = torch.from_numpy(loc).float()
165
+ locations_gt[part] = loc
166
+ loc_in = loc/(self.gt_size//self.in_size)
167
+ locations_in[part] = loc_in
168
+ return locations_gt, locations_in
169
+
170
+
171
+ def __getitem__(self, index):
172
+ if self.file_client is None:
173
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
174
+
175
+ # load gt image
176
+ gt_path = self.paths[index]
177
+ name = osp.basename(gt_path)[:-4]
178
+ img_bytes = self.file_client.get(gt_path)
179
+ img_gt = imfrombytes(img_bytes, float32=True)
180
+
181
+ # random horizontal flip
182
+ img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
183
+
184
+ if self.load_latent_gt:
185
+ if status[0]:
186
+ latent_gt = self.latent_gt_dict['hflip'][name]
187
+ else:
188
+ latent_gt = self.latent_gt_dict['orig'][name]
189
+
190
+ if self.crop_components:
191
+ locations_gt, locations_in = self.get_component_locations(name, status)
192
+
193
+ # generate in image
194
+ img_in = img_gt
195
+ if self.use_corrupt:
196
+ # motion blur
197
+ if self.use_motion_kernel and random.random() < self.motion_kernel_prob:
198
+ m_i = random.randint(0,31)
199
+ k = self.motion_kernels[f'{m_i:02d}']
200
+ img_in = cv2.filter2D(img_in,-1,k)
201
+
202
+ # gaussian blur
203
+ kernel = gaussian_kernels.random_mixed_kernels(
204
+ self.kernel_list,
205
+ self.kernel_prob,
206
+ self.blur_kernel_size,
207
+ self.blur_sigma,
208
+ self.blur_sigma,
209
+ [-math.pi, math.pi],
210
+ noise_range=None)
211
+ img_in = cv2.filter2D(img_in, -1, kernel)
212
+
213
+ # downsample
214
+ scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
215
+ img_in = cv2.resize(img_in, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR)
216
+
217
+ # noise
218
+ if self.noise_range is not None:
219
+ noise_sigma = np.random.uniform(self.noise_range[0] / 255., self.noise_range[1] / 255.)
220
+ noise = np.float32(np.random.randn(*(img_in.shape))) * noise_sigma
221
+ img_in = img_in + noise
222
+ img_in = np.clip(img_in, 0, 1)
223
+
224
+ # jpeg
225
+ if self.jpeg_range is not None:
226
+ jpeg_p = np.random.uniform(self.jpeg_range[0], self.jpeg_range[1])
227
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_p]
228
+ _, encimg = cv2.imencode('.jpg', img_in * 255., encode_param)
229
+ img_in = np.float32(cv2.imdecode(encimg, 1)) / 255.
230
+
231
+ # resize to in_size
232
+ img_in = cv2.resize(img_in, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR)
233
+
234
+
235
+ # generate in_large with large degradation
236
+ img_in_large = img_gt
237
+
238
+ if self.use_corrupt:
239
+ # motion blur
240
+ if self.use_motion_kernel and random.random() < self.motion_kernel_prob:
241
+ m_i = random.randint(0,31)
242
+ k = self.motion_kernels[f'{m_i:02d}']
243
+ img_in_large = cv2.filter2D(img_in_large,-1,k)
244
+
245
+ # gaussian blur
246
+ kernel = gaussian_kernels.random_mixed_kernels(
247
+ self.kernel_list,
248
+ self.kernel_prob,
249
+ self.blur_kernel_size,
250
+ self.blur_sigma_large,
251
+ self.blur_sigma_large,
252
+ [-math.pi, math.pi],
253
+ noise_range=None)
254
+ img_in_large = cv2.filter2D(img_in_large, -1, kernel)
255
+
256
+ # downsample
257
+ scale = np.random.uniform(self.downsample_range_large[0], self.downsample_range_large[1])
258
+ img_in_large = cv2.resize(img_in_large, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR)
259
+
260
+ # noise
261
+ if self.noise_range_large is not None:
262
+ noise_sigma = np.random.uniform(self.noise_range_large[0] / 255., self.noise_range_large[1] / 255.)
263
+ noise = np.float32(np.random.randn(*(img_in_large.shape))) * noise_sigma
264
+ img_in_large = img_in_large + noise
265
+ img_in_large = np.clip(img_in_large, 0, 1)
266
+
267
+ # jpeg
268
+ if self.jpeg_range_large is not None:
269
+ jpeg_p = np.random.uniform(self.jpeg_range_large[0], self.jpeg_range_large[1])
270
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_p]
271
+ _, encimg = cv2.imencode('.jpg', img_in_large * 255., encode_param)
272
+ img_in_large = np.float32(cv2.imdecode(encimg, 1)) / 255.
273
+
274
+ # resize to in_size
275
+ img_in_large = cv2.resize(img_in_large, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR)
276
+
277
+
278
+ # random color jitter (only for lq)
279
+ if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
280
+ img_in = self.color_jitter(img_in, self.color_jitter_shift)
281
+ img_in_large = self.color_jitter(img_in_large, self.color_jitter_shift)
282
+ # random to gray (only for lq)
283
+ if self.gray_prob and np.random.uniform() < self.gray_prob:
284
+ img_in = cv2.cvtColor(img_in, cv2.COLOR_BGR2GRAY)
285
+ img_in = np.tile(img_in[:, :, None], [1, 1, 3])
286
+ img_in_large = cv2.cvtColor(img_in_large, cv2.COLOR_BGR2GRAY)
287
+ img_in_large = np.tile(img_in_large[:, :, None], [1, 1, 3])
288
+
289
+ # BGR to RGB, HWC to CHW, numpy to tensor
290
+ img_in, img_in_large, img_gt = img2tensor([img_in, img_in_large, img_gt], bgr2rgb=True, float32=True)
291
+
292
+ # random color jitter (pytorch version) (only for lq)
293
+ if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
294
+ brightness = self.opt.get('brightness', (0.5, 1.5))
295
+ contrast = self.opt.get('contrast', (0.5, 1.5))
296
+ saturation = self.opt.get('saturation', (0, 1.5))
297
+ hue = self.opt.get('hue', (-0.1, 0.1))
298
+ img_in = self.color_jitter_pt(img_in, brightness, contrast, saturation, hue)
299
+ img_in_large = self.color_jitter_pt(img_in_large, brightness, contrast, saturation, hue)
300
+
301
+ # round and clip
302
+ img_in = np.clip((img_in * 255.0).round(), 0, 255) / 255.
303
+ img_in_large = np.clip((img_in_large * 255.0).round(), 0, 255) / 255.
304
+
305
+ # Set vgg range_norm=True if use the normalization here
306
+ # normalize
307
+ normalize(img_in, self.mean, self.std, inplace=True)
308
+ normalize(img_in_large, self.mean, self.std, inplace=True)
309
+ normalize(img_gt, self.mean, self.std, inplace=True)
310
+
311
+ return_dict = {'in': img_in, 'in_large_de': img_in_large, 'gt': img_gt, 'gt_path': gt_path}
312
+
313
+ if self.crop_components:
314
+ return_dict['locations_in'] = locations_in
315
+ return_dict['locations_gt'] = locations_gt
316
+
317
+ if self.load_latent_gt:
318
+ return_dict['latent_gt'] = latent_gt
319
+
320
+ return return_dict
321
+
322
+
323
+ def __len__(self):
324
+ return len(self.paths)
basicsr/data/gaussian_kernels.py ADDED
@@ -0,0 +1,690 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import random
4
+ from scipy.ndimage.interpolation import shift
5
+ from scipy.stats import multivariate_normal
6
+
7
+
8
+ def sigma_matrix2(sig_x, sig_y, theta):
9
+ """Calculate the rotated sigma matrix (two dimensional matrix).
10
+ Args:
11
+ sig_x (float):
12
+ sig_y (float):
13
+ theta (float): Radian measurement.
14
+ Returns:
15
+ ndarray: Rotated sigma matrix.
16
+ """
17
+ D = np.array([[sig_x**2, 0], [0, sig_y**2]])
18
+ U = np.array([[np.cos(theta), -np.sin(theta)],
19
+ [np.sin(theta), np.cos(theta)]])
20
+ return np.dot(U, np.dot(D, U.T))
21
+
22
+
23
+ def mesh_grid(kernel_size):
24
+ """Generate the mesh grid, centering at zero.
25
+ Args:
26
+ kernel_size (int):
27
+ Returns:
28
+ xy (ndarray): with the shape (kernel_size, kernel_size, 2)
29
+ xx (ndarray): with the shape (kernel_size, kernel_size)
30
+ yy (ndarray): with the shape (kernel_size, kernel_size)
31
+ """
32
+ ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
33
+ xx, yy = np.meshgrid(ax, ax)
34
+ xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)),
35
+ yy.reshape(kernel_size * kernel_size,
36
+ 1))).reshape(kernel_size, kernel_size, 2)
37
+ return xy, xx, yy
38
+
39
+
40
+ def pdf2(sigma_matrix, grid):
41
+ """Calculate PDF of the bivariate Gaussian distribution.
42
+ Args:
43
+ sigma_matrix (ndarray): with the shape (2, 2)
44
+ grid (ndarray): generated by :func:`mesh_grid`,
45
+ with the shape (K, K, 2), K is the kernel size.
46
+ Returns:
47
+ kernel (ndarrray): un-normalized kernel.
48
+ """
49
+ inverse_sigma = np.linalg.inv(sigma_matrix)
50
+ kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
51
+ return kernel
52
+
53
+
54
+ def cdf2(D, grid):
55
+ """Calculate the CDF of the standard bivariate Gaussian distribution.
56
+ Used in skewed Gaussian distribution.
57
+ Args:
58
+ D (ndarrasy): skew matrix.
59
+ grid (ndarray): generated by :func:`mesh_grid`,
60
+ with the shape (K, K, 2), K is the kernel size.
61
+ Returns:
62
+ cdf (ndarray): skewed cdf.
63
+ """
64
+ rv = multivariate_normal([0, 0], [[1, 0], [0, 1]])
65
+ grid = np.dot(grid, D)
66
+ cdf = rv.cdf(grid)
67
+ return cdf
68
+
69
+
70
+ def bivariate_skew_Gaussian(kernel_size, sig_x, sig_y, theta, D, grid=None):
71
+ """Generate a bivariate skew Gaussian kernel.
72
+ Described in `A multivariate skew normal distribution`_ by Shi et. al (2004).
73
+ Args:
74
+ kernel_size (int):
75
+ sig_x (float):
76
+ sig_y (float):
77
+ theta (float): Radian measurement.
78
+ D (ndarrasy): skew matrix.
79
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
80
+ with the shape (K, K, 2), K is the kernel size. Default: None
81
+ Returns:
82
+ kernel (ndarray): normalized kernel.
83
+ .. _A multivariate skew normal distribution:
84
+ https://www.sciencedirect.com/science/article/pii/S0047259X03001313
85
+ """
86
+ if grid is None:
87
+ grid, _, _ = mesh_grid(kernel_size)
88
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
89
+ pdf = pdf2(sigma_matrix, grid)
90
+ cdf = cdf2(D, grid)
91
+ kernel = pdf * cdf
92
+ kernel = kernel / np.sum(kernel)
93
+ return kernel
94
+
95
+
96
+ def mass_center_shift(kernel_size, kernel):
97
+ """Calculate the shift of the mass center of a kenrel.
98
+ Args:
99
+ kernel_size (int):
100
+ kernel (ndarray): normalized kernel.
101
+ Returns:
102
+ delta_h (float):
103
+ delta_w (float):
104
+ """
105
+ ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
106
+ col_sum, row_sum = np.sum(kernel, axis=0), np.sum(kernel, axis=1)
107
+ delta_h = np.dot(row_sum, ax)
108
+ delta_w = np.dot(col_sum, ax)
109
+ return delta_h, delta_w
110
+
111
+
112
+ def bivariate_skew_Gaussian_center(kernel_size,
113
+ sig_x,
114
+ sig_y,
115
+ theta,
116
+ D,
117
+ grid=None):
118
+ """Generate a bivariate skew Gaussian kernel at center. Shift with nearest padding.
119
+ Args:
120
+ kernel_size (int):
121
+ sig_x (float):
122
+ sig_y (float):
123
+ theta (float): Radian measurement.
124
+ D (ndarrasy): skew matrix.
125
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
126
+ with the shape (K, K, 2), K is the kernel size. Default: None
127
+ Returns:
128
+ kernel (ndarray): centered and normalized kernel.
129
+ """
130
+ if grid is None:
131
+ grid, _, _ = mesh_grid(kernel_size)
132
+ kernel = bivariate_skew_Gaussian(kernel_size, sig_x, sig_y, theta, D, grid)
133
+ delta_h, delta_w = mass_center_shift(kernel_size, kernel)
134
+ kernel = shift(kernel, [-delta_h, -delta_w], mode='nearest')
135
+ kernel = kernel / np.sum(kernel)
136
+ return kernel
137
+
138
+
139
+ def bivariate_anisotropic_Gaussian(kernel_size,
140
+ sig_x,
141
+ sig_y,
142
+ theta,
143
+ grid=None):
144
+ """Generate a bivariate anisotropic Gaussian kernel.
145
+ Args:
146
+ kernel_size (int):
147
+ sig_x (float):
148
+ sig_y (float):
149
+ theta (float): Radian measurement.
150
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
151
+ with the shape (K, K, 2), K is the kernel size. Default: None
152
+ Returns:
153
+ kernel (ndarray): normalized kernel.
154
+ """
155
+ if grid is None:
156
+ grid, _, _ = mesh_grid(kernel_size)
157
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
158
+ kernel = pdf2(sigma_matrix, grid)
159
+ kernel = kernel / np.sum(kernel)
160
+ return kernel
161
+
162
+
163
+ def bivariate_isotropic_Gaussian(kernel_size, sig, grid=None):
164
+ """Generate a bivariate isotropic Gaussian kernel.
165
+ Args:
166
+ kernel_size (int):
167
+ sig (float):
168
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
169
+ with the shape (K, K, 2), K is the kernel size. Default: None
170
+ Returns:
171
+ kernel (ndarray): normalized kernel.
172
+ """
173
+ if grid is None:
174
+ grid, _, _ = mesh_grid(kernel_size)
175
+ sigma_matrix = np.array([[sig**2, 0], [0, sig**2]])
176
+ kernel = pdf2(sigma_matrix, grid)
177
+ kernel = kernel / np.sum(kernel)
178
+ return kernel
179
+
180
+
181
+ def bivariate_generalized_Gaussian(kernel_size,
182
+ sig_x,
183
+ sig_y,
184
+ theta,
185
+ beta,
186
+ grid=None):
187
+ """Generate a bivariate generalized Gaussian kernel.
188
+ Described in `Parameter Estimation For Multivariate Generalized Gaussian Distributions`_
189
+ by Pascal et. al (2013).
190
+ Args:
191
+ kernel_size (int):
192
+ sig_x (float):
193
+ sig_y (float):
194
+ theta (float): Radian measurement.
195
+ beta (float): shape parameter, beta = 1 is the normal distribution.
196
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
197
+ with the shape (K, K, 2), K is the kernel size. Default: None
198
+ Returns:
199
+ kernel (ndarray): normalized kernel.
200
+ .. _Parameter Estimation For Multivariate Generalized Gaussian Distributions:
201
+ https://arxiv.org/abs/1302.6498
202
+ """
203
+ if grid is None:
204
+ grid, _, _ = mesh_grid(kernel_size)
205
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
206
+ inverse_sigma = np.linalg.inv(sigma_matrix)
207
+ kernel = np.exp(
208
+ -0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta))
209
+ kernel = kernel / np.sum(kernel)
210
+ return kernel
211
+
212
+
213
+ def bivariate_plateau_type1(kernel_size, sig_x, sig_y, theta, beta, grid=None):
214
+ """Generate a plateau-like anisotropic kernel.
215
+ 1 / (1+x^(beta))
216
+ Args:
217
+ kernel_size (int):
218
+ sig_x (float):
219
+ sig_y (float):
220
+ theta (float): Radian measurement.
221
+ beta (float): shape parameter, beta = 1 is the normal distribution.
222
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
223
+ with the shape (K, K, 2), K is the kernel size. Default: None
224
+ Returns:
225
+ kernel (ndarray): normalized kernel.
226
+ """
227
+ if grid is None:
228
+ grid, _, _ = mesh_grid(kernel_size)
229
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
230
+ inverse_sigma = np.linalg.inv(sigma_matrix)
231
+ kernel = np.reciprocal(
232
+ np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
233
+ kernel = kernel / np.sum(kernel)
234
+ return kernel
235
+
236
+
237
+ def bivariate_plateau_type1_iso(kernel_size, sig, beta, grid=None):
238
+ """Generate a plateau-like isotropic kernel.
239
+ 1 / (1+x^(beta))
240
+ Args:
241
+ kernel_size (int):
242
+ sig (float):
243
+ beta (float): shape parameter, beta = 1 is the normal distribution.
244
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
245
+ with the shape (K, K, 2), K is the kernel size. Default: None
246
+ Returns:
247
+ kernel (ndarray): normalized kernel.
248
+ """
249
+ if grid is None:
250
+ grid, _, _ = mesh_grid(kernel_size)
251
+ sigma_matrix = np.array([[sig**2, 0], [0, sig**2]])
252
+ inverse_sigma = np.linalg.inv(sigma_matrix)
253
+ kernel = np.reciprocal(
254
+ np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
255
+ kernel = kernel / np.sum(kernel)
256
+ return kernel
257
+
258
+
259
+ def random_bivariate_skew_Gaussian_center(kernel_size,
260
+ sigma_x_range,
261
+ sigma_y_range,
262
+ rotation_range,
263
+ noise_range=None,
264
+ strict=False):
265
+ """Randomly generate bivariate skew Gaussian kernels at center.
266
+ Args:
267
+ kernel_size (int):
268
+ sigma_x_range (tuple): [0.6, 5]
269
+ sigma_y_range (tuple): [0.6, 5]
270
+ rotation range (tuple): [-math.pi, math.pi]
271
+ noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
272
+ Returns:
273
+ kernel (ndarray):
274
+ """
275
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
276
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
277
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
278
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
279
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
280
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
281
+ if strict:
282
+ sigma_max = np.max([sigma_x, sigma_y])
283
+ sigma_min = np.min([sigma_x, sigma_y])
284
+ sigma_x, sigma_y = sigma_max, sigma_min
285
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
286
+
287
+ sigma_max = np.max([sigma_x, sigma_y])
288
+ thres = 3 / sigma_max
289
+ D = [[np.random.uniform(-thres, thres),
290
+ np.random.uniform(-thres, thres)],
291
+ [np.random.uniform(-thres, thres),
292
+ np.random.uniform(-thres, thres)]]
293
+
294
+ kernel = bivariate_skew_Gaussian_center(kernel_size, sigma_x, sigma_y,
295
+ rotation, D)
296
+
297
+ # add multiplicative noise
298
+ if noise_range is not None:
299
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
300
+ noise = np.random.uniform(
301
+ noise_range[0], noise_range[1], size=kernel.shape)
302
+ kernel = kernel * noise
303
+ kernel = kernel / np.sum(kernel)
304
+ if strict:
305
+ return kernel, sigma_x, sigma_y, rotation, D
306
+ else:
307
+ return kernel
308
+
309
+
310
+ def random_bivariate_anisotropic_Gaussian(kernel_size,
311
+ sigma_x_range,
312
+ sigma_y_range,
313
+ rotation_range,
314
+ noise_range=None,
315
+ strict=False):
316
+ """Randomly generate bivariate anisotropic Gaussian kernels.
317
+ Args:
318
+ kernel_size (int):
319
+ sigma_x_range (tuple): [0.6, 5]
320
+ sigma_y_range (tuple): [0.6, 5]
321
+ rotation range (tuple): [-math.pi, math.pi]
322
+ noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
323
+ Returns:
324
+ kernel (ndarray):
325
+ """
326
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
327
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
328
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
329
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
330
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
331
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
332
+ if strict:
333
+ sigma_max = np.max([sigma_x, sigma_y])
334
+ sigma_min = np.min([sigma_x, sigma_y])
335
+ sigma_x, sigma_y = sigma_max, sigma_min
336
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
337
+
338
+ kernel = bivariate_anisotropic_Gaussian(kernel_size, sigma_x, sigma_y,
339
+ rotation)
340
+
341
+ # add multiplicative noise
342
+ if noise_range is not None:
343
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
344
+ noise = np.random.uniform(
345
+ noise_range[0], noise_range[1], size=kernel.shape)
346
+ kernel = kernel * noise
347
+ kernel = kernel / np.sum(kernel)
348
+ if strict:
349
+ return kernel, sigma_x, sigma_y, rotation
350
+ else:
351
+ return kernel
352
+
353
+
354
+ def random_bivariate_isotropic_Gaussian(kernel_size,
355
+ sigma_range,
356
+ noise_range=None,
357
+ strict=False):
358
+ """Randomly generate bivariate isotropic Gaussian kernels.
359
+ Args:
360
+ kernel_size (int):
361
+ sigma_range (tuple): [0.6, 5]
362
+ noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
363
+ Returns:
364
+ kernel (ndarray):
365
+ """
366
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
367
+ assert sigma_range[0] < sigma_range[1], 'Wrong sigma_x_range.'
368
+ sigma = np.random.uniform(sigma_range[0], sigma_range[1])
369
+
370
+ kernel = bivariate_isotropic_Gaussian(kernel_size, sigma)
371
+
372
+ # add multiplicative noise
373
+ if noise_range is not None:
374
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
375
+ noise = np.random.uniform(
376
+ noise_range[0], noise_range[1], size=kernel.shape)
377
+ kernel = kernel * noise
378
+ kernel = kernel / np.sum(kernel)
379
+ if strict:
380
+ return kernel, sigma
381
+ else:
382
+ return kernel
383
+
384
+
385
+ def random_bivariate_generalized_Gaussian(kernel_size,
386
+ sigma_x_range,
387
+ sigma_y_range,
388
+ rotation_range,
389
+ beta_range,
390
+ noise_range=None,
391
+ strict=False):
392
+ """Randomly generate bivariate generalized Gaussian kernels.
393
+ Args:
394
+ kernel_size (int):
395
+ sigma_x_range (tuple): [0.6, 5]
396
+ sigma_y_range (tuple): [0.6, 5]
397
+ rotation range (tuple): [-math.pi, math.pi]
398
+ beta_range (tuple): [0.5, 8]
399
+ noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
400
+ Returns:
401
+ kernel (ndarray):
402
+ """
403
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
404
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
405
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
406
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
407
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
408
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
409
+ if strict:
410
+ sigma_max = np.max([sigma_x, sigma_y])
411
+ sigma_min = np.min([sigma_x, sigma_y])
412
+ sigma_x, sigma_y = sigma_max, sigma_min
413
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
414
+ if np.random.uniform() < 0.5:
415
+ beta = np.random.uniform(beta_range[0], 1)
416
+ else:
417
+ beta = np.random.uniform(1, beta_range[1])
418
+
419
+ kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y,
420
+ rotation, beta)
421
+
422
+ # add multiplicative noise
423
+ if noise_range is not None:
424
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
425
+ noise = np.random.uniform(
426
+ noise_range[0], noise_range[1], size=kernel.shape)
427
+ kernel = kernel * noise
428
+ kernel = kernel / np.sum(kernel)
429
+ if strict:
430
+ return kernel, sigma_x, sigma_y, rotation, beta
431
+ else:
432
+ return kernel
433
+
434
+
435
+ def random_bivariate_plateau_type1(kernel_size,
436
+ sigma_x_range,
437
+ sigma_y_range,
438
+ rotation_range,
439
+ beta_range,
440
+ noise_range=None,
441
+ strict=False):
442
+ """Randomly generate bivariate plateau type1 kernels.
443
+ Args:
444
+ kernel_size (int):
445
+ sigma_x_range (tuple): [0.6, 5]
446
+ sigma_y_range (tuple): [0.6, 5]
447
+ rotation range (tuple): [-math.pi/2, math.pi/2]
448
+ beta_range (tuple): [1, 4]
449
+ noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
450
+ Returns:
451
+ kernel (ndarray):
452
+ """
453
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
454
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
455
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
456
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
457
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
458
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
459
+ if strict:
460
+ sigma_max = np.max([sigma_x, sigma_y])
461
+ sigma_min = np.min([sigma_x, sigma_y])
462
+ sigma_x, sigma_y = sigma_max, sigma_min
463
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
464
+ if np.random.uniform() < 0.5:
465
+ beta = np.random.uniform(beta_range[0], 1)
466
+ else:
467
+ beta = np.random.uniform(1, beta_range[1])
468
+
469
+ kernel = bivariate_plateau_type1(kernel_size, sigma_x, sigma_y, rotation,
470
+ beta)
471
+
472
+ # add multiplicative noise
473
+ if noise_range is not None:
474
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
475
+ noise = np.random.uniform(
476
+ noise_range[0], noise_range[1], size=kernel.shape)
477
+ kernel = kernel * noise
478
+ kernel = kernel / np.sum(kernel)
479
+ if strict:
480
+ return kernel, sigma_x, sigma_y, rotation, beta
481
+ else:
482
+ return kernel
483
+
484
+
485
+ def random_bivariate_plateau_type1_iso(kernel_size,
486
+ sigma_range,
487
+ beta_range,
488
+ noise_range=None,
489
+ strict=False):
490
+ """Randomly generate bivariate plateau type1 kernels (iso).
491
+ Args:
492
+ kernel_size (int):
493
+ sigma_range (tuple): [0.6, 5]
494
+ beta_range (tuple): [1, 4]
495
+ noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
496
+ Returns:
497
+ kernel (ndarray):
498
+ """
499
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
500
+ assert sigma_range[0] < sigma_range[1], 'Wrong sigma_x_range.'
501
+ sigma = np.random.uniform(sigma_range[0], sigma_range[1])
502
+ beta = np.random.uniform(beta_range[0], beta_range[1])
503
+
504
+ kernel = bivariate_plateau_type1_iso(kernel_size, sigma, beta)
505
+
506
+ # add multiplicative noise
507
+ if noise_range is not None:
508
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
509
+ noise = np.random.uniform(
510
+ noise_range[0], noise_range[1], size=kernel.shape)
511
+ kernel = kernel * noise
512
+ kernel = kernel / np.sum(kernel)
513
+ if strict:
514
+ return kernel, sigma, beta
515
+ else:
516
+ return kernel
517
+
518
+
519
+ def random_mixed_kernels(kernel_list,
520
+ kernel_prob,
521
+ kernel_size=21,
522
+ sigma_x_range=[0.6, 5],
523
+ sigma_y_range=[0.6, 5],
524
+ rotation_range=[-math.pi, math.pi],
525
+ beta_range=[0.5, 8],
526
+ noise_range=None):
527
+ """Randomly generate mixed kernels.
528
+ Args:
529
+ kernel_list (tuple): a list name of kenrel types,
530
+ support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso', 'plateau_aniso']
531
+ kernel_prob (tuple): corresponding kernel probability for each kernel type
532
+ kernel_size (int):
533
+ sigma_x_range (tuple): [0.6, 5]
534
+ sigma_y_range (tuple): [0.6, 5]
535
+ rotation range (tuple): [-math.pi, math.pi]
536
+ beta_range (tuple): [0.5, 8]
537
+ noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
538
+ Returns:
539
+ kernel (ndarray):
540
+ """
541
+ kernel_type = random.choices(kernel_list, kernel_prob)[0]
542
+ if kernel_type == 'iso':
543
+ kernel = random_bivariate_isotropic_Gaussian(
544
+ kernel_size, sigma_x_range, noise_range=noise_range)
545
+ elif kernel_type == 'aniso':
546
+ kernel = random_bivariate_anisotropic_Gaussian(
547
+ kernel_size,
548
+ sigma_x_range,
549
+ sigma_y_range,
550
+ rotation_range,
551
+ noise_range=noise_range)
552
+ elif kernel_type == 'skew':
553
+ kernel = random_bivariate_skew_Gaussian_center(
554
+ kernel_size,
555
+ sigma_x_range,
556
+ sigma_y_range,
557
+ rotation_range,
558
+ noise_range=noise_range)
559
+ elif kernel_type == 'generalized':
560
+ kernel = random_bivariate_generalized_Gaussian(
561
+ kernel_size,
562
+ sigma_x_range,
563
+ sigma_y_range,
564
+ rotation_range,
565
+ beta_range,
566
+ noise_range=noise_range)
567
+ elif kernel_type == 'plateau_iso':
568
+ kernel = random_bivariate_plateau_type1_iso(
569
+ kernel_size, sigma_x_range, beta_range, noise_range=noise_range)
570
+ elif kernel_type == 'plateau_aniso':
571
+ kernel = random_bivariate_plateau_type1(
572
+ kernel_size,
573
+ sigma_x_range,
574
+ sigma_y_range,
575
+ rotation_range,
576
+ beta_range,
577
+ noise_range=noise_range)
578
+ # add multiplicative noise
579
+ if noise_range is not None:
580
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
581
+ noise = np.random.uniform(
582
+ noise_range[0], noise_range[1], size=kernel.shape)
583
+ kernel = kernel * noise
584
+ kernel = kernel / np.sum(kernel)
585
+ return kernel
586
+
587
+
588
+ def show_one_kernel():
589
+ import matplotlib.pyplot as plt
590
+ kernel_size = 21
591
+
592
+ # bivariate skew Gaussian
593
+ D = [[0, 0], [0, 0]]
594
+ D = [[3 / 4, 0], [0, 0.5]]
595
+ kernel = bivariate_skew_Gaussian_center(kernel_size, 2, 4, -math.pi / 4, D)
596
+ # bivariate anisotropic Gaussian
597
+ kernel = bivariate_anisotropic_Gaussian(kernel_size, 2, 4, -math.pi / 4)
598
+ # bivariate anisotropic Gaussian
599
+ kernel = bivariate_isotropic_Gaussian(kernel_size, 1)
600
+ # bivariate generalized Gaussian
601
+ kernel = bivariate_generalized_Gaussian(
602
+ kernel_size, 2, 4, -math.pi / 4, beta=4)
603
+
604
+ delta_h, delta_w = mass_center_shift(kernel_size, kernel)
605
+ print(delta_h, delta_w)
606
+
607
+ fig, axs = plt.subplots(nrows=2, ncols=2)
608
+ # axs.set_axis_off()
609
+ ax = axs[0][0]
610
+ im = ax.matshow(kernel, cmap='jet', origin='upper')
611
+ fig.colorbar(im, ax=ax)
612
+
613
+ # image
614
+ ax = axs[0][1]
615
+ kernel_vis = kernel - np.min(kernel)
616
+ kernel_vis = kernel_vis / np.max(kernel_vis) * 255.
617
+ ax.imshow(kernel_vis, interpolation='nearest')
618
+
619
+ _, xx, yy = mesh_grid(kernel_size)
620
+ # contour
621
+ ax = axs[1][0]
622
+ CS = ax.contour(xx, yy, kernel, origin='upper')
623
+ ax.clabel(CS, inline=1, fontsize=3)
624
+
625
+ # contourf
626
+ ax = axs[1][1]
627
+ kernel = kernel / np.max(kernel)
628
+ p = ax.contourf(
629
+ xx, yy, kernel, origin='upper', levels=np.linspace(-0.05, 1.05, 10))
630
+ fig.colorbar(p)
631
+
632
+ plt.show()
633
+
634
+
635
+ def show_plateau_kernel():
636
+ import matplotlib.pyplot as plt
637
+ kernel_size = 21
638
+
639
+ kernel = plateau_type1(kernel_size, 2, 4, -math.pi / 8, 2, grid=None)
640
+ kernel_norm = bivariate_isotropic_Gaussian(kernel_size, 5)
641
+ kernel_gau = bivariate_generalized_Gaussian(
642
+ kernel_size, 2, 4, -math.pi / 8, 2, grid=None)
643
+ delta_h, delta_w = mass_center_shift(kernel_size, kernel)
644
+ print(delta_h, delta_w)
645
+
646
+ # kernel_slice = kernel[10, :]
647
+ # kernel_gau_slice = kernel_gau[10, :]
648
+ # kernel_norm_slice = kernel_norm[10, :]
649
+ # fig, ax = plt.subplots()
650
+ # t = list(range(1, 22))
651
+
652
+ # ax.plot(t, kernel_gau_slice)
653
+ # ax.plot(t, kernel_slice)
654
+ # ax.plot(t, kernel_norm_slice)
655
+
656
+ # t = np.arange(0, 10, 0.1)
657
+ # y = np.exp(-0.5 * t)
658
+ # y2 = np.reciprocal(1 + t)
659
+ # print(t.shape)
660
+ # print(y.shape)
661
+ # ax.plot(t, y)
662
+ # ax.plot(t, y2)
663
+ # plt.show()
664
+
665
+ fig, axs = plt.subplots(nrows=2, ncols=2)
666
+ # axs.set_axis_off()
667
+ ax = axs[0][0]
668
+ im = ax.matshow(kernel, cmap='jet', origin='upper')
669
+ fig.colorbar(im, ax=ax)
670
+
671
+ # image
672
+ ax = axs[0][1]
673
+ kernel_vis = kernel - np.min(kernel)
674
+ kernel_vis = kernel_vis / np.max(kernel_vis) * 255.
675
+ ax.imshow(kernel_vis, interpolation='nearest')
676
+
677
+ _, xx, yy = mesh_grid(kernel_size)
678
+ # contour
679
+ ax = axs[1][0]
680
+ CS = ax.contour(xx, yy, kernel, origin='upper')
681
+ ax.clabel(CS, inline=1, fontsize=3)
682
+
683
+ # contourf
684
+ ax = axs[1][1]
685
+ kernel = kernel / np.max(kernel)
686
+ p = ax.contourf(
687
+ xx, yy, kernel, origin='upper', levels=np.linspace(-0.05, 1.05, 10))
688
+ fig.colorbar(p)
689
+
690
+ plt.show()
basicsr/data/paired_image_dataset.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils import data as data
2
+ from torchvision.transforms.functional import normalize
3
+
4
+ from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file
5
+ from basicsr.data.transforms import augment, paired_random_crop
6
+ from basicsr.utils import FileClient, imfrombytes, img2tensor
7
+ from basicsr.utils.registry import DATASET_REGISTRY
8
+
9
+
10
+ @DATASET_REGISTRY.register()
11
+ class PairedImageDataset(data.Dataset):
12
+ """Paired image dataset for image restoration.
13
+
14
+ Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and
15
+ GT image pairs.
16
+
17
+ There are three modes:
18
+ 1. 'lmdb': Use lmdb files.
19
+ If opt['io_backend'] == lmdb.
20
+ 2. 'meta_info_file': Use meta information file to generate paths.
21
+ If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
22
+ 3. 'folder': Scan folders to generate paths.
23
+ The rest.
24
+
25
+ Args:
26
+ opt (dict): Config for train datasets. It contains the following keys:
27
+ dataroot_gt (str): Data root path for gt.
28
+ dataroot_lq (str): Data root path for lq.
29
+ meta_info_file (str): Path for meta information file.
30
+ io_backend (dict): IO backend type and other kwarg.
31
+ filename_tmpl (str): Template for each filename. Note that the
32
+ template excludes the file extension. Default: '{}'.
33
+ gt_size (int): Cropped patched size for gt patches.
34
+ use_flip (bool): Use horizontal flips.
35
+ use_rot (bool): Use rotation (use vertical flip and transposing h
36
+ and w for implementation).
37
+
38
+ scale (bool): Scale, which will be added automatically.
39
+ phase (str): 'train' or 'val'.
40
+ """
41
+
42
+ def __init__(self, opt):
43
+ super(PairedImageDataset, self).__init__()
44
+ self.opt = opt
45
+ # file client (io backend)
46
+ self.file_client = None
47
+ self.io_backend_opt = opt['io_backend']
48
+ self.mean = opt['mean'] if 'mean' in opt else None
49
+ self.std = opt['std'] if 'std' in opt else None
50
+
51
+ self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
52
+ if 'filename_tmpl' in opt:
53
+ self.filename_tmpl = opt['filename_tmpl']
54
+ else:
55
+ self.filename_tmpl = '{}'
56
+
57
+ if self.io_backend_opt['type'] == 'lmdb':
58
+ self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
59
+ self.io_backend_opt['client_keys'] = ['lq', 'gt']
60
+ self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
61
+ elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None:
62
+ self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'],
63
+ self.opt['meta_info_file'], self.filename_tmpl)
64
+ else:
65
+ self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
66
+
67
+ def __getitem__(self, index):
68
+ if self.file_client is None:
69
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
70
+
71
+ scale = self.opt['scale']
72
+
73
+ # Load gt and lq images. Dimension order: HWC; channel order: BGR;
74
+ # image range: [0, 1], float32.
75
+ gt_path = self.paths[index]['gt_path']
76
+ img_bytes = self.file_client.get(gt_path, 'gt')
77
+ img_gt = imfrombytes(img_bytes, float32=True)
78
+ lq_path = self.paths[index]['lq_path']
79
+ img_bytes = self.file_client.get(lq_path, 'lq')
80
+ img_lq = imfrombytes(img_bytes, float32=True)
81
+
82
+ # augmentation for training
83
+ if self.opt['phase'] == 'train':
84
+ gt_size = self.opt['gt_size']
85
+ # random crop
86
+ img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
87
+ # flip, rotation
88
+ img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_flip'], self.opt['use_rot'])
89
+
90
+ # TODO: color space transform
91
+ # BGR to RGB, HWC to CHW, numpy to tensor
92
+ img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
93
+ # normalize
94
+ if self.mean is not None or self.std is not None:
95
+ normalize(img_lq, self.mean, self.std, inplace=True)
96
+ normalize(img_gt, self.mean, self.std, inplace=True)
97
+
98
+ return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
99
+
100
+ def __len__(self):
101
+ return len(self.paths)
basicsr/data/prefetch_dataloader.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import queue as Queue
2
+ import threading
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+
6
+
7
+ class PrefetchGenerator(threading.Thread):
8
+ """A general prefetch generator.
9
+
10
+ Ref:
11
+ https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
12
+
13
+ Args:
14
+ generator: Python generator.
15
+ num_prefetch_queue (int): Number of prefetch queue.
16
+ """
17
+
18
+ def __init__(self, generator, num_prefetch_queue):
19
+ threading.Thread.__init__(self)
20
+ self.queue = Queue.Queue(num_prefetch_queue)
21
+ self.generator = generator
22
+ self.daemon = True
23
+ self.start()
24
+
25
+ def run(self):
26
+ for item in self.generator:
27
+ self.queue.put(item)
28
+ self.queue.put(None)
29
+
30
+ def __next__(self):
31
+ next_item = self.queue.get()
32
+ if next_item is None:
33
+ raise StopIteration
34
+ return next_item
35
+
36
+ def __iter__(self):
37
+ return self
38
+
39
+
40
+ class PrefetchDataLoader(DataLoader):
41
+ """Prefetch version of dataloader.
42
+
43
+ Ref:
44
+ https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
45
+
46
+ TODO:
47
+ Need to test on single gpu and ddp (multi-gpu). There is a known issue in
48
+ ddp.
49
+
50
+ Args:
51
+ num_prefetch_queue (int): Number of prefetch queue.
52
+ kwargs (dict): Other arguments for dataloader.
53
+ """
54
+
55
+ def __init__(self, num_prefetch_queue, **kwargs):
56
+ self.num_prefetch_queue = num_prefetch_queue
57
+ super(PrefetchDataLoader, self).__init__(**kwargs)
58
+
59
+ def __iter__(self):
60
+ return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
61
+
62
+
63
+ class CPUPrefetcher():
64
+ """CPU prefetcher.
65
+
66
+ Args:
67
+ loader: Dataloader.
68
+ """
69
+
70
+ def __init__(self, loader):
71
+ self.ori_loader = loader
72
+ self.loader = iter(loader)
73
+
74
+ def next(self):
75
+ try:
76
+ return next(self.loader)
77
+ except StopIteration:
78
+ return None
79
+
80
+ def reset(self):
81
+ self.loader = iter(self.ori_loader)
82
+
83
+
84
+ class CUDAPrefetcher():
85
+ """CUDA prefetcher.
86
+
87
+ Ref:
88
+ https://github.com/NVIDIA/apex/issues/304#
89
+
90
+ It may consums more GPU memory.
91
+
92
+ Args:
93
+ loader: Dataloader.
94
+ opt (dict): Options.
95
+ """
96
+
97
+ def __init__(self, loader, opt):
98
+ self.ori_loader = loader
99
+ self.loader = iter(loader)
100
+ self.opt = opt
101
+ self.stream = torch.cuda.Stream()
102
+ self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
103
+ self.preload()
104
+
105
+ def preload(self):
106
+ try:
107
+ self.batch = next(self.loader) # self.batch is a dict
108
+ except StopIteration:
109
+ self.batch = None
110
+ return None
111
+ # put tensors to gpu
112
+ with torch.cuda.stream(self.stream):
113
+ for k, v in self.batch.items():
114
+ if torch.is_tensor(v):
115
+ self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
116
+
117
+ def next(self):
118
+ torch.cuda.current_stream().wait_stream(self.stream)
119
+ batch = self.batch
120
+ self.preload()
121
+ return batch
122
+
123
+ def reset(self):
124
+ self.loader = iter(self.ori_loader)
125
+ self.preload()
basicsr/data/transforms.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import random
3
+
4
+
5
+ def mod_crop(img, scale):
6
+ """Mod crop images, used during testing.
7
+
8
+ Args:
9
+ img (ndarray): Input image.
10
+ scale (int): Scale factor.
11
+
12
+ Returns:
13
+ ndarray: Result image.
14
+ """
15
+ img = img.copy()
16
+ if img.ndim in (2, 3):
17
+ h, w = img.shape[0], img.shape[1]
18
+ h_remainder, w_remainder = h % scale, w % scale
19
+ img = img[:h - h_remainder, :w - w_remainder, ...]
20
+ else:
21
+ raise ValueError(f'Wrong img ndim: {img.ndim}.')
22
+ return img
23
+
24
+
25
+ def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path):
26
+ """Paired random crop.
27
+
28
+ It crops lists of lq and gt images with corresponding locations.
29
+
30
+ Args:
31
+ img_gts (list[ndarray] | ndarray): GT images. Note that all images
32
+ should have the same shape. If the input is an ndarray, it will
33
+ be transformed to a list containing itself.
34
+ img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
35
+ should have the same shape. If the input is an ndarray, it will
36
+ be transformed to a list containing itself.
37
+ gt_patch_size (int): GT patch size.
38
+ scale (int): Scale factor.
39
+ gt_path (str): Path to ground-truth.
40
+
41
+ Returns:
42
+ list[ndarray] | ndarray: GT images and LQ images. If returned results
43
+ only have one element, just return ndarray.
44
+ """
45
+
46
+ if not isinstance(img_gts, list):
47
+ img_gts = [img_gts]
48
+ if not isinstance(img_lqs, list):
49
+ img_lqs = [img_lqs]
50
+
51
+ h_lq, w_lq, _ = img_lqs[0].shape
52
+ h_gt, w_gt, _ = img_gts[0].shape
53
+ lq_patch_size = gt_patch_size // scale
54
+
55
+ if h_gt != h_lq * scale or w_gt != w_lq * scale:
56
+ raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
57
+ f'multiplication of LQ ({h_lq}, {w_lq}).')
58
+ if h_lq < lq_patch_size or w_lq < lq_patch_size:
59
+ raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
60
+ f'({lq_patch_size}, {lq_patch_size}). '
61
+ f'Please remove {gt_path}.')
62
+
63
+ # randomly choose top and left coordinates for lq patch
64
+ top = random.randint(0, h_lq - lq_patch_size)
65
+ left = random.randint(0, w_lq - lq_patch_size)
66
+
67
+ # crop lq patch
68
+ img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
69
+
70
+ # crop corresponding gt patch
71
+ top_gt, left_gt = int(top * scale), int(left * scale)
72
+ img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
73
+ if len(img_gts) == 1:
74
+ img_gts = img_gts[0]
75
+ if len(img_lqs) == 1:
76
+ img_lqs = img_lqs[0]
77
+ return img_gts, img_lqs
78
+
79
+
80
+ def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
81
+ """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
82
+
83
+ We use vertical flip and transpose for rotation implementation.
84
+ All the images in the list use the same augmentation.
85
+
86
+ Args:
87
+ imgs (list[ndarray] | ndarray): Images to be augmented. If the input
88
+ is an ndarray, it will be transformed to a list.
89
+ hflip (bool): Horizontal flip. Default: True.
90
+ rotation (bool): Ratotation. Default: True.
91
+ flows (list[ndarray]: Flows to be augmented. If the input is an
92
+ ndarray, it will be transformed to a list.
93
+ Dimension is (h, w, 2). Default: None.
94
+ return_status (bool): Return the status of flip and rotation.
95
+ Default: False.
96
+
97
+ Returns:
98
+ list[ndarray] | ndarray: Augmented images and flows. If returned
99
+ results only have one element, just return ndarray.
100
+
101
+ """
102
+ hflip = hflip and random.random() < 0.5
103
+ vflip = rotation and random.random() < 0.5
104
+ rot90 = rotation and random.random() < 0.5
105
+
106
+ def _augment(img):
107
+ if hflip: # horizontal
108
+ cv2.flip(img, 1, img)
109
+ if vflip: # vertical
110
+ cv2.flip(img, 0, img)
111
+ if rot90:
112
+ img = img.transpose(1, 0, 2)
113
+ return img
114
+
115
+ def _augment_flow(flow):
116
+ if hflip: # horizontal
117
+ cv2.flip(flow, 1, flow)
118
+ flow[:, :, 0] *= -1
119
+ if vflip: # vertical
120
+ cv2.flip(flow, 0, flow)
121
+ flow[:, :, 1] *= -1
122
+ if rot90:
123
+ flow = flow.transpose(1, 0, 2)
124
+ flow = flow[:, :, [1, 0]]
125
+ return flow
126
+
127
+ if not isinstance(imgs, list):
128
+ imgs = [imgs]
129
+ imgs = [_augment(img) for img in imgs]
130
+ if len(imgs) == 1:
131
+ imgs = imgs[0]
132
+
133
+ if flows is not None:
134
+ if not isinstance(flows, list):
135
+ flows = [flows]
136
+ flows = [_augment_flow(flow) for flow in flows]
137
+ if len(flows) == 1:
138
+ flows = flows[0]
139
+ return imgs, flows
140
+ else:
141
+ if return_status:
142
+ return imgs, (hflip, vflip, rot90)
143
+ else:
144
+ return imgs
145
+
146
+
147
+ def img_rotate(img, angle, center=None, scale=1.0):
148
+ """Rotate image.
149
+
150
+ Args:
151
+ img (ndarray): Image to be rotated.
152
+ angle (float): Rotation angle in degrees. Positive values mean
153
+ counter-clockwise rotation.
154
+ center (tuple[int]): Rotation center. If the center is None,
155
+ initialize it as the center of the image. Default: None.
156
+ scale (float): Isotropic scale factor. Default: 1.0.
157
+ """
158
+ (h, w) = img.shape[:2]
159
+
160
+ if center is None:
161
+ center = (w // 2, h // 2)
162
+
163
+ matrix = cv2.getRotationMatrix2D(center, angle, scale)
164
+ rotated_img = cv2.warpAffine(img, matrix, (w, h))
165
+ return rotated_img
basicsr/losses/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+
3
+ from basicsr.utils import get_root_logger
4
+ from basicsr.utils.registry import LOSS_REGISTRY
5
+ from .losses import (CharbonnierLoss, GANLoss, L1Loss, MSELoss, PerceptualLoss, WeightedTVLoss, g_path_regularize,
6
+ gradient_penalty_loss, r1_penalty)
7
+
8
+ __all__ = [
9
+ 'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'PerceptualLoss', 'GANLoss', 'gradient_penalty_loss',
10
+ 'r1_penalty', 'g_path_regularize'
11
+ ]
12
+
13
+
14
+ def build_loss(opt):
15
+ """Build loss from options.
16
+
17
+ Args:
18
+ opt (dict): Configuration. It must constain:
19
+ type (str): Model type.
20
+ """
21
+ opt = deepcopy(opt)
22
+ loss_type = opt.pop('type')
23
+ loss = LOSS_REGISTRY.get(loss_type)(**opt)
24
+ logger = get_root_logger()
25
+ logger.info(f'Loss [{loss.__class__.__name__}] is created.')
26
+ return loss
basicsr/losses/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (999 Bytes). View file
 
basicsr/losses/__pycache__/loss_util.cpython-39.pyc ADDED
Binary file (2.67 kB). View file
 
basicsr/losses/__pycache__/losses.cpython-39.pyc ADDED
Binary file (14.6 kB). View file
 
basicsr/losses/loss_util.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ from torch.nn import functional as F
3
+
4
+
5
+ def reduce_loss(loss, reduction):
6
+ """Reduce loss as specified.
7
+
8
+ Args:
9
+ loss (Tensor): Elementwise loss tensor.
10
+ reduction (str): Options are 'none', 'mean' and 'sum'.
11
+
12
+ Returns:
13
+ Tensor: Reduced loss tensor.
14
+ """
15
+ reduction_enum = F._Reduction.get_enum(reduction)
16
+ # none: 0, elementwise_mean:1, sum: 2
17
+ if reduction_enum == 0:
18
+ return loss
19
+ elif reduction_enum == 1:
20
+ return loss.mean()
21
+ else:
22
+ return loss.sum()
23
+
24
+
25
+ def weight_reduce_loss(loss, weight=None, reduction='mean'):
26
+ """Apply element-wise weight and reduce loss.
27
+
28
+ Args:
29
+ loss (Tensor): Element-wise loss.
30
+ weight (Tensor): Element-wise weights. Default: None.
31
+ reduction (str): Same as built-in losses of PyTorch. Options are
32
+ 'none', 'mean' and 'sum'. Default: 'mean'.
33
+
34
+ Returns:
35
+ Tensor: Loss values.
36
+ """
37
+ # if weight is specified, apply element-wise weight
38
+ if weight is not None:
39
+ assert weight.dim() == loss.dim()
40
+ assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
41
+ loss = loss * weight
42
+
43
+ # if weight is not specified or reduction is sum, just reduce the loss
44
+ if weight is None or reduction == 'sum':
45
+ loss = reduce_loss(loss, reduction)
46
+ # if reduction is mean, then compute mean over weight region
47
+ elif reduction == 'mean':
48
+ if weight.size(1) > 1:
49
+ weight = weight.sum()
50
+ else:
51
+ weight = weight.sum() * loss.size(1)
52
+ loss = loss.sum() / weight
53
+
54
+ return loss
55
+
56
+
57
+ def weighted_loss(loss_func):
58
+ """Create a weighted version of a given loss function.
59
+
60
+ To use this decorator, the loss function must have the signature like
61
+ `loss_func(pred, target, **kwargs)`. The function only needs to compute
62
+ element-wise loss without any reduction. This decorator will add weight
63
+ and reduction arguments to the function. The decorated function will have
64
+ the signature like `loss_func(pred, target, weight=None, reduction='mean',
65
+ **kwargs)`.
66
+
67
+ :Example:
68
+
69
+ >>> import torch
70
+ >>> @weighted_loss
71
+ >>> def l1_loss(pred, target):
72
+ >>> return (pred - target).abs()
73
+
74
+ >>> pred = torch.Tensor([0, 2, 3])
75
+ >>> target = torch.Tensor([1, 1, 1])
76
+ >>> weight = torch.Tensor([1, 0, 1])
77
+
78
+ >>> l1_loss(pred, target)
79
+ tensor(1.3333)
80
+ >>> l1_loss(pred, target, weight)
81
+ tensor(1.5000)
82
+ >>> l1_loss(pred, target, reduction='none')
83
+ tensor([1., 1., 2.])
84
+ >>> l1_loss(pred, target, weight, reduction='sum')
85
+ tensor(3.)
86
+ """
87
+
88
+ @functools.wraps(loss_func)
89
+ def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
90
+ # get element-wise loss
91
+ loss = loss_func(pred, target, **kwargs)
92
+ loss = weight_reduce_loss(loss, weight, reduction)
93
+ return loss
94
+
95
+ return wrapper
basicsr/losses/losses.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import lpips
3
+ import torch
4
+ from torch import autograd as autograd
5
+ from torch import nn as nn
6
+ from torch.nn import functional as F
7
+
8
+ from basicsr.archs.vgg_arch import VGGFeatureExtractor
9
+ from basicsr.utils.registry import LOSS_REGISTRY
10
+ from .loss_util import weighted_loss
11
+
12
+ _reduction_modes = ['none', 'mean', 'sum']
13
+
14
+
15
+ @weighted_loss
16
+ def l1_loss(pred, target):
17
+ return F.l1_loss(pred, target, reduction='none')
18
+
19
+
20
+ @weighted_loss
21
+ def mse_loss(pred, target):
22
+ return F.mse_loss(pred, target, reduction='none')
23
+
24
+
25
+ @weighted_loss
26
+ def charbonnier_loss(pred, target, eps=1e-12):
27
+ return torch.sqrt((pred - target)**2 + eps)
28
+
29
+
30
+ @LOSS_REGISTRY.register()
31
+ class L1Loss(nn.Module):
32
+ """L1 (mean absolute error, MAE) loss.
33
+
34
+ Args:
35
+ loss_weight (float): Loss weight for L1 loss. Default: 1.0.
36
+ reduction (str): Specifies the reduction to apply to the output.
37
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
38
+ """
39
+
40
+ def __init__(self, loss_weight=1.0, reduction='mean'):
41
+ super(L1Loss, self).__init__()
42
+ if reduction not in ['none', 'mean', 'sum']:
43
+ raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
44
+
45
+ self.loss_weight = loss_weight
46
+ self.reduction = reduction
47
+
48
+ def forward(self, pred, target, weight=None, **kwargs):
49
+ """
50
+ Args:
51
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
52
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
53
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise
54
+ weights. Default: None.
55
+ """
56
+ return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction)
57
+
58
+
59
+ @LOSS_REGISTRY.register()
60
+ class MSELoss(nn.Module):
61
+ """MSE (L2) loss.
62
+
63
+ Args:
64
+ loss_weight (float): Loss weight for MSE loss. Default: 1.0.
65
+ reduction (str): Specifies the reduction to apply to the output.
66
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
67
+ """
68
+
69
+ def __init__(self, loss_weight=1.0, reduction='mean'):
70
+ super(MSELoss, self).__init__()
71
+ if reduction not in ['none', 'mean', 'sum']:
72
+ raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
73
+
74
+ self.loss_weight = loss_weight
75
+ self.reduction = reduction
76
+
77
+ def forward(self, pred, target, weight=None, **kwargs):
78
+ """
79
+ Args:
80
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
81
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
82
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise
83
+ weights. Default: None.
84
+ """
85
+ return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction)
86
+
87
+
88
+ @LOSS_REGISTRY.register()
89
+ class CharbonnierLoss(nn.Module):
90
+ """Charbonnier loss (one variant of Robust L1Loss, a differentiable
91
+ variant of L1Loss).
92
+
93
+ Described in "Deep Laplacian Pyramid Networks for Fast and Accurate
94
+ Super-Resolution".
95
+
96
+ Args:
97
+ loss_weight (float): Loss weight for L1 loss. Default: 1.0.
98
+ reduction (str): Specifies the reduction to apply to the output.
99
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
100
+ eps (float): A value used to control the curvature near zero.
101
+ Default: 1e-12.
102
+ """
103
+
104
+ def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12):
105
+ super(CharbonnierLoss, self).__init__()
106
+ if reduction not in ['none', 'mean', 'sum']:
107
+ raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
108
+
109
+ self.loss_weight = loss_weight
110
+ self.reduction = reduction
111
+ self.eps = eps
112
+
113
+ def forward(self, pred, target, weight=None, **kwargs):
114
+ """
115
+ Args:
116
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
117
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
118
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise
119
+ weights. Default: None.
120
+ """
121
+ return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction)
122
+
123
+
124
+ @LOSS_REGISTRY.register()
125
+ class WeightedTVLoss(L1Loss):
126
+ """Weighted TV loss.
127
+
128
+ Args:
129
+ loss_weight (float): Loss weight. Default: 1.0.
130
+ """
131
+
132
+ def __init__(self, loss_weight=1.0):
133
+ super(WeightedTVLoss, self).__init__(loss_weight=loss_weight)
134
+
135
+ def forward(self, pred, weight=None):
136
+ y_diff = super(WeightedTVLoss, self).forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=weight[:, :, :-1, :])
137
+ x_diff = super(WeightedTVLoss, self).forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=weight[:, :, :, :-1])
138
+
139
+ loss = x_diff + y_diff
140
+
141
+ return loss
142
+
143
+
144
+ @LOSS_REGISTRY.register()
145
+ class PerceptualLoss(nn.Module):
146
+ """Perceptual loss with commonly used style loss.
147
+
148
+ Args:
149
+ layer_weights (dict): The weight for each layer of vgg feature.
150
+ Here is an example: {'conv5_4': 1.}, which means the conv5_4
151
+ feature layer (before relu5_4) will be extracted with weight
152
+ 1.0 in calculting losses.
153
+ vgg_type (str): The type of vgg network used as feature extractor.
154
+ Default: 'vgg19'.
155
+ use_input_norm (bool): If True, normalize the input image in vgg.
156
+ Default: True.
157
+ range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
158
+ Default: False.
159
+ perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
160
+ loss will be calculated and the loss will multiplied by the
161
+ weight. Default: 1.0.
162
+ style_weight (float): If `style_weight > 0`, the style loss will be
163
+ calculated and the loss will multiplied by the weight.
164
+ Default: 0.
165
+ criterion (str): Criterion used for perceptual loss. Default: 'l1'.
166
+ """
167
+
168
+ def __init__(self,
169
+ layer_weights,
170
+ vgg_type='vgg19',
171
+ use_input_norm=True,
172
+ range_norm=False,
173
+ perceptual_weight=1.0,
174
+ style_weight=0.,
175
+ criterion='l1'):
176
+ super(PerceptualLoss, self).__init__()
177
+ self.perceptual_weight = perceptual_weight
178
+ self.style_weight = style_weight
179
+ self.layer_weights = layer_weights
180
+ self.vgg = VGGFeatureExtractor(
181
+ layer_name_list=list(layer_weights.keys()),
182
+ vgg_type=vgg_type,
183
+ use_input_norm=use_input_norm,
184
+ range_norm=range_norm)
185
+
186
+ self.criterion_type = criterion
187
+ if self.criterion_type == 'l1':
188
+ self.criterion = torch.nn.L1Loss()
189
+ elif self.criterion_type == 'l2':
190
+ self.criterion = torch.nn.L2loss()
191
+ elif self.criterion_type == 'mse':
192
+ self.criterion = torch.nn.MSELoss(reduction='mean')
193
+ elif self.criterion_type == 'fro':
194
+ self.criterion = None
195
+ else:
196
+ raise NotImplementedError(f'{criterion} criterion has not been supported.')
197
+
198
+ def forward(self, x, gt):
199
+ """Forward function.
200
+
201
+ Args:
202
+ x (Tensor): Input tensor with shape (n, c, h, w).
203
+ gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
204
+
205
+ Returns:
206
+ Tensor: Forward results.
207
+ """
208
+ # extract vgg features
209
+ x_features = self.vgg(x)
210
+ gt_features = self.vgg(gt.detach())
211
+
212
+ # calculate perceptual loss
213
+ if self.perceptual_weight > 0:
214
+ percep_loss = 0
215
+ for k in x_features.keys():
216
+ if self.criterion_type == 'fro':
217
+ percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k]
218
+ else:
219
+ percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
220
+ percep_loss *= self.perceptual_weight
221
+ else:
222
+ percep_loss = None
223
+
224
+ # calculate style loss
225
+ if self.style_weight > 0:
226
+ style_loss = 0
227
+ for k in x_features.keys():
228
+ if self.criterion_type == 'fro':
229
+ style_loss += torch.norm(
230
+ self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]
231
+ else:
232
+ style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(
233
+ gt_features[k])) * self.layer_weights[k]
234
+ style_loss *= self.style_weight
235
+ else:
236
+ style_loss = None
237
+
238
+ return percep_loss, style_loss
239
+
240
+ def _gram_mat(self, x):
241
+ """Calculate Gram matrix.
242
+
243
+ Args:
244
+ x (torch.Tensor): Tensor with shape of (n, c, h, w).
245
+
246
+ Returns:
247
+ torch.Tensor: Gram matrix.
248
+ """
249
+ n, c, h, w = x.size()
250
+ features = x.view(n, c, w * h)
251
+ features_t = features.transpose(1, 2)
252
+ gram = features.bmm(features_t) / (c * h * w)
253
+ return gram
254
+
255
+
256
+ @LOSS_REGISTRY.register()
257
+ class LPIPSLoss(nn.Module):
258
+ def __init__(self,
259
+ loss_weight=1.0,
260
+ use_input_norm=True,
261
+ range_norm=False,):
262
+ super(LPIPSLoss, self).__init__()
263
+ self.perceptual = lpips.LPIPS(net="vgg", spatial=False).eval()
264
+ self.loss_weight = loss_weight
265
+ self.use_input_norm = use_input_norm
266
+ self.range_norm = range_norm
267
+
268
+ if self.use_input_norm:
269
+ # the mean is for image with range [0, 1]
270
+ self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
271
+ # the std is for image with range [0, 1]
272
+ self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
273
+
274
+ def forward(self, pred, target):
275
+ if self.range_norm:
276
+ pred = (pred + 1) / 2
277
+ target = (target + 1) / 2
278
+ if self.use_input_norm:
279
+ pred = (pred - self.mean) / self.std
280
+ target = (target - self.mean) / self.std
281
+ lpips_loss = self.perceptual(target.contiguous(), pred.contiguous())
282
+ return self.loss_weight * lpips_loss.mean()
283
+
284
+
285
+ @LOSS_REGISTRY.register()
286
+ class GANLoss(nn.Module):
287
+ """Define GAN loss.
288
+
289
+ Args:
290
+ gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
291
+ real_label_val (float): The value for real label. Default: 1.0.
292
+ fake_label_val (float): The value for fake label. Default: 0.0.
293
+ loss_weight (float): Loss weight. Default: 1.0.
294
+ Note that loss_weight is only for generators; and it is always 1.0
295
+ for discriminators.
296
+ """
297
+
298
+ def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
299
+ super(GANLoss, self).__init__()
300
+ self.gan_type = gan_type
301
+ self.loss_weight = loss_weight
302
+ self.real_label_val = real_label_val
303
+ self.fake_label_val = fake_label_val
304
+
305
+ if self.gan_type == 'vanilla':
306
+ self.loss = nn.BCEWithLogitsLoss()
307
+ elif self.gan_type == 'lsgan':
308
+ self.loss = nn.MSELoss()
309
+ elif self.gan_type == 'wgan':
310
+ self.loss = self._wgan_loss
311
+ elif self.gan_type == 'wgan_softplus':
312
+ self.loss = self._wgan_softplus_loss
313
+ elif self.gan_type == 'hinge':
314
+ self.loss = nn.ReLU()
315
+ else:
316
+ raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')
317
+
318
+ def _wgan_loss(self, input, target):
319
+ """wgan loss.
320
+
321
+ Args:
322
+ input (Tensor): Input tensor.
323
+ target (bool): Target label.
324
+
325
+ Returns:
326
+ Tensor: wgan loss.
327
+ """
328
+ return -input.mean() if target else input.mean()
329
+
330
+ def _wgan_softplus_loss(self, input, target):
331
+ """wgan loss with soft plus. softplus is a smooth approximation to the
332
+ ReLU function.
333
+
334
+ In StyleGAN2, it is called:
335
+ Logistic loss for discriminator;
336
+ Non-saturating loss for generator.
337
+
338
+ Args:
339
+ input (Tensor): Input tensor.
340
+ target (bool): Target label.
341
+
342
+ Returns:
343
+ Tensor: wgan loss.
344
+ """
345
+ return F.softplus(-input).mean() if target else F.softplus(input).mean()
346
+
347
+ def get_target_label(self, input, target_is_real):
348
+ """Get target label.
349
+
350
+ Args:
351
+ input (Tensor): Input tensor.
352
+ target_is_real (bool): Whether the target is real or fake.
353
+
354
+ Returns:
355
+ (bool | Tensor): Target tensor. Return bool for wgan, otherwise,
356
+ return Tensor.
357
+ """
358
+
359
+ if self.gan_type in ['wgan', 'wgan_softplus']:
360
+ return target_is_real
361
+ target_val = (self.real_label_val if target_is_real else self.fake_label_val)
362
+ return input.new_ones(input.size()) * target_val
363
+
364
+ def forward(self, input, target_is_real, is_disc=False):
365
+ """
366
+ Args:
367
+ input (Tensor): The input for the loss module, i.e., the network
368
+ prediction.
369
+ target_is_real (bool): Whether the targe is real or fake.
370
+ is_disc (bool): Whether the loss for discriminators or not.
371
+ Default: False.
372
+
373
+ Returns:
374
+ Tensor: GAN loss value.
375
+ """
376
+ if self.gan_type == 'hinge':
377
+ if is_disc: # for discriminators in hinge-gan
378
+ input = -input if target_is_real else input
379
+ loss = self.loss(1 + input).mean()
380
+ else: # for generators in hinge-gan
381
+ loss = -input.mean()
382
+ else: # other gan types
383
+ target_label = self.get_target_label(input, target_is_real)
384
+ loss = self.loss(input, target_label)
385
+
386
+ # loss_weight is always 1.0 for discriminators
387
+ return loss if is_disc else loss * self.loss_weight
388
+
389
+
390
+ def r1_penalty(real_pred, real_img):
391
+ """R1 regularization for discriminator. The core idea is to
392
+ penalize the gradient on real data alone: when the
393
+ generator distribution produces the true data distribution
394
+ and the discriminator is equal to 0 on the data manifold, the
395
+ gradient penalty ensures that the discriminator cannot create
396
+ a non-zero gradient orthogonal to the data manifold without
397
+ suffering a loss in the GAN game.
398
+
399
+ Ref:
400
+ Eq. 9 in Which training methods for GANs do actually converge.
401
+ """
402
+ grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0]
403
+ grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
404
+ return grad_penalty
405
+
406
+
407
+ def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
408
+ noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3])
409
+ grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0]
410
+ path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
411
+
412
+ path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
413
+
414
+ path_penalty = (path_lengths - path_mean).pow(2).mean()
415
+
416
+ return path_penalty, path_lengths.detach().mean(), path_mean.detach()
417
+
418
+
419
+ def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
420
+ """Calculate gradient penalty for wgan-gp.
421
+
422
+ Args:
423
+ discriminator (nn.Module): Network for the discriminator.
424
+ real_data (Tensor): Real input data.
425
+ fake_data (Tensor): Fake input data.
426
+ weight (Tensor): Weight tensor. Default: None.
427
+
428
+ Returns:
429
+ Tensor: A tensor for gradient penalty.
430
+ """
431
+
432
+ batch_size = real_data.size(0)
433
+ alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))
434
+
435
+ # interpolate between real_data and fake_data
436
+ interpolates = alpha * real_data + (1. - alpha) * fake_data
437
+ interpolates = autograd.Variable(interpolates, requires_grad=True)
438
+
439
+ disc_interpolates = discriminator(interpolates)
440
+ gradients = autograd.grad(
441
+ outputs=disc_interpolates,
442
+ inputs=interpolates,
443
+ grad_outputs=torch.ones_like(disc_interpolates),
444
+ create_graph=True,
445
+ retain_graph=True,
446
+ only_inputs=True)[0]
447
+
448
+ if weight is not None:
449
+ gradients = gradients * weight
450
+
451
+ gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
452
+ if weight is not None:
453
+ gradients_penalty /= torch.mean(weight)
454
+
455
+ return gradients_penalty
basicsr/metrics/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+
3
+ from basicsr.utils.registry import METRIC_REGISTRY
4
+ from .psnr_ssim import calculate_psnr, calculate_ssim
5
+
6
+ __all__ = ['calculate_psnr', 'calculate_ssim']
7
+
8
+
9
+ def calculate_metric(data, opt):
10
+ """Calculate metric from data and options.
11
+
12
+ Args:
13
+ opt (dict): Configuration. It must constain:
14
+ type (str): Model type.
15
+ """
16
+ opt = deepcopy(opt)
17
+ metric_type = opt.pop('type')
18
+ metric = METRIC_REGISTRY.get(metric_type)(**data, **opt)
19
+ return metric
basicsr/metrics/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (695 Bytes). View file