模型代码
Browse files- utils/__init__.py +3 -0
- utils/logging.py +22 -0
- utils/metrics.py +207 -0
- utils/optimize.py +13 -0
- utils/sampling.py +89 -0
utils/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from utils.logging import *
|
2 |
+
from utils.sampling import *
|
3 |
+
from utils.optimize import *
|
utils/logging.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
import torchvision.utils as tvu
|
4 |
+
|
5 |
+
|
6 |
+
def save_image(img, file_directory):
|
7 |
+
if not os.path.exists(os.path.dirname(file_directory)):
|
8 |
+
os.makedirs(os.path.dirname(file_directory))
|
9 |
+
tvu.save_image(img, file_directory)
|
10 |
+
|
11 |
+
|
12 |
+
def save_checkpoint(state, filename):
|
13 |
+
if not os.path.exists(os.path.dirname(filename)):
|
14 |
+
os.makedirs(os.path.dirname(filename))
|
15 |
+
torch.save(state, filename + '.pth.tar')
|
16 |
+
|
17 |
+
|
18 |
+
def load_checkpoint(path, device):
|
19 |
+
if device is None:
|
20 |
+
return torch.load(path)
|
21 |
+
else:
|
22 |
+
return torch.load(path, map_location=device)
|
utils/metrics.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import PIL
|
4 |
+
|
5 |
+
def calculate_psnr(img1, img2, test_y_channel=False):
|
6 |
+
"""Calculate PSNR (Peak Signal-to-Noise Ratio).
|
7 |
+
|
8 |
+
Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
|
9 |
+
|
10 |
+
Args:
|
11 |
+
img1 (ndarray): Images with range [0, 255].
|
12 |
+
img2 (ndarray): Images with range [0, 255].
|
13 |
+
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
float: psnr result.
|
17 |
+
"""
|
18 |
+
|
19 |
+
assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
|
20 |
+
assert img1.shape[2] == 3
|
21 |
+
img1 = img1.astype(np.float64)
|
22 |
+
img2 = img2.astype(np.float64)
|
23 |
+
|
24 |
+
if test_y_channel:
|
25 |
+
img1 = to_y_channel(img1)
|
26 |
+
img2 = to_y_channel(img2)
|
27 |
+
|
28 |
+
mse = np.mean((img1 - img2) ** 2)
|
29 |
+
if mse == 0:
|
30 |
+
return float('inf')
|
31 |
+
return 20. * np.log10(255. / np.sqrt(mse))
|
32 |
+
|
33 |
+
|
34 |
+
def _ssim(img1, img2):
|
35 |
+
"""Calculate SSIM (structural similarity) for one channel images.
|
36 |
+
|
37 |
+
It is called by func:`calculate_ssim`.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
img1 (ndarray): Images with range [0, 255] with order 'HWC'.
|
41 |
+
img2 (ndarray): Images with range [0, 255] with order 'HWC'.
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
float: ssim result.
|
45 |
+
"""
|
46 |
+
|
47 |
+
C1 = (0.01 * 255) ** 2
|
48 |
+
C2 = (0.03 * 255) ** 2
|
49 |
+
|
50 |
+
img1 = img1.astype(np.float64)
|
51 |
+
img2 = img2.astype(np.float64)
|
52 |
+
kernel = cv2.getGaussianKernel(11, 1.5)
|
53 |
+
window = np.outer(kernel, kernel.transpose())
|
54 |
+
|
55 |
+
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
|
56 |
+
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
57 |
+
mu1_sq = mu1 ** 2
|
58 |
+
mu2_sq = mu2 ** 2
|
59 |
+
mu1_mu2 = mu1 * mu2
|
60 |
+
sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
|
61 |
+
sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
|
62 |
+
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
63 |
+
|
64 |
+
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
65 |
+
return ssim_map.mean()
|
66 |
+
|
67 |
+
|
68 |
+
def calculate_ssim(img1, img2, test_y_channel=False):
|
69 |
+
"""Calculate SSIM (structural similarity).
|
70 |
+
|
71 |
+
Ref:
|
72 |
+
Image quality assessment: From error visibility to structural similarity
|
73 |
+
|
74 |
+
The results are the same as that of the official released MATLAB code in
|
75 |
+
https://ece.uwaterloo.ca/~z70wang/research/ssim/.
|
76 |
+
|
77 |
+
For three-channel images, SSIM is calculated for each channel and then
|
78 |
+
averaged.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
img1 (ndarray): Images with range [0, 255].
|
82 |
+
img2 (ndarray): Images with range [0, 255].
|
83 |
+
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
float: ssim result.
|
87 |
+
"""
|
88 |
+
|
89 |
+
assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
|
90 |
+
assert img1.shape[2] == 3
|
91 |
+
img1 = img1.astype(np.float64)
|
92 |
+
img2 = img2.astype(np.float64)
|
93 |
+
|
94 |
+
if test_y_channel:
|
95 |
+
img1 = to_y_channel(img1)
|
96 |
+
img2 = to_y_channel(img2)
|
97 |
+
|
98 |
+
ssims = []
|
99 |
+
for i in range(img1.shape[2]):
|
100 |
+
ssims.append(_ssim(img1[..., i], img2[..., i]))
|
101 |
+
return np.array(ssims).mean()
|
102 |
+
|
103 |
+
|
104 |
+
def to_y_channel(img):
|
105 |
+
"""Change to Y channel of YCbCr.
|
106 |
+
|
107 |
+
Args:
|
108 |
+
img (ndarray): Images with range [0, 255].
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
(ndarray): Images with range [0, 255] (float type) without round.
|
112 |
+
"""
|
113 |
+
img = img.astype(np.float32) / 255.
|
114 |
+
if img.ndim == 3 and img.shape[2] == 3:
|
115 |
+
img = bgr2ycbcr(img, y_only=True)
|
116 |
+
img = img[..., None]
|
117 |
+
return img * 255.
|
118 |
+
|
119 |
+
|
120 |
+
def _convert_input_type_range(img):
|
121 |
+
"""Convert the type and range of the input image.
|
122 |
+
|
123 |
+
It converts the input image to np.float32 type and range of [0, 1].
|
124 |
+
It is mainly used for pre-processing the input image in colorspace
|
125 |
+
convertion functions such as rgb2ycbcr and ycbcr2rgb.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
img (ndarray): The input image. It accepts:
|
129 |
+
1. np.uint8 type with range [0, 255];
|
130 |
+
2. np.float32 type with range [0, 1].
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
(ndarray): The converted image with type of np.float32 and range of
|
134 |
+
[0, 1].
|
135 |
+
"""
|
136 |
+
img_type = img.dtype
|
137 |
+
img = img.astype(np.float32)
|
138 |
+
if img_type == np.float32:
|
139 |
+
pass
|
140 |
+
elif img_type == np.uint8:
|
141 |
+
img /= 255.
|
142 |
+
else:
|
143 |
+
raise TypeError('The img type should be np.float32 or np.uint8, ' f'but got {img_type}')
|
144 |
+
return img
|
145 |
+
|
146 |
+
|
147 |
+
def _convert_output_type_range(img, dst_type):
|
148 |
+
"""Convert the type and range of the image according to dst_type.
|
149 |
+
|
150 |
+
It converts the image to desired type and range. If `dst_type` is np.uint8,
|
151 |
+
images will be converted to np.uint8 type with range [0, 255]. If
|
152 |
+
`dst_type` is np.float32, it converts the image to np.float32 type with
|
153 |
+
range [0, 1].
|
154 |
+
It is mainly used for post-processing images in colorspace convertion
|
155 |
+
functions such as rgb2ycbcr and ycbcr2rgb.
|
156 |
+
|
157 |
+
Args:
|
158 |
+
img (ndarray): The image to be converted with np.float32 type and
|
159 |
+
range [0, 255].
|
160 |
+
dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
|
161 |
+
converts the image to np.uint8 type with range [0, 255]. If
|
162 |
+
dst_type is np.float32, it converts the image to np.float32 type
|
163 |
+
with range [0, 1].
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
(ndarray): The converted image with desired type and range.
|
167 |
+
"""
|
168 |
+
if dst_type not in (np.uint8, np.float32):
|
169 |
+
raise TypeError('The dst_type should be np.float32 or np.uint8, ' f'but got {dst_type}')
|
170 |
+
if dst_type == np.uint8:
|
171 |
+
img = img.round()
|
172 |
+
else:
|
173 |
+
img /= 255.
|
174 |
+
return img.astype(dst_type)
|
175 |
+
|
176 |
+
|
177 |
+
def bgr2ycbcr(img, y_only=False):
|
178 |
+
"""Convert a BGR image to YCbCr image.
|
179 |
+
|
180 |
+
The bgr version of rgb2ycbcr.
|
181 |
+
It implements the ITU-R BT.601 conversion for standard-definition
|
182 |
+
television. See more details in
|
183 |
+
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
184 |
+
|
185 |
+
It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
|
186 |
+
In OpenCV, it implements a JPEG conversion. See more details in
|
187 |
+
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
img (ndarray): The input image. It accepts:
|
191 |
+
1. np.uint8 type with range [0, 255];
|
192 |
+
2. np.float32 type with range [0, 1].
|
193 |
+
y_only (bool): Whether to only return Y channel. Default: False.
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
ndarray: The converted YCbCr image. The output image has the same type
|
197 |
+
and range as input image.
|
198 |
+
"""
|
199 |
+
img_type = img.dtype
|
200 |
+
img = _convert_input_type_range(img)
|
201 |
+
if y_only:
|
202 |
+
out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
|
203 |
+
else:
|
204 |
+
out_img = np.matmul(
|
205 |
+
img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128]
|
206 |
+
out_img = _convert_output_type_range(out_img, img_type)
|
207 |
+
return out_img
|
utils/optimize.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.optim as optim
|
2 |
+
|
3 |
+
|
4 |
+
def get_optimizer(config, parameters):
|
5 |
+
if config.optim.optimizer == 'Adam':
|
6 |
+
return optim.Adam(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay,
|
7 |
+
betas=(0.9, 0.999), amsgrad=config.optim.amsgrad, eps=config.optim.eps)
|
8 |
+
elif config.optim.optimizer == 'RMSProp':
|
9 |
+
return optim.RMSprop(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay)
|
10 |
+
elif config.optim.optimizer == 'SGD':
|
11 |
+
return optim.SGD(parameters, lr=config.optim.lr, momentum=0.9)
|
12 |
+
else:
|
13 |
+
raise NotImplementedError('Optimizer {} not understood.'.format(config.optim.optimizer))
|
utils/sampling.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchvision.transforms.functional import crop
|
3 |
+
|
4 |
+
|
5 |
+
def compute_alpha(beta, t):
|
6 |
+
beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
|
7 |
+
a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
|
8 |
+
return a
|
9 |
+
|
10 |
+
|
11 |
+
def data_transform(X):
|
12 |
+
return 2 * X - 1.0
|
13 |
+
|
14 |
+
|
15 |
+
def inverse_data_transform(X):
|
16 |
+
return torch.clamp((X + 1.0) / 2.0, 0.0, 1.0)
|
17 |
+
|
18 |
+
|
19 |
+
def generalized_steps(x, x_cond, seq, model, b, eta=0., device=None):
|
20 |
+
with torch.no_grad():
|
21 |
+
n = x.size(0)
|
22 |
+
seq_next = [-1] + list(seq[:-1])
|
23 |
+
x0_preds = []
|
24 |
+
xs = [x]
|
25 |
+
for i, j in zip(reversed(seq), reversed(seq_next)):
|
26 |
+
t = (torch.ones(n) * i).to(x.device)
|
27 |
+
next_t = (torch.ones(n) * j).to(x.device)
|
28 |
+
at = compute_alpha(b, t.long())
|
29 |
+
at_next = compute_alpha(b, next_t.long())
|
30 |
+
xt = xs[-1].to(device)
|
31 |
+
|
32 |
+
et = model(torch.cat([x_cond, xt], dim=1), t)
|
33 |
+
x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
|
34 |
+
x0_preds.append(x0_t.to(device))
|
35 |
+
|
36 |
+
c1 = eta * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt()
|
37 |
+
c2 = ((1 - at_next) - c1 ** 2).sqrt()
|
38 |
+
xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et
|
39 |
+
xs.append(xt_next.to(device))
|
40 |
+
return xs, x0_preds
|
41 |
+
|
42 |
+
|
43 |
+
def generalized_steps_overlapping(x, x_cond, seq, model, b, eta=0., corners=None, p_size=None, manual_batching=True,
|
44 |
+
device=None):
|
45 |
+
with torch.no_grad():
|
46 |
+
n = x.size(0)
|
47 |
+
seq_next = [-1] + list(seq[:-1])
|
48 |
+
x0_preds = []
|
49 |
+
xs = [x]
|
50 |
+
|
51 |
+
x_grid_mask = torch.zeros_like(x_cond, device=x.device)
|
52 |
+
for (hi, wi) in corners:
|
53 |
+
x_grid_mask[:, :, hi:hi + p_size, wi:wi + p_size] += 1
|
54 |
+
|
55 |
+
for i, j in zip(reversed(seq), reversed(seq_next)):
|
56 |
+
t = (torch.ones(n) * i).to(x.device)
|
57 |
+
next_t = (torch.ones(n) * j).to(x.device)
|
58 |
+
at = compute_alpha(b, t.long())
|
59 |
+
at_next = compute_alpha(b, next_t.long())
|
60 |
+
xt = xs[-1].to(device)
|
61 |
+
et_output = torch.zeros_like(x_cond, device=x.device)
|
62 |
+
|
63 |
+
if manual_batching:
|
64 |
+
manual_batching_size = 64
|
65 |
+
xt_patch = torch.cat([crop(xt, hi, wi, p_size, p_size) for (hi, wi) in corners], dim=0)
|
66 |
+
x_cond_patch = torch.cat([data_transform(crop(x_cond, hi, wi, p_size, p_size)) for (hi, wi) in corners],
|
67 |
+
dim=0)
|
68 |
+
for i in range(0, len(corners), manual_batching_size):
|
69 |
+
outputs = model(torch.cat([x_cond_patch[i:i + manual_batching_size],
|
70 |
+
xt_patch[i:i + manual_batching_size]], dim=1), t)
|
71 |
+
for idx, (hi, wi) in enumerate(corners[i:i + manual_batching_size]):
|
72 |
+
et_output[0, :, hi:hi + p_size, wi:wi + p_size] += outputs[idx]
|
73 |
+
else:
|
74 |
+
for (hi, wi) in corners:
|
75 |
+
xt_patch = crop(xt, hi, wi, p_size, p_size)
|
76 |
+
x_cond_patch = crop(x_cond, hi, wi, p_size, p_size)
|
77 |
+
x_cond_patch = data_transform(x_cond_patch)
|
78 |
+
et_output[:, :, hi:hi + p_size, wi:wi + p_size] += model(torch.cat([x_cond_patch, xt_patch], dim=1),
|
79 |
+
t)
|
80 |
+
|
81 |
+
et = torch.div(et_output, x_grid_mask)
|
82 |
+
x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
|
83 |
+
x0_preds.append(x0_t.to(device))
|
84 |
+
|
85 |
+
c1 = eta * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt()
|
86 |
+
c2 = ((1 - at_next) - c1 ** 2).sqrt()
|
87 |
+
xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et
|
88 |
+
xs.append(xt_next.to(device))
|
89 |
+
return xs, x0_preds
|