HaisuGuan commited on
Commit
cf75d92
·
1 Parent(s): 332a731

模型代码

Browse files
Files changed (5) hide show
  1. utils/__init__.py +3 -0
  2. utils/logging.py +22 -0
  3. utils/metrics.py +207 -0
  4. utils/optimize.py +13 -0
  5. utils/sampling.py +89 -0
utils/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from utils.logging import *
2
+ from utils.sampling import *
3
+ from utils.optimize import *
utils/logging.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import torchvision.utils as tvu
4
+
5
+
6
+ def save_image(img, file_directory):
7
+ if not os.path.exists(os.path.dirname(file_directory)):
8
+ os.makedirs(os.path.dirname(file_directory))
9
+ tvu.save_image(img, file_directory)
10
+
11
+
12
+ def save_checkpoint(state, filename):
13
+ if not os.path.exists(os.path.dirname(filename)):
14
+ os.makedirs(os.path.dirname(filename))
15
+ torch.save(state, filename + '.pth.tar')
16
+
17
+
18
+ def load_checkpoint(path, device):
19
+ if device is None:
20
+ return torch.load(path)
21
+ else:
22
+ return torch.load(path, map_location=device)
utils/metrics.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import PIL
4
+
5
+ def calculate_psnr(img1, img2, test_y_channel=False):
6
+ """Calculate PSNR (Peak Signal-to-Noise Ratio).
7
+
8
+ Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
9
+
10
+ Args:
11
+ img1 (ndarray): Images with range [0, 255].
12
+ img2 (ndarray): Images with range [0, 255].
13
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
14
+
15
+ Returns:
16
+ float: psnr result.
17
+ """
18
+
19
+ assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
20
+ assert img1.shape[2] == 3
21
+ img1 = img1.astype(np.float64)
22
+ img2 = img2.astype(np.float64)
23
+
24
+ if test_y_channel:
25
+ img1 = to_y_channel(img1)
26
+ img2 = to_y_channel(img2)
27
+
28
+ mse = np.mean((img1 - img2) ** 2)
29
+ if mse == 0:
30
+ return float('inf')
31
+ return 20. * np.log10(255. / np.sqrt(mse))
32
+
33
+
34
+ def _ssim(img1, img2):
35
+ """Calculate SSIM (structural similarity) for one channel images.
36
+
37
+ It is called by func:`calculate_ssim`.
38
+
39
+ Args:
40
+ img1 (ndarray): Images with range [0, 255] with order 'HWC'.
41
+ img2 (ndarray): Images with range [0, 255] with order 'HWC'.
42
+
43
+ Returns:
44
+ float: ssim result.
45
+ """
46
+
47
+ C1 = (0.01 * 255) ** 2
48
+ C2 = (0.03 * 255) ** 2
49
+
50
+ img1 = img1.astype(np.float64)
51
+ img2 = img2.astype(np.float64)
52
+ kernel = cv2.getGaussianKernel(11, 1.5)
53
+ window = np.outer(kernel, kernel.transpose())
54
+
55
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
56
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
57
+ mu1_sq = mu1 ** 2
58
+ mu2_sq = mu2 ** 2
59
+ mu1_mu2 = mu1 * mu2
60
+ sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
61
+ sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
62
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
63
+
64
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
65
+ return ssim_map.mean()
66
+
67
+
68
+ def calculate_ssim(img1, img2, test_y_channel=False):
69
+ """Calculate SSIM (structural similarity).
70
+
71
+ Ref:
72
+ Image quality assessment: From error visibility to structural similarity
73
+
74
+ The results are the same as that of the official released MATLAB code in
75
+ https://ece.uwaterloo.ca/~z70wang/research/ssim/.
76
+
77
+ For three-channel images, SSIM is calculated for each channel and then
78
+ averaged.
79
+
80
+ Args:
81
+ img1 (ndarray): Images with range [0, 255].
82
+ img2 (ndarray): Images with range [0, 255].
83
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
84
+
85
+ Returns:
86
+ float: ssim result.
87
+ """
88
+
89
+ assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
90
+ assert img1.shape[2] == 3
91
+ img1 = img1.astype(np.float64)
92
+ img2 = img2.astype(np.float64)
93
+
94
+ if test_y_channel:
95
+ img1 = to_y_channel(img1)
96
+ img2 = to_y_channel(img2)
97
+
98
+ ssims = []
99
+ for i in range(img1.shape[2]):
100
+ ssims.append(_ssim(img1[..., i], img2[..., i]))
101
+ return np.array(ssims).mean()
102
+
103
+
104
+ def to_y_channel(img):
105
+ """Change to Y channel of YCbCr.
106
+
107
+ Args:
108
+ img (ndarray): Images with range [0, 255].
109
+
110
+ Returns:
111
+ (ndarray): Images with range [0, 255] (float type) without round.
112
+ """
113
+ img = img.astype(np.float32) / 255.
114
+ if img.ndim == 3 and img.shape[2] == 3:
115
+ img = bgr2ycbcr(img, y_only=True)
116
+ img = img[..., None]
117
+ return img * 255.
118
+
119
+
120
+ def _convert_input_type_range(img):
121
+ """Convert the type and range of the input image.
122
+
123
+ It converts the input image to np.float32 type and range of [0, 1].
124
+ It is mainly used for pre-processing the input image in colorspace
125
+ convertion functions such as rgb2ycbcr and ycbcr2rgb.
126
+
127
+ Args:
128
+ img (ndarray): The input image. It accepts:
129
+ 1. np.uint8 type with range [0, 255];
130
+ 2. np.float32 type with range [0, 1].
131
+
132
+ Returns:
133
+ (ndarray): The converted image with type of np.float32 and range of
134
+ [0, 1].
135
+ """
136
+ img_type = img.dtype
137
+ img = img.astype(np.float32)
138
+ if img_type == np.float32:
139
+ pass
140
+ elif img_type == np.uint8:
141
+ img /= 255.
142
+ else:
143
+ raise TypeError('The img type should be np.float32 or np.uint8, ' f'but got {img_type}')
144
+ return img
145
+
146
+
147
+ def _convert_output_type_range(img, dst_type):
148
+ """Convert the type and range of the image according to dst_type.
149
+
150
+ It converts the image to desired type and range. If `dst_type` is np.uint8,
151
+ images will be converted to np.uint8 type with range [0, 255]. If
152
+ `dst_type` is np.float32, it converts the image to np.float32 type with
153
+ range [0, 1].
154
+ It is mainly used for post-processing images in colorspace convertion
155
+ functions such as rgb2ycbcr and ycbcr2rgb.
156
+
157
+ Args:
158
+ img (ndarray): The image to be converted with np.float32 type and
159
+ range [0, 255].
160
+ dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
161
+ converts the image to np.uint8 type with range [0, 255]. If
162
+ dst_type is np.float32, it converts the image to np.float32 type
163
+ with range [0, 1].
164
+
165
+ Returns:
166
+ (ndarray): The converted image with desired type and range.
167
+ """
168
+ if dst_type not in (np.uint8, np.float32):
169
+ raise TypeError('The dst_type should be np.float32 or np.uint8, ' f'but got {dst_type}')
170
+ if dst_type == np.uint8:
171
+ img = img.round()
172
+ else:
173
+ img /= 255.
174
+ return img.astype(dst_type)
175
+
176
+
177
+ def bgr2ycbcr(img, y_only=False):
178
+ """Convert a BGR image to YCbCr image.
179
+
180
+ The bgr version of rgb2ycbcr.
181
+ It implements the ITU-R BT.601 conversion for standard-definition
182
+ television. See more details in
183
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
184
+
185
+ It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
186
+ In OpenCV, it implements a JPEG conversion. See more details in
187
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
188
+
189
+ Args:
190
+ img (ndarray): The input image. It accepts:
191
+ 1. np.uint8 type with range [0, 255];
192
+ 2. np.float32 type with range [0, 1].
193
+ y_only (bool): Whether to only return Y channel. Default: False.
194
+
195
+ Returns:
196
+ ndarray: The converted YCbCr image. The output image has the same type
197
+ and range as input image.
198
+ """
199
+ img_type = img.dtype
200
+ img = _convert_input_type_range(img)
201
+ if y_only:
202
+ out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
203
+ else:
204
+ out_img = np.matmul(
205
+ img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128]
206
+ out_img = _convert_output_type_range(out_img, img_type)
207
+ return out_img
utils/optimize.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.optim as optim
2
+
3
+
4
+ def get_optimizer(config, parameters):
5
+ if config.optim.optimizer == 'Adam':
6
+ return optim.Adam(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay,
7
+ betas=(0.9, 0.999), amsgrad=config.optim.amsgrad, eps=config.optim.eps)
8
+ elif config.optim.optimizer == 'RMSProp':
9
+ return optim.RMSprop(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay)
10
+ elif config.optim.optimizer == 'SGD':
11
+ return optim.SGD(parameters, lr=config.optim.lr, momentum=0.9)
12
+ else:
13
+ raise NotImplementedError('Optimizer {} not understood.'.format(config.optim.optimizer))
utils/sampling.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision.transforms.functional import crop
3
+
4
+
5
+ def compute_alpha(beta, t):
6
+ beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
7
+ a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
8
+ return a
9
+
10
+
11
+ def data_transform(X):
12
+ return 2 * X - 1.0
13
+
14
+
15
+ def inverse_data_transform(X):
16
+ return torch.clamp((X + 1.0) / 2.0, 0.0, 1.0)
17
+
18
+
19
+ def generalized_steps(x, x_cond, seq, model, b, eta=0., device=None):
20
+ with torch.no_grad():
21
+ n = x.size(0)
22
+ seq_next = [-1] + list(seq[:-1])
23
+ x0_preds = []
24
+ xs = [x]
25
+ for i, j in zip(reversed(seq), reversed(seq_next)):
26
+ t = (torch.ones(n) * i).to(x.device)
27
+ next_t = (torch.ones(n) * j).to(x.device)
28
+ at = compute_alpha(b, t.long())
29
+ at_next = compute_alpha(b, next_t.long())
30
+ xt = xs[-1].to(device)
31
+
32
+ et = model(torch.cat([x_cond, xt], dim=1), t)
33
+ x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
34
+ x0_preds.append(x0_t.to(device))
35
+
36
+ c1 = eta * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt()
37
+ c2 = ((1 - at_next) - c1 ** 2).sqrt()
38
+ xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et
39
+ xs.append(xt_next.to(device))
40
+ return xs, x0_preds
41
+
42
+
43
+ def generalized_steps_overlapping(x, x_cond, seq, model, b, eta=0., corners=None, p_size=None, manual_batching=True,
44
+ device=None):
45
+ with torch.no_grad():
46
+ n = x.size(0)
47
+ seq_next = [-1] + list(seq[:-1])
48
+ x0_preds = []
49
+ xs = [x]
50
+
51
+ x_grid_mask = torch.zeros_like(x_cond, device=x.device)
52
+ for (hi, wi) in corners:
53
+ x_grid_mask[:, :, hi:hi + p_size, wi:wi + p_size] += 1
54
+
55
+ for i, j in zip(reversed(seq), reversed(seq_next)):
56
+ t = (torch.ones(n) * i).to(x.device)
57
+ next_t = (torch.ones(n) * j).to(x.device)
58
+ at = compute_alpha(b, t.long())
59
+ at_next = compute_alpha(b, next_t.long())
60
+ xt = xs[-1].to(device)
61
+ et_output = torch.zeros_like(x_cond, device=x.device)
62
+
63
+ if manual_batching:
64
+ manual_batching_size = 64
65
+ xt_patch = torch.cat([crop(xt, hi, wi, p_size, p_size) for (hi, wi) in corners], dim=0)
66
+ x_cond_patch = torch.cat([data_transform(crop(x_cond, hi, wi, p_size, p_size)) for (hi, wi) in corners],
67
+ dim=0)
68
+ for i in range(0, len(corners), manual_batching_size):
69
+ outputs = model(torch.cat([x_cond_patch[i:i + manual_batching_size],
70
+ xt_patch[i:i + manual_batching_size]], dim=1), t)
71
+ for idx, (hi, wi) in enumerate(corners[i:i + manual_batching_size]):
72
+ et_output[0, :, hi:hi + p_size, wi:wi + p_size] += outputs[idx]
73
+ else:
74
+ for (hi, wi) in corners:
75
+ xt_patch = crop(xt, hi, wi, p_size, p_size)
76
+ x_cond_patch = crop(x_cond, hi, wi, p_size, p_size)
77
+ x_cond_patch = data_transform(x_cond_patch)
78
+ et_output[:, :, hi:hi + p_size, wi:wi + p_size] += model(torch.cat([x_cond_patch, xt_patch], dim=1),
79
+ t)
80
+
81
+ et = torch.div(et_output, x_grid_mask)
82
+ x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
83
+ x0_preds.append(x0_t.to(device))
84
+
85
+ c1 = eta * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt()
86
+ c2 = ((1 - at_next) - c1 ** 2).sqrt()
87
+ xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et
88
+ xs.append(xt_next.to(device))
89
+ return xs, x0_preds