XiangZ commited on
Commit
5c7484a
·
verified ·
1 Parent(s): 1233664

Update hit_sir_arch.py

Browse files
Files changed (1) hide show
  1. hit_sir_arch.py +898 -900
hit_sir_arch.py CHANGED
@@ -1,900 +1,898 @@
1
- import math
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- import torch.utils.checkpoint as checkpoint
6
- from timm.models.layers import DropPath, to_2tuple, trunc_normal_
7
-
8
- import numpy as np
9
- from huggingface_hub import PyTorchModelHubMixin
10
- from utils import FileClient, imfrombytes, img2tensor, tensor2img
11
-
12
- class DFE(nn.Module):
13
- """ Dual Feature Extraction
14
- Args:
15
- in_features (int): Number of input channels.
16
- out_features (int): Number of output channels.
17
- """
18
- def __init__(self, in_features, out_features):
19
- super().__init__()
20
-
21
- self.out_features = out_features
22
-
23
- self.conv = nn.Sequential(nn.Conv2d(in_features, in_features // 5, 1, 1, 0),
24
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
25
- nn.Conv2d(in_features // 5, in_features // 5, 3, 1, 1),
26
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
27
- nn.Conv2d(in_features // 5, out_features, 1, 1, 0))
28
-
29
- self.linear = nn.Conv2d(in_features, out_features,1,1,0)
30
-
31
- def forward(self, x, x_size):
32
-
33
- B, L, C = x.shape
34
- H, W = x_size
35
- x = x.permute(0, 2, 1).contiguous().view(B, C, H, W)
36
- x = self.conv(x) * self.linear(x)
37
- x = x.view(B, -1, H*W).permute(0,2,1).contiguous()
38
-
39
- return x
40
-
41
- class Mlp(nn.Module):
42
- """ MLP-based Feed-Forward Network
43
- Args:
44
- in_features (int): Number of input channels.
45
- hidden_features (int | None): Number of hidden channels. Default: None
46
- out_features (int | None): Number of output channels. Default: None
47
- act_layer (nn.Module): Activation layer. Default: nn.GELU
48
- drop (float): Dropout rate. Default: 0.0
49
- """
50
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
51
- super().__init__()
52
- out_features = out_features or in_features
53
- hidden_features = hidden_features or in_features
54
- self.fc1 = nn.Linear(in_features, hidden_features)
55
- self.act = act_layer()
56
- self.fc2 = nn.Linear(hidden_features, out_features)
57
- self.drop = nn.Dropout(drop)
58
-
59
- def forward(self, x):
60
- x = self.fc1(x)
61
- x = self.act(x)
62
- x = self.drop(x)
63
- x = self.fc2(x)
64
- x = self.drop(x)
65
- return x
66
-
67
-
68
- def window_partition(x, window_size):
69
- """
70
- Args:
71
- x: (B, H, W, C)
72
- window_size (tuple): window size
73
-
74
- Returns:
75
- windows: (num_windows*B, window_size, window_size, C)
76
- """
77
- B, H, W, C = x.shape
78
- x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
79
- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
80
- return windows
81
-
82
-
83
- def window_reverse(windows, window_size, H, W):
84
- """
85
- Args:
86
- windows: (num_windows*B, window_size, window_size, C)
87
- window_size (tuple): Window size
88
- H (int): Height of image
89
- W (int): Width of image
90
-
91
- Returns:
92
- x: (B, H, W, C)
93
- """
94
- B = int(windows.shape[0] * (window_size[0] * window_size[1]) / (H * W))
95
- x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1)
96
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
97
- return x
98
-
99
- class DynamicPosBias(nn.Module):
100
- # The implementation builds on Crossformer code https://github.com/cheerss/CrossFormer/blob/main/models/crossformer.py
101
- """ Dynamic Relative Position Bias.
102
- Args:
103
- dim (int): Number of input channels.
104
- num_heads (int): Number of heads for spatial self-correlation.
105
- residual (bool): If True, use residual strage to connect conv.
106
- """
107
- def __init__(self, dim, num_heads, residual):
108
- super().__init__()
109
- self.residual = residual
110
- self.num_heads = num_heads
111
- self.pos_dim = dim // 4
112
- self.pos_proj = nn.Linear(2, self.pos_dim)
113
- self.pos1 = nn.Sequential(
114
- nn.LayerNorm(self.pos_dim),
115
- nn.ReLU(inplace=True),
116
- nn.Linear(self.pos_dim, self.pos_dim),
117
- )
118
- self.pos2 = nn.Sequential(
119
- nn.LayerNorm(self.pos_dim),
120
- nn.ReLU(inplace=True),
121
- nn.Linear(self.pos_dim, self.pos_dim)
122
- )
123
- self.pos3 = nn.Sequential(
124
- nn.LayerNorm(self.pos_dim),
125
- nn.ReLU(inplace=True),
126
- nn.Linear(self.pos_dim, self.num_heads)
127
- )
128
- def forward(self, biases):
129
- if self.residual:
130
- pos = self.pos_proj(biases) # 2Gh-1 * 2Gw-1, heads
131
- pos = pos + self.pos1(pos)
132
- pos = pos + self.pos2(pos)
133
- pos = self.pos3(pos)
134
- else:
135
- pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases))))
136
- return pos
137
-
138
- class SCC(nn.Module):
139
- """ Spatial-Channel Correlation.
140
- Args:
141
- dim (int): Number of input channels.
142
- base_win_size (tuple[int]): The height and width of the base window.
143
- window_size (tuple[int]): The height and width of the window.
144
- num_heads (int): Number of heads for spatial self-correlation.
145
- value_drop (float, optional): Dropout ratio of value. Default: 0.0
146
- proj_drop (float, optional): Dropout ratio of output. Default: 0.0
147
- """
148
-
149
- def __init__(self, dim, base_win_size, window_size, num_heads, value_drop=0., proj_drop=0.):
150
-
151
- super().__init__()
152
- # parameters
153
- self.dim = dim
154
- self.window_size = window_size
155
- self.num_heads = num_heads
156
-
157
- # feature projection
158
- self.qv = DFE(dim, dim)
159
- self.proj = nn.Linear(dim, dim)
160
-
161
- # dropout
162
- self.value_drop = nn.Dropout(value_drop)
163
- self.proj_drop = nn.Dropout(proj_drop)
164
-
165
- # base window size
166
- min_h = min(self.window_size[0], base_win_size[0])
167
- min_w = min(self.window_size[1], base_win_size[1])
168
- self.base_win_size = (min_h, min_w)
169
-
170
- # normalization factor and spatial linear layer for S-SC
171
- head_dim = dim // (2*num_heads)
172
- self.scale = head_dim
173
- self.spatial_linear = nn.Linear(self.window_size[0]*self.window_size[1] // (self.base_win_size[0]*self.base_win_size[1]), 1)
174
-
175
- # define a parameter table of relative position bias
176
- self.H_sp, self.W_sp = self.window_size
177
- self.pos = DynamicPosBias(self.dim // 4, self.num_heads, residual=False)
178
-
179
- def spatial_linear_projection(self, x):
180
- B, num_h, L, C = x.shape
181
- H, W = self.window_size
182
- map_H, map_W = self.base_win_size
183
-
184
- x = x.view(B, num_h, map_H, H//map_H, map_W, W//map_W, C).permute(0,1,2,4,6,3,5).contiguous().view(B, num_h, map_H*map_W, C, -1)
185
- x = self.spatial_linear(x).view(B, num_h, map_H*map_W, C)
186
- return x
187
-
188
- def spatial_self_correlation(self, q, v):
189
-
190
- B, num_head, L, C = q.shape
191
-
192
- # spatial projection
193
- v = self.spatial_linear_projection(v)
194
-
195
- # compute correlation map
196
- corr_map = (q @ v.transpose(-2,-1)) / self.scale
197
-
198
- # add relative position bias
199
- # generate mother-set
200
- position_bias_h = torch.arange(1 - self.H_sp, self.H_sp, device=v.device)
201
- position_bias_w = torch.arange(1 - self.W_sp, self.W_sp, device=v.device)
202
- biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w]))
203
- rpe_biases = biases.flatten(1).transpose(0, 1).contiguous().float()
204
- pos = self.pos(rpe_biases)
205
-
206
- # select position bias
207
- coords_h = torch.arange(self.H_sp, device=v.device)
208
- coords_w = torch.arange(self.W_sp, device=v.device)
209
- coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
210
- coords_flatten = torch.flatten(coords, 1)
211
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
212
- relative_coords = relative_coords.permute(1, 2, 0).contiguous()
213
- relative_coords[:, :, 0] += self.H_sp - 1
214
- relative_coords[:, :, 1] += self.W_sp - 1
215
- relative_coords[:, :, 0] *= 2 * self.W_sp - 1
216
- relative_position_index = relative_coords.sum(-1)
217
- relative_position_bias = pos[relative_position_index.view(-1)].view(
218
- self.window_size[0] * self.window_size[1], self.base_win_size[0], self.window_size[0]//self.base_win_size[0], self.base_win_size[1], self.window_size[1]//self.base_win_size[1], -1) # Wh*Ww,Wh*Ww,nH
219
- relative_position_bias = relative_position_bias.permute(0,1,3,5,2,4).contiguous().view(
220
- self.window_size[0] * self.window_size[1], self.base_win_size[0]*self.base_win_size[1], self.num_heads, -1).mean(-1)
221
- relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
222
- corr_map = corr_map + relative_position_bias.unsqueeze(0)
223
-
224
- # transformation
225
- v_drop = self.value_drop(v)
226
- x = (corr_map @ v_drop).permute(0,2,1,3).contiguous().view(B, L, -1)
227
-
228
- return x
229
-
230
- def channel_self_correlation(self, q, v):
231
-
232
- B, num_head, L, C = q.shape
233
-
234
- # apply single head strategy
235
- q = q.permute(0,2,1,3).contiguous().view(B, L, num_head*C)
236
- v = v.permute(0,2,1,3).contiguous().view(B, L, num_head*C)
237
-
238
- # compute correlation map
239
- corr_map = (q.transpose(-2,-1) @ v) / L
240
-
241
- # transformation
242
- v_drop = self.value_drop(v)
243
- x = (corr_map @ v_drop.transpose(-2,-1)).permute(0,2,1).contiguous().view(B, L, -1)
244
-
245
- return x
246
-
247
- def forward(self, x):
248
- """
249
- Args:
250
- x: input features with shape of (B, H, W, C)
251
- """
252
- xB,xH,xW,xC = x.shape
253
- qv = self.qv(x.view(xB,-1,xC), (xH,xW)).view(xB, xH, xW, xC)
254
-
255
- # window partition
256
- qv = window_partition(qv, self.window_size)
257
- qv = qv.view(-1, self.window_size[0]*self.window_size[1], xC)
258
-
259
- # qv splitting
260
- B, L, C = qv.shape
261
- qv = qv.view(B, L, 2, self.num_heads, C // (2*self.num_heads)).permute(2,0,3,1,4).contiguous()
262
- q, v = qv[0], qv[1] # B, num_heads, L, C//num_heads
263
-
264
- # spatial self-correlation (S-SC)
265
- x_spatial = self.spatial_self_correlation(q, v)
266
- x_spatial = x_spatial.view(-1, self.window_size[0], self.window_size[1], C//2)
267
- x_spatial = window_reverse(x_spatial, (self.window_size[0],self.window_size[1]), xH, xW) # xB xH xW xC
268
-
269
- # channel self-correlation (C-SC)
270
- x_channel = self.channel_self_correlation(q, v)
271
- x_channel = x_channel.view(-1, self.window_size[0], self.window_size[1], C//2)
272
- x_channel = window_reverse(x_channel, (self.window_size[0], self.window_size[1]), xH, xW) # xB xH xW xC
273
-
274
- # spatial-channel information fusion
275
- x = torch.cat([x_spatial, x_channel], -1)
276
- x = self.proj_drop(self.proj(x))
277
-
278
- return x
279
-
280
- def extra_repr(self) -> str:
281
- return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
282
-
283
-
284
- class HierarchicalTransformerBlock(nn.Module):
285
- """ Hierarchical Transformer Block.
286
- Args:
287
- dim (int): Number of input channels.
288
- input_resolution (tuple[int]): Input resulotion.
289
- num_heads (int): Number of heads for spatial self-correlation.
290
- base_win_size (tuple[int]): The height and width of the base window.
291
- window_size (tuple[int]): The height and width of the window.
292
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
293
- drop (float, optional): Dropout rate. Default: 0.0
294
- value_drop (float, optional): Dropout ratio of value. Default: 0.0
295
- drop_path (float, optional): Stochastic depth rate. Default: 0.0
296
- act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
297
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
298
- """
299
-
300
- def __init__(self, dim, input_resolution, num_heads, base_win_size, window_size,
301
- mlp_ratio=4., drop=0., value_drop=0., drop_path=0.,
302
- act_layer=nn.GELU, norm_layer=nn.LayerNorm):
303
- super().__init__()
304
- self.dim = dim
305
- self.input_resolution = input_resolution
306
- self.num_heads = num_heads
307
- self.window_size = window_size
308
- self.mlp_ratio = mlp_ratio
309
-
310
- # check window size
311
- if (window_size[0] > base_win_size[0]) and (window_size[1] > base_win_size[1]):
312
- assert window_size[0] % base_win_size[0] == 0, "please ensure the window size is smaller than or divisible by the base window size"
313
- assert window_size[1] % base_win_size[1] == 0, "please ensure the window size is smaller than or divisible by the base window size"
314
-
315
-
316
- self.norm1 = norm_layer(dim)
317
- self.correlation = SCC(
318
- dim, base_win_size=base_win_size, window_size=self.window_size, num_heads=num_heads,
319
- value_drop=value_drop, proj_drop=drop)
320
-
321
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
322
- self.norm2 = norm_layer(dim)
323
- mlp_hidden_dim = int(dim * mlp_ratio)
324
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
325
-
326
- def check_image_size(self, x, win_size):
327
- x = x.permute(0,3,1,2).contiguous()
328
- _, _, h, w = x.size()
329
- mod_pad_h = (win_size[0] - h % win_size[0]) % win_size[0]
330
- mod_pad_w = (win_size[1] - w % win_size[1]) % win_size[1]
331
- x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
332
- x = x.permute(0,2,3,1).contiguous()
333
- return x
334
-
335
- def forward(self, x, x_size, win_size):
336
- H, W = x_size
337
- B, L, C = x.shape
338
-
339
- shortcut = x
340
- x = x.view(B, H, W, C)
341
-
342
- # padding
343
- x = self.check_image_size(x, win_size)
344
- _, H_pad, W_pad, _ = x.shape # shape after padding
345
-
346
- x = self.correlation(x)
347
-
348
- # unpad
349
- x = x[:, :H, :W, :].contiguous()
350
-
351
- # norm
352
- x = x.view(B, H * W, C)
353
- x = self.norm1(x)
354
-
355
- # FFN
356
- x = shortcut + self.drop_path(x)
357
- x = x + self.drop_path(self.norm2(self.mlp(x)))
358
-
359
- return x
360
-
361
- def extra_repr(self) -> str:
362
- return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
363
- f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}"
364
-
365
-
366
- class PatchMerging(nn.Module):
367
- """ Patch Merging Layer.
368
- Args:
369
- input_resolution (tuple[int]): Resolution of input feature.
370
- dim (int): Number of input channels.
371
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
372
- """
373
-
374
- def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
375
- super().__init__()
376
- self.input_resolution = input_resolution
377
- self.dim = dim
378
- self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
379
- self.norm = norm_layer(4 * dim)
380
-
381
- def forward(self, x):
382
- """
383
- x: B, H*W, C
384
- """
385
- H, W = self.input_resolution
386
- B, L, C = x.shape
387
- assert L == H * W, "input feature has wrong size"
388
- assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
389
-
390
- x = x.view(B, H, W, C)
391
-
392
- x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
393
- x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
394
- x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
395
- x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
396
- x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
397
- x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
398
-
399
- x = self.norm(x)
400
- x = self.reduction(x)
401
-
402
- return x
403
-
404
- def extra_repr(self) -> str:
405
- return f"input_resolution={self.input_resolution}, dim={self.dim}"
406
-
407
-
408
- class BasicLayer(nn.Module):
409
- """ A basic Hierarchical Transformer layer for one stage.
410
-
411
- Args:
412
- dim (int): Number of input channels.
413
- input_resolution (tuple[int]): Input resolution.
414
- depth (int): Number of blocks.
415
- num_heads (int): Number of heads for spatial self-correlation.
416
- base_win_size (tuple[int]): The height and width of the base window.
417
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
418
- drop (float, optional): Dropout rate. Default: 0.0
419
- value_drop (float, optional): Dropout ratio of value. Default: 0.0
420
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
421
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
422
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
423
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
424
- hier_win_ratios (list): hierarchical window ratios for a transformer block. Default: [0.5,1,2,4,6,8].
425
- """
426
-
427
- def __init__(self, dim, input_resolution, depth, num_heads, base_win_size,
428
- mlp_ratio=4., drop=0., value_drop=0.,drop_path=0., norm_layer=nn.LayerNorm,
429
- downsample=None, use_checkpoint=False, hier_win_ratios=[0.5,1,2,4,6,8]):
430
-
431
- super().__init__()
432
- self.dim = dim
433
- self.input_resolution = input_resolution
434
- self.depth = depth
435
- self.use_checkpoint = use_checkpoint
436
-
437
- self.win_hs = [int(base_win_size[0] * ratio) for ratio in hier_win_ratios]
438
- self.win_ws = [int(base_win_size[1] * ratio) for ratio in hier_win_ratios]
439
-
440
- # build blocks
441
- self.blocks = nn.ModuleList([
442
- HierarchicalTransformerBlock(dim=dim, input_resolution=input_resolution,
443
- num_heads=num_heads,
444
- base_win_size=base_win_size,
445
- window_size=(self.win_hs[i], self.win_ws[i]),
446
- mlp_ratio=mlp_ratio,
447
- drop=drop, value_drop=value_drop,
448
- drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
449
- norm_layer=norm_layer)
450
- for i in range(depth)])
451
-
452
- # patch merging layer
453
- if downsample is not None:
454
- self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
455
- else:
456
- self.downsample = None
457
-
458
- def forward(self, x, x_size):
459
-
460
- i = 0
461
- for blk in self.blocks:
462
- if self.use_checkpoint:
463
- x = checkpoint.checkpoint(blk, x, x_size, (self.win_hs[i], self.win_ws[i]))
464
- else:
465
- x = blk(x, x_size, (self.win_hs[i], self.win_ws[i]))
466
- i = i + 1
467
-
468
- if self.downsample is not None:
469
- x = self.downsample(x)
470
- return x
471
-
472
- def extra_repr(self) -> str:
473
- return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
474
-
475
-
476
- class RHTB(nn.Module):
477
- """Residual Hierarchical Transformer Block (RHTB).
478
- Args:
479
- dim (int): Number of input channels.
480
- input_resolution (tuple[int]): Input resolution.
481
- depth (int): Number of blocks.
482
- num_heads (int): Number of heads for spatial self-correlation.
483
- base_win_size (tuple[int]): The height and width of the base window.
484
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
485
- drop (float, optional): Dropout rate. Default: 0.0
486
- value_drop (float, optional): Dropout ratio of value. Default: 0.0
487
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
488
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
489
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
490
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
491
- img_size: Input image size.
492
- patch_size: Patch size.
493
- resi_connection: The convolutional block before residual connection.
494
- hier_win_ratios (list): hierarchical window ratios for a transformer block. Default: [0.5,1,2,4,6,8].
495
- """
496
-
497
- def __init__(self, dim, input_resolution, depth, num_heads, base_win_size,
498
- mlp_ratio=4., drop=0., value_drop=0., drop_path=0., norm_layer=nn.LayerNorm,
499
- downsample=None, use_checkpoint=False, img_size=224, patch_size=4,
500
- resi_connection='1conv', hier_win_ratios=[0.5,1,2,4,6,8]):
501
- super(RHTB, self).__init__()
502
-
503
- self.dim = dim
504
- self.input_resolution = input_resolution
505
-
506
- self.residual_group = BasicLayer(dim=dim,
507
- input_resolution=input_resolution,
508
- depth=depth,
509
- num_heads=num_heads,
510
- base_win_size=base_win_size,
511
- mlp_ratio=mlp_ratio,
512
- drop=drop, value_drop=value_drop,
513
- drop_path=drop_path,
514
- norm_layer=norm_layer,
515
- downsample=downsample,
516
- use_checkpoint=use_checkpoint,
517
- hier_win_ratios=hier_win_ratios)
518
-
519
- if resi_connection == '1conv':
520
- self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
521
- elif resi_connection == '3conv':
522
- # to save parameters and memory
523
- self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
524
- nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
525
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
526
- nn.Conv2d(dim // 4, dim, 3, 1, 1))
527
-
528
- self.patch_embed = PatchEmbed(
529
- img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
530
- norm_layer=None)
531
-
532
- self.patch_unembed = PatchUnEmbed(
533
- img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
534
- norm_layer=None)
535
-
536
- def forward(self, x, x_size):
537
- return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
538
-
539
-
540
- class PatchEmbed(nn.Module):
541
- r""" Image to Patch Embedding
542
-
543
- Args:
544
- img_size (int): Image size. Default: 224.
545
- patch_size (int): Patch token size. Default: 4.
546
- in_chans (int): Number of input image channels. Default: 3.
547
- embed_dim (int): Number of linear projection output channels. Default: 96.
548
- norm_layer (nn.Module, optional): Normalization layer. Default: None
549
- """
550
-
551
- def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
552
- super().__init__()
553
- img_size = to_2tuple(img_size)
554
- patch_size = to_2tuple(patch_size)
555
- patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
556
- self.img_size = img_size
557
- self.patch_size = patch_size
558
- self.patches_resolution = patches_resolution
559
- self.num_patches = patches_resolution[0] * patches_resolution[1]
560
-
561
- self.in_chans = in_chans
562
- self.embed_dim = embed_dim
563
-
564
- if norm_layer is not None:
565
- self.norm = norm_layer(embed_dim)
566
- else:
567
- self.norm = None
568
-
569
- def forward(self, x):
570
- x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
571
- if self.norm is not None:
572
- x = self.norm(x)
573
- return x
574
-
575
-
576
- class PatchUnEmbed(nn.Module):
577
- r""" Image to Patch Unembedding
578
-
579
- Args:
580
- img_size (int): Image size. Default: 224.
581
- patch_size (int): Patch token size. Default: 4.
582
- in_chans (int): Number of input image channels. Default: 3.
583
- embed_dim (int): Number of linear projection output channels. Default: 96.
584
- norm_layer (nn.Module, optional): Normalization layer. Default: None
585
- """
586
-
587
- def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
588
- super().__init__()
589
- img_size = to_2tuple(img_size)
590
- patch_size = to_2tuple(patch_size)
591
- patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
592
- self.img_size = img_size
593
- self.patch_size = patch_size
594
- self.patches_resolution = patches_resolution
595
- self.num_patches = patches_resolution[0] * patches_resolution[1]
596
-
597
- self.in_chans = in_chans
598
- self.embed_dim = embed_dim
599
-
600
- def forward(self, x, x_size):
601
- B, HW, C = x.shape
602
- x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
603
- return x
604
-
605
-
606
- class Upsample(nn.Sequential):
607
- """Upsample module.
608
-
609
- Args:
610
- scale (int): Scale factor. Supported scales: 2^n and 3.
611
- num_feat (int): Channel number of intermediate features.
612
- """
613
-
614
- def __init__(self, scale, num_feat):
615
- m = []
616
- if (scale & (scale - 1)) == 0: # scale = 2^n
617
- for _ in range(int(math.log(scale, 2))):
618
- m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
619
- m.append(nn.PixelShuffle(2))
620
- elif scale == 3:
621
- m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
622
- m.append(nn.PixelShuffle(3))
623
- else:
624
- raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
625
- super(Upsample, self).__init__(*m)
626
-
627
-
628
- class UpsampleOneStep(nn.Sequential):
629
- """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
630
- Used in lightweight SR to save parameters.
631
-
632
- Args:
633
- scale (int): Scale factor. Supported scales: 2^n and 3.
634
- num_feat (int): Channel number of intermediate features.
635
-
636
- """
637
-
638
- def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
639
- self.num_feat = num_feat
640
- self.input_resolution = input_resolution
641
- m = []
642
- m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
643
- m.append(nn.PixelShuffle(scale))
644
- super(UpsampleOneStep, self).__init__(*m)
645
-
646
-
647
- class HiT_SIR(nn.Module, PyTorchModelHubMixin):
648
- """ HiT-SIR network.
649
-
650
- Args:
651
- img_size (int | tuple(int)): Input image size. Default 64
652
- patch_size (int | tuple(int)): Patch size. Default: 1
653
- in_chans (int): Number of input image channels. Default: 3
654
- embed_dim (int): Patch embedding dimension. Default: 96
655
- depths (tuple(int)): Depth of each Transformer block.
656
- num_heads (tuple(int)): Number of heads for spatial self-correlation in different layers.
657
- base_win_size (tuple[int]): The height and width of the base window.
658
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
659
- drop_rate (float): Dropout rate. Default: 0
660
- value_drop_rate (float): Dropout ratio of value. Default: 0.0
661
- drop_path_rate (float): Stochastic depth rate. Default: 0.1
662
- norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
663
- ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
664
- patch_norm (bool): If True, add normalization after patch embedding. Default: True
665
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
666
- upscale (int): Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
667
- img_range (float): Image range. 1. or 255.
668
- upsampler (str): The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
669
- resi_connection (str): The convolutional block before residual connection. '1conv'/'3conv'
670
- hier_win_ratios (list): hierarchical window ratios for a transformer block. Default: [0.5,1,2,4,6,8].
671
- """
672
-
673
- def __init__(self, img_size=64, patch_size=1, in_chans=3,
674
- embed_dim=60, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
675
- base_win_size=[8,8], mlp_ratio=2.,
676
- drop_rate=0., value_drop_rate=0., drop_path_rate=0.,
677
- norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
678
- use_checkpoint=False, upscale=4, img_range=1., upsampler='pixelshuffledirect', resi_connection='1conv',
679
- hier_win_ratios=[0.5,1,2,4,6,8],
680
- **kwargs):
681
- super(HiT_SIR, self).__init__()
682
- num_in_ch = in_chans
683
- num_out_ch = in_chans
684
- num_feat = 64
685
- self.img_range = img_range
686
- if in_chans == 3:
687
- rgb_mean = (0.4488, 0.4371, 0.4040)
688
- self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
689
- else:
690
- self.mean = torch.zeros(1, 1, 1, 1)
691
- self.upscale = upscale
692
- self.upsampler = upsampler
693
- self.base_win_size = base_win_size
694
-
695
- #####################################################################################################
696
- ################################### 1, shallow feature extraction ###################################
697
- self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
698
-
699
- #####################################################################################################
700
- ################################### 2, deep feature extraction ######################################
701
- self.num_layers = len(depths)
702
- self.embed_dim = embed_dim
703
- self.ape = ape
704
- self.patch_norm = patch_norm
705
- self.num_features = embed_dim
706
- self.mlp_ratio = mlp_ratio
707
-
708
- # split image into non-overlapping patches
709
- self.patch_embed = PatchEmbed(
710
- img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
711
- norm_layer=norm_layer if self.patch_norm else None)
712
- num_patches = self.patch_embed.num_patches
713
- patches_resolution = self.patch_embed.patches_resolution
714
- self.patches_resolution = patches_resolution
715
-
716
- # merge non-overlapping patches into image
717
- self.patch_unembed = PatchUnEmbed(
718
- img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
719
- norm_layer=norm_layer if self.patch_norm else None)
720
-
721
- # absolute position embedding
722
- if self.ape:
723
- self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
724
- trunc_normal_(self.absolute_pos_embed, std=.02)
725
-
726
- self.pos_drop = nn.Dropout(p=drop_rate)
727
-
728
- # stochastic depth
729
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
730
-
731
- # build Residual Hierarchical Transformer blocks (RHTB)
732
- self.layers = nn.ModuleList()
733
- for i_layer in range(self.num_layers):
734
- layer = RHTB(dim=embed_dim,
735
- input_resolution=(patches_resolution[0],
736
- patches_resolution[1]),
737
- depth=depths[i_layer],
738
- num_heads=num_heads[i_layer],
739
- base_win_size=base_win_size,
740
- mlp_ratio=self.mlp_ratio,
741
- drop=drop_rate, value_drop=value_drop_rate,
742
- drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
743
- norm_layer=norm_layer,
744
- downsample=None,
745
- use_checkpoint=use_checkpoint,
746
- img_size=img_size,
747
- patch_size=patch_size,
748
- resi_connection=resi_connection,
749
- hier_win_ratios=hier_win_ratios
750
- )
751
- self.layers.append(layer)
752
- self.norm = norm_layer(self.num_features)
753
-
754
- # build the last conv layer in deep feature extraction
755
- if resi_connection == '1conv':
756
- self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
757
- elif resi_connection == '3conv':
758
- # to save parameters and memory
759
- self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
760
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
761
- nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
762
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
763
- nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
764
-
765
- #####################################################################################################
766
- ################################ 3, high quality image reconstruction ################################
767
- if self.upsampler == 'pixelshuffle':
768
- # for classical SR
769
- self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
770
- nn.LeakyReLU(inplace=True))
771
- self.upsample = Upsample(upscale, num_feat)
772
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
773
- elif self.upsampler == 'pixelshuffledirect':
774
- # for lightweight SR (to save parameters)
775
- self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
776
- (patches_resolution[0], patches_resolution[1]))
777
- elif self.upsampler == 'nearest+conv':
778
- # for real-world SR (less artifacts)
779
- assert self.upscale == 4, 'only support x4 now.'
780
- self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
781
- nn.LeakyReLU(inplace=True))
782
- self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
783
- self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
784
- self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
785
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
786
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
787
- else:
788
- # for image denoising and JPEG compression artifact reduction
789
- self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
790
-
791
- self.apply(self._init_weights)
792
-
793
- def _init_weights(self, m):
794
- if isinstance(m, nn.Linear):
795
- trunc_normal_(m.weight, std=.02)
796
- if isinstance(m, nn.Linear) and m.bias is not None:
797
- nn.init.constant_(m.bias, 0)
798
- elif isinstance(m, nn.LayerNorm):
799
- nn.init.constant_(m.bias, 0)
800
- nn.init.constant_(m.weight, 1.0)
801
-
802
- @torch.jit.ignore
803
- def no_weight_decay(self):
804
- return {'absolute_pos_embed'}
805
-
806
- @torch.jit.ignore
807
- def no_weight_decay_keywords(self):
808
- return {'relative_position_bias_table'}
809
-
810
-
811
- def forward_features(self, x):
812
- x_size = (x.shape[2], x.shape[3])
813
- x = self.patch_embed(x)
814
- if self.ape:
815
- x = x + self.absolute_pos_embed
816
- x = self.pos_drop(x)
817
-
818
- for layer in self.layers:
819
- x = layer(x, x_size)
820
-
821
- x = self.norm(x) # B L C
822
- x = self.patch_unembed(x, x_size)
823
-
824
- return x
825
-
826
- def infer_image(self, image_path, cuda=True):
827
-
828
- io_backend_opt = {'type':'disk'}
829
- self.file_client = FileClient(io_backend_opt.pop('type'), **io_backend_opt)
830
-
831
- # load lq image
832
- lq_path = image_path
833
- img_bytes = self.file_client.get(lq_path, 'lq')
834
- img_lq = imfrombytes(img_bytes, float32=True)
835
-
836
- # BGR to RGB, HWC to CHW, numpy to tensor
837
- x = img2tensor(img_lq, bgr2rgb=True, float32=True)[None,...]
838
-
839
- if cuda:
840
- x= x.cuda()
841
-
842
- out = self(x)
843
-
844
- if cuda:
845
- out = out.cpu()
846
-
847
- out = tensor2img(out)
848
-
849
- return out
850
-
851
- def forward(self, x):
852
- H, W = x.shape[2:]
853
-
854
- self.mean = self.mean.type_as(x)
855
- x = (x - self.mean) * self.img_range
856
-
857
- if self.upsampler == 'pixelshuffle':
858
- # for classical SR
859
- x = self.conv_first(x)
860
- x = self.conv_after_body(self.forward_features(x)) + x
861
- x = self.conv_before_upsample(x)
862
- x = self.conv_last(self.upsample(x))
863
- elif self.upsampler == 'pixelshuffledirect':
864
- # for lightweight SR
865
- x = self.conv_first(x)
866
- x = self.conv_after_body(self.forward_features(x)) + x
867
- x = self.upsample(x)
868
- elif self.upsampler == 'nearest+conv':
869
- # for real-world SR
870
- x = self.conv_first(x)
871
- x = self.conv_after_body(self.forward_features(x)) + x
872
- x = self.conv_before_upsample(x)
873
- x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
874
- x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
875
- x = self.conv_last(self.lrelu(self.conv_hr(x)))
876
- else:
877
- # for image denoising and JPEG compression artifact reduction
878
- x_first = self.conv_first(x)
879
- res = self.conv_after_body(self.forward_features(x_first)) + x_first
880
- x = x + self.conv_last(res)
881
-
882
- x = x / self.img_range + self.mean
883
-
884
- return x[:, :, :H*self.upscale, :W*self.upscale]
885
-
886
-
887
- if __name__ == '__main__':
888
- upscale = 4
889
- base_win_size = [8, 8]
890
- height = (1024 // upscale // base_win_size[0] + 1) * base_win_size[0]
891
- width = (720 // upscale // base_win_size[1] + 1) * base_win_size[1]
892
-
893
- ## HiT-SIR
894
- model = HiT_SIR(upscale=4, img_size=(height, width),
895
- base_win_size=base_win_size, img_range=1., depths=[6, 6, 6, 6],
896
- embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
897
-
898
- params_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
899
- print("params: ", params_num)
900
-
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.utils.checkpoint as checkpoint
6
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
7
+
8
+ import numpy as np
9
+ from huggingface_hub import PyTorchModelHubMixin
10
+ from utils import FileClient, imfrombytes, img2tensor, tensor2img
11
+
12
+ class DFE(nn.Module):
13
+ """ Dual Feature Extraction
14
+ Args:
15
+ in_features (int): Number of input channels.
16
+ out_features (int): Number of output channels.
17
+ """
18
+ def __init__(self, in_features, out_features):
19
+ super().__init__()
20
+
21
+ self.out_features = out_features
22
+
23
+ self.conv = nn.Sequential(nn.Conv2d(in_features, in_features // 5, 1, 1, 0),
24
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
25
+ nn.Conv2d(in_features // 5, in_features // 5, 3, 1, 1),
26
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
27
+ nn.Conv2d(in_features // 5, out_features, 1, 1, 0))
28
+
29
+ self.linear = nn.Conv2d(in_features, out_features,1,1,0)
30
+
31
+ def forward(self, x, x_size):
32
+
33
+ B, L, C = x.shape
34
+ H, W = x_size
35
+ x = x.permute(0, 2, 1).contiguous().view(B, C, H, W)
36
+ x = self.conv(x) * self.linear(x)
37
+ x = x.view(B, -1, H*W).permute(0,2,1).contiguous()
38
+
39
+ return x
40
+
41
+ class Mlp(nn.Module):
42
+ """ MLP-based Feed-Forward Network
43
+ Args:
44
+ in_features (int): Number of input channels.
45
+ hidden_features (int | None): Number of hidden channels. Default: None
46
+ out_features (int | None): Number of output channels. Default: None
47
+ act_layer (nn.Module): Activation layer. Default: nn.GELU
48
+ drop (float): Dropout rate. Default: 0.0
49
+ """
50
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
51
+ super().__init__()
52
+ out_features = out_features or in_features
53
+ hidden_features = hidden_features or in_features
54
+ self.fc1 = nn.Linear(in_features, hidden_features)
55
+ self.act = act_layer()
56
+ self.fc2 = nn.Linear(hidden_features, out_features)
57
+ self.drop = nn.Dropout(drop)
58
+
59
+ def forward(self, x):
60
+ x = self.fc1(x)
61
+ x = self.act(x)
62
+ x = self.drop(x)
63
+ x = self.fc2(x)
64
+ x = self.drop(x)
65
+ return x
66
+
67
+
68
+ def window_partition(x, window_size):
69
+ """
70
+ Args:
71
+ x: (B, H, W, C)
72
+ window_size (tuple): window size
73
+
74
+ Returns:
75
+ windows: (num_windows*B, window_size, window_size, C)
76
+ """
77
+ B, H, W, C = x.shape
78
+ x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
79
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
80
+ return windows
81
+
82
+
83
+ def window_reverse(windows, window_size, H, W):
84
+ """
85
+ Args:
86
+ windows: (num_windows*B, window_size, window_size, C)
87
+ window_size (tuple): Window size
88
+ H (int): Height of image
89
+ W (int): Width of image
90
+
91
+ Returns:
92
+ x: (B, H, W, C)
93
+ """
94
+ B = int(windows.shape[0] * (window_size[0] * window_size[1]) / (H * W))
95
+ x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1)
96
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
97
+ return x
98
+
99
+ class DynamicPosBias(nn.Module):
100
+ # The implementation builds on Crossformer code https://github.com/cheerss/CrossFormer/blob/main/models/crossformer.py
101
+ """ Dynamic Relative Position Bias.
102
+ Args:
103
+ dim (int): Number of input channels.
104
+ num_heads (int): Number of heads for spatial self-correlation.
105
+ residual (bool): If True, use residual strage to connect conv.
106
+ """
107
+ def __init__(self, dim, num_heads, residual):
108
+ super().__init__()
109
+ self.residual = residual
110
+ self.num_heads = num_heads
111
+ self.pos_dim = dim // 4
112
+ self.pos_proj = nn.Linear(2, self.pos_dim)
113
+ self.pos1 = nn.Sequential(
114
+ nn.LayerNorm(self.pos_dim),
115
+ nn.ReLU(inplace=True),
116
+ nn.Linear(self.pos_dim, self.pos_dim),
117
+ )
118
+ self.pos2 = nn.Sequential(
119
+ nn.LayerNorm(self.pos_dim),
120
+ nn.ReLU(inplace=True),
121
+ nn.Linear(self.pos_dim, self.pos_dim)
122
+ )
123
+ self.pos3 = nn.Sequential(
124
+ nn.LayerNorm(self.pos_dim),
125
+ nn.ReLU(inplace=True),
126
+ nn.Linear(self.pos_dim, self.num_heads)
127
+ )
128
+ def forward(self, biases):
129
+ if self.residual:
130
+ pos = self.pos_proj(biases) # 2Gh-1 * 2Gw-1, heads
131
+ pos = pos + self.pos1(pos)
132
+ pos = pos + self.pos2(pos)
133
+ pos = self.pos3(pos)
134
+ else:
135
+ pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases))))
136
+ return pos
137
+
138
+ class SCC(nn.Module):
139
+ """ Spatial-Channel Correlation.
140
+ Args:
141
+ dim (int): Number of input channels.
142
+ base_win_size (tuple[int]): The height and width of the base window.
143
+ window_size (tuple[int]): The height and width of the window.
144
+ num_heads (int): Number of heads for spatial self-correlation.
145
+ value_drop (float, optional): Dropout ratio of value. Default: 0.0
146
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
147
+ """
148
+
149
+ def __init__(self, dim, base_win_size, window_size, num_heads, value_drop=0., proj_drop=0.):
150
+
151
+ super().__init__()
152
+ # parameters
153
+ self.dim = dim
154
+ self.window_size = window_size
155
+ self.num_heads = num_heads
156
+
157
+ # feature projection
158
+ self.qv = DFE(dim, dim)
159
+ self.proj = nn.Linear(dim, dim)
160
+
161
+ # dropout
162
+ self.value_drop = nn.Dropout(value_drop)
163
+ self.proj_drop = nn.Dropout(proj_drop)
164
+
165
+ # base window size
166
+ min_h = min(self.window_size[0], base_win_size[0])
167
+ min_w = min(self.window_size[1], base_win_size[1])
168
+ self.base_win_size = (min_h, min_w)
169
+
170
+ # normalization factor and spatial linear layer for S-SC
171
+ head_dim = dim // (2*num_heads)
172
+ self.scale = head_dim
173
+ self.spatial_linear = nn.Linear(self.window_size[0]*self.window_size[1] // (self.base_win_size[0]*self.base_win_size[1]), 1)
174
+
175
+ # define a parameter table of relative position bias
176
+ self.H_sp, self.W_sp = self.window_size
177
+ self.pos = DynamicPosBias(self.dim // 4, self.num_heads, residual=False)
178
+
179
+ def spatial_linear_projection(self, x):
180
+ B, num_h, L, C = x.shape
181
+ H, W = self.window_size
182
+ map_H, map_W = self.base_win_size
183
+
184
+ x = x.view(B, num_h, map_H, H//map_H, map_W, W//map_W, C).permute(0,1,2,4,6,3,5).contiguous().view(B, num_h, map_H*map_W, C, -1)
185
+ x = self.spatial_linear(x).view(B, num_h, map_H*map_W, C)
186
+ return x
187
+
188
+ def spatial_self_correlation(self, q, v):
189
+
190
+ B, num_head, L, C = q.shape
191
+
192
+ # spatial projection
193
+ v = self.spatial_linear_projection(v)
194
+
195
+ # compute correlation map
196
+ corr_map = (q @ v.transpose(-2,-1)) / self.scale
197
+
198
+ # add relative position bias
199
+ # generate mother-set
200
+ position_bias_h = torch.arange(1 - self.H_sp, self.H_sp, device=v.device)
201
+ position_bias_w = torch.arange(1 - self.W_sp, self.W_sp, device=v.device)
202
+ biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w]))
203
+ rpe_biases = biases.flatten(1).transpose(0, 1).contiguous().float()
204
+ pos = self.pos(rpe_biases)
205
+
206
+ # select position bias
207
+ coords_h = torch.arange(self.H_sp, device=v.device)
208
+ coords_w = torch.arange(self.W_sp, device=v.device)
209
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
210
+ coords_flatten = torch.flatten(coords, 1)
211
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
212
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
213
+ relative_coords[:, :, 0] += self.H_sp - 1
214
+ relative_coords[:, :, 1] += self.W_sp - 1
215
+ relative_coords[:, :, 0] *= 2 * self.W_sp - 1
216
+ relative_position_index = relative_coords.sum(-1)
217
+ relative_position_bias = pos[relative_position_index.view(-1)].view(
218
+ self.window_size[0] * self.window_size[1], self.base_win_size[0], self.window_size[0]//self.base_win_size[0], self.base_win_size[1], self.window_size[1]//self.base_win_size[1], -1) # Wh*Ww,Wh*Ww,nH
219
+ relative_position_bias = relative_position_bias.permute(0,1,3,5,2,4).contiguous().view(
220
+ self.window_size[0] * self.window_size[1], self.base_win_size[0]*self.base_win_size[1], self.num_heads, -1).mean(-1)
221
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
222
+ corr_map = corr_map + relative_position_bias.unsqueeze(0)
223
+
224
+ # transformation
225
+ v_drop = self.value_drop(v)
226
+ x = (corr_map @ v_drop).permute(0,2,1,3).contiguous().view(B, L, -1)
227
+
228
+ return x
229
+
230
+ def channel_self_correlation(self, q, v):
231
+
232
+ B, num_head, L, C = q.shape
233
+
234
+ # apply single head strategy
235
+ q = q.permute(0,2,1,3).contiguous().view(B, L, num_head*C)
236
+ v = v.permute(0,2,1,3).contiguous().view(B, L, num_head*C)
237
+
238
+ # compute correlation map
239
+ corr_map = (q.transpose(-2,-1) @ v) / L
240
+
241
+ # transformation
242
+ v_drop = self.value_drop(v)
243
+ x = (corr_map @ v_drop.transpose(-2,-1)).permute(0,2,1).contiguous().view(B, L, -1)
244
+
245
+ return x
246
+
247
+ def forward(self, x):
248
+ """
249
+ Args:
250
+ x: input features with shape of (B, H, W, C)
251
+ """
252
+ xB,xH,xW,xC = x.shape
253
+ qv = self.qv(x.view(xB,-1,xC), (xH,xW)).view(xB, xH, xW, xC)
254
+
255
+ # window partition
256
+ qv = window_partition(qv, self.window_size)
257
+ qv = qv.view(-1, self.window_size[0]*self.window_size[1], xC)
258
+
259
+ # qv splitting
260
+ B, L, C = qv.shape
261
+ qv = qv.view(B, L, 2, self.num_heads, C // (2*self.num_heads)).permute(2,0,3,1,4).contiguous()
262
+ q, v = qv[0], qv[1] # B, num_heads, L, C//num_heads
263
+
264
+ # spatial self-correlation (S-SC)
265
+ x_spatial = self.spatial_self_correlation(q, v)
266
+ x_spatial = x_spatial.view(-1, self.window_size[0], self.window_size[1], C//2)
267
+ x_spatial = window_reverse(x_spatial, (self.window_size[0],self.window_size[1]), xH, xW) # xB xH xW xC
268
+
269
+ # channel self-correlation (C-SC)
270
+ x_channel = self.channel_self_correlation(q, v)
271
+ x_channel = x_channel.view(-1, self.window_size[0], self.window_size[1], C//2)
272
+ x_channel = window_reverse(x_channel, (self.window_size[0], self.window_size[1]), xH, xW) # xB xH xW xC
273
+
274
+ # spatial-channel information fusion
275
+ x = torch.cat([x_spatial, x_channel], -1)
276
+ x = self.proj_drop(self.proj(x))
277
+
278
+ return x
279
+
280
+ def extra_repr(self) -> str:
281
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
282
+
283
+
284
+ class HierarchicalTransformerBlock(nn.Module):
285
+ """ Hierarchical Transformer Block.
286
+ Args:
287
+ dim (int): Number of input channels.
288
+ input_resolution (tuple[int]): Input resulotion.
289
+ num_heads (int): Number of heads for spatial self-correlation.
290
+ base_win_size (tuple[int]): The height and width of the base window.
291
+ window_size (tuple[int]): The height and width of the window.
292
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
293
+ drop (float, optional): Dropout rate. Default: 0.0
294
+ value_drop (float, optional): Dropout ratio of value. Default: 0.0
295
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
296
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
297
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
298
+ """
299
+
300
+ def __init__(self, dim, input_resolution, num_heads, base_win_size, window_size,
301
+ mlp_ratio=4., drop=0., value_drop=0., drop_path=0.,
302
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
303
+ super().__init__()
304
+ self.dim = dim
305
+ self.input_resolution = input_resolution
306
+ self.num_heads = num_heads
307
+ self.window_size = window_size
308
+ self.mlp_ratio = mlp_ratio
309
+
310
+ # check window size
311
+ if (window_size[0] > base_win_size[0]) and (window_size[1] > base_win_size[1]):
312
+ assert window_size[0] % base_win_size[0] == 0, "please ensure the window size is smaller than or divisible by the base window size"
313
+ assert window_size[1] % base_win_size[1] == 0, "please ensure the window size is smaller than or divisible by the base window size"
314
+
315
+
316
+ self.norm1 = norm_layer(dim)
317
+ self.correlation = SCC(
318
+ dim, base_win_size=base_win_size, window_size=self.window_size, num_heads=num_heads,
319
+ value_drop=value_drop, proj_drop=drop)
320
+
321
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
322
+ self.norm2 = norm_layer(dim)
323
+ mlp_hidden_dim = int(dim * mlp_ratio)
324
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
325
+
326
+ def check_image_size(self, x, win_size):
327
+ x = x.permute(0,3,1,2).contiguous()
328
+ _, _, h, w = x.size()
329
+ mod_pad_h = (win_size[0] - h % win_size[0]) % win_size[0]
330
+ mod_pad_w = (win_size[1] - w % win_size[1]) % win_size[1]
331
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
332
+ x = x.permute(0,2,3,1).contiguous()
333
+ return x
334
+
335
+ def forward(self, x, x_size, win_size):
336
+ H, W = x_size
337
+ B, L, C = x.shape
338
+
339
+ shortcut = x
340
+ x = x.view(B, H, W, C)
341
+
342
+ # padding
343
+ x = self.check_image_size(x, win_size)
344
+ _, H_pad, W_pad, _ = x.shape # shape after padding
345
+
346
+ x = self.correlation(x)
347
+
348
+ # unpad
349
+ x = x[:, :H, :W, :].contiguous()
350
+
351
+ # norm
352
+ x = x.view(B, H * W, C)
353
+ x = self.norm1(x)
354
+
355
+ # FFN
356
+ x = shortcut + self.drop_path(x)
357
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
358
+
359
+ return x
360
+
361
+ def extra_repr(self) -> str:
362
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
363
+ f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}"
364
+
365
+
366
+ class PatchMerging(nn.Module):
367
+ """ Patch Merging Layer.
368
+ Args:
369
+ input_resolution (tuple[int]): Resolution of input feature.
370
+ dim (int): Number of input channels.
371
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
372
+ """
373
+
374
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
375
+ super().__init__()
376
+ self.input_resolution = input_resolution
377
+ self.dim = dim
378
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
379
+ self.norm = norm_layer(4 * dim)
380
+
381
+ def forward(self, x):
382
+ """
383
+ x: B, H*W, C
384
+ """
385
+ H, W = self.input_resolution
386
+ B, L, C = x.shape
387
+ assert L == H * W, "input feature has wrong size"
388
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
389
+
390
+ x = x.view(B, H, W, C)
391
+
392
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
393
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
394
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
395
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
396
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
397
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
398
+
399
+ x = self.norm(x)
400
+ x = self.reduction(x)
401
+
402
+ return x
403
+
404
+ def extra_repr(self) -> str:
405
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
406
+
407
+
408
+ class BasicLayer(nn.Module):
409
+ """ A basic Hierarchical Transformer layer for one stage.
410
+
411
+ Args:
412
+ dim (int): Number of input channels.
413
+ input_resolution (tuple[int]): Input resolution.
414
+ depth (int): Number of blocks.
415
+ num_heads (int): Number of heads for spatial self-correlation.
416
+ base_win_size (tuple[int]): The height and width of the base window.
417
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
418
+ drop (float, optional): Dropout rate. Default: 0.0
419
+ value_drop (float, optional): Dropout ratio of value. Default: 0.0
420
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
421
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
422
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
423
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
424
+ hier_win_ratios (list): hierarchical window ratios for a transformer block. Default: [0.5,1,2,4,6,8].
425
+ """
426
+
427
+ def __init__(self, dim, input_resolution, depth, num_heads, base_win_size,
428
+ mlp_ratio=4., drop=0., value_drop=0.,drop_path=0., norm_layer=nn.LayerNorm,
429
+ downsample=None, use_checkpoint=False, hier_win_ratios=[0.5,1,2,4,6,8]):
430
+
431
+ super().__init__()
432
+ self.dim = dim
433
+ self.input_resolution = input_resolution
434
+ self.depth = depth
435
+ self.use_checkpoint = use_checkpoint
436
+
437
+ self.win_hs = [int(base_win_size[0] * ratio) for ratio in hier_win_ratios]
438
+ self.win_ws = [int(base_win_size[1] * ratio) for ratio in hier_win_ratios]
439
+
440
+ # build blocks
441
+ self.blocks = nn.ModuleList([
442
+ HierarchicalTransformerBlock(dim=dim, input_resolution=input_resolution,
443
+ num_heads=num_heads,
444
+ base_win_size=base_win_size,
445
+ window_size=(self.win_hs[i], self.win_ws[i]),
446
+ mlp_ratio=mlp_ratio,
447
+ drop=drop, value_drop=value_drop,
448
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
449
+ norm_layer=norm_layer)
450
+ for i in range(depth)])
451
+
452
+ # patch merging layer
453
+ if downsample is not None:
454
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
455
+ else:
456
+ self.downsample = None
457
+
458
+ def forward(self, x, x_size):
459
+
460
+ i = 0
461
+ for blk in self.blocks:
462
+ if self.use_checkpoint:
463
+ x = checkpoint.checkpoint(blk, x, x_size, (self.win_hs[i], self.win_ws[i]))
464
+ else:
465
+ x = blk(x, x_size, (self.win_hs[i], self.win_ws[i]))
466
+ i = i + 1
467
+
468
+ if self.downsample is not None:
469
+ x = self.downsample(x)
470
+ return x
471
+
472
+ def extra_repr(self) -> str:
473
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
474
+
475
+
476
+ class RHTB(nn.Module):
477
+ """Residual Hierarchical Transformer Block (RHTB).
478
+ Args:
479
+ dim (int): Number of input channels.
480
+ input_resolution (tuple[int]): Input resolution.
481
+ depth (int): Number of blocks.
482
+ num_heads (int): Number of heads for spatial self-correlation.
483
+ base_win_size (tuple[int]): The height and width of the base window.
484
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
485
+ drop (float, optional): Dropout rate. Default: 0.0
486
+ value_drop (float, optional): Dropout ratio of value. Default: 0.0
487
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
488
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
489
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
490
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
491
+ img_size: Input image size.
492
+ patch_size: Patch size.
493
+ resi_connection: The convolutional block before residual connection.
494
+ hier_win_ratios (list): hierarchical window ratios for a transformer block. Default: [0.5,1,2,4,6,8].
495
+ """
496
+
497
+ def __init__(self, dim, input_resolution, depth, num_heads, base_win_size,
498
+ mlp_ratio=4., drop=0., value_drop=0., drop_path=0., norm_layer=nn.LayerNorm,
499
+ downsample=None, use_checkpoint=False, img_size=224, patch_size=4,
500
+ resi_connection='1conv', hier_win_ratios=[0.5,1,2,4,6,8]):
501
+ super(RHTB, self).__init__()
502
+
503
+ self.dim = dim
504
+ self.input_resolution = input_resolution
505
+
506
+ self.residual_group = BasicLayer(dim=dim,
507
+ input_resolution=input_resolution,
508
+ depth=depth,
509
+ num_heads=num_heads,
510
+ base_win_size=base_win_size,
511
+ mlp_ratio=mlp_ratio,
512
+ drop=drop, value_drop=value_drop,
513
+ drop_path=drop_path,
514
+ norm_layer=norm_layer,
515
+ downsample=downsample,
516
+ use_checkpoint=use_checkpoint,
517
+ hier_win_ratios=hier_win_ratios)
518
+
519
+ if resi_connection == '1conv':
520
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
521
+ elif resi_connection == '3conv':
522
+ # to save parameters and memory
523
+ self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
524
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
525
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
526
+ nn.Conv2d(dim // 4, dim, 3, 1, 1))
527
+
528
+ self.patch_embed = PatchEmbed(
529
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
530
+ norm_layer=None)
531
+
532
+ self.patch_unembed = PatchUnEmbed(
533
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
534
+ norm_layer=None)
535
+
536
+ def forward(self, x, x_size):
537
+ return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
538
+
539
+
540
+ class PatchEmbed(nn.Module):
541
+ r""" Image to Patch Embedding
542
+
543
+ Args:
544
+ img_size (int): Image size. Default: 224.
545
+ patch_size (int): Patch token size. Default: 4.
546
+ in_chans (int): Number of input image channels. Default: 3.
547
+ embed_dim (int): Number of linear projection output channels. Default: 96.
548
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
549
+ """
550
+
551
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
552
+ super().__init__()
553
+ img_size = to_2tuple(img_size)
554
+ patch_size = to_2tuple(patch_size)
555
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
556
+ self.img_size = img_size
557
+ self.patch_size = patch_size
558
+ self.patches_resolution = patches_resolution
559
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
560
+
561
+ self.in_chans = in_chans
562
+ self.embed_dim = embed_dim
563
+
564
+ if norm_layer is not None:
565
+ self.norm = norm_layer(embed_dim)
566
+ else:
567
+ self.norm = None
568
+
569
+ def forward(self, x):
570
+ x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
571
+ if self.norm is not None:
572
+ x = self.norm(x)
573
+ return x
574
+
575
+
576
+ class PatchUnEmbed(nn.Module):
577
+ r""" Image to Patch Unembedding
578
+
579
+ Args:
580
+ img_size (int): Image size. Default: 224.
581
+ patch_size (int): Patch token size. Default: 4.
582
+ in_chans (int): Number of input image channels. Default: 3.
583
+ embed_dim (int): Number of linear projection output channels. Default: 96.
584
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
585
+ """
586
+
587
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
588
+ super().__init__()
589
+ img_size = to_2tuple(img_size)
590
+ patch_size = to_2tuple(patch_size)
591
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
592
+ self.img_size = img_size
593
+ self.patch_size = patch_size
594
+ self.patches_resolution = patches_resolution
595
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
596
+
597
+ self.in_chans = in_chans
598
+ self.embed_dim = embed_dim
599
+
600
+ def forward(self, x, x_size):
601
+ B, HW, C = x.shape
602
+ x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
603
+ return x
604
+
605
+
606
+ class Upsample(nn.Sequential):
607
+ """Upsample module.
608
+
609
+ Args:
610
+ scale (int): Scale factor. Supported scales: 2^n and 3.
611
+ num_feat (int): Channel number of intermediate features.
612
+ """
613
+
614
+ def __init__(self, scale, num_feat):
615
+ m = []
616
+ if (scale & (scale - 1)) == 0: # scale = 2^n
617
+ for _ in range(int(math.log(scale, 2))):
618
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
619
+ m.append(nn.PixelShuffle(2))
620
+ elif scale == 3:
621
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
622
+ m.append(nn.PixelShuffle(3))
623
+ else:
624
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
625
+ super(Upsample, self).__init__(*m)
626
+
627
+
628
+ class UpsampleOneStep(nn.Sequential):
629
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
630
+ Used in lightweight SR to save parameters.
631
+
632
+ Args:
633
+ scale (int): Scale factor. Supported scales: 2^n and 3.
634
+ num_feat (int): Channel number of intermediate features.
635
+
636
+ """
637
+
638
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
639
+ self.num_feat = num_feat
640
+ self.input_resolution = input_resolution
641
+ m = []
642
+ m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
643
+ m.append(nn.PixelShuffle(scale))
644
+ super(UpsampleOneStep, self).__init__(*m)
645
+
646
+
647
+ class HiT_SIR(nn.Module, PyTorchModelHubMixin):
648
+ """ HiT-SIR network.
649
+
650
+ Args:
651
+ img_size (int | tuple(int)): Input image size. Default 64
652
+ patch_size (int | tuple(int)): Patch size. Default: 1
653
+ in_chans (int): Number of input image channels. Default: 3
654
+ embed_dim (int): Patch embedding dimension. Default: 96
655
+ depths (tuple(int)): Depth of each Transformer block.
656
+ num_heads (tuple(int)): Number of heads for spatial self-correlation in different layers.
657
+ base_win_size (tuple[int]): The height and width of the base window.
658
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
659
+ drop_rate (float): Dropout rate. Default: 0
660
+ value_drop_rate (float): Dropout ratio of value. Default: 0.0
661
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
662
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
663
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
664
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
665
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
666
+ upscale (int): Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
667
+ img_range (float): Image range. 1. or 255.
668
+ upsampler (str): The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
669
+ resi_connection (str): The convolutional block before residual connection. '1conv'/'3conv'
670
+ hier_win_ratios (list): hierarchical window ratios for a transformer block. Default: [0.5,1,2,4,6,8].
671
+ """
672
+
673
+ def __init__(self, img_size=64, patch_size=1, in_chans=3,
674
+ embed_dim=60, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
675
+ base_win_size=[8,8], mlp_ratio=2.,
676
+ drop_rate=0., value_drop_rate=0., drop_path_rate=0.,
677
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
678
+ use_checkpoint=False, upscale=4, img_range=1., upsampler='pixelshuffledirect', resi_connection='1conv',
679
+ hier_win_ratios=[0.5,1,2,4,6,8],
680
+ **kwargs):
681
+ super(HiT_SIR, self).__init__()
682
+ num_in_ch = in_chans
683
+ num_out_ch = in_chans
684
+ num_feat = 64
685
+ self.img_range = img_range
686
+ if in_chans == 3:
687
+ rgb_mean = (0.4488, 0.4371, 0.4040)
688
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
689
+ else:
690
+ self.mean = torch.zeros(1, 1, 1, 1)
691
+ self.upscale = upscale
692
+ self.upsampler = upsampler
693
+ self.base_win_size = base_win_size
694
+
695
+ #####################################################################################################
696
+ ################################### 1, shallow feature extraction ###################################
697
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
698
+
699
+ #####################################################################################################
700
+ ################################### 2, deep feature extraction ######################################
701
+ self.num_layers = len(depths)
702
+ self.embed_dim = embed_dim
703
+ self.ape = ape
704
+ self.patch_norm = patch_norm
705
+ self.num_features = embed_dim
706
+ self.mlp_ratio = mlp_ratio
707
+
708
+ # split image into non-overlapping patches
709
+ self.patch_embed = PatchEmbed(
710
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
711
+ norm_layer=norm_layer if self.patch_norm else None)
712
+ num_patches = self.patch_embed.num_patches
713
+ patches_resolution = self.patch_embed.patches_resolution
714
+ self.patches_resolution = patches_resolution
715
+
716
+ # merge non-overlapping patches into image
717
+ self.patch_unembed = PatchUnEmbed(
718
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
719
+ norm_layer=norm_layer if self.patch_norm else None)
720
+
721
+ # absolute position embedding
722
+ if self.ape:
723
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
724
+ trunc_normal_(self.absolute_pos_embed, std=.02)
725
+
726
+ self.pos_drop = nn.Dropout(p=drop_rate)
727
+
728
+ # stochastic depth
729
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
730
+
731
+ # build Residual Hierarchical Transformer blocks (RHTB)
732
+ self.layers = nn.ModuleList()
733
+ for i_layer in range(self.num_layers):
734
+ layer = RHTB(dim=embed_dim,
735
+ input_resolution=(patches_resolution[0],
736
+ patches_resolution[1]),
737
+ depth=depths[i_layer],
738
+ num_heads=num_heads[i_layer],
739
+ base_win_size=base_win_size,
740
+ mlp_ratio=self.mlp_ratio,
741
+ drop=drop_rate, value_drop=value_drop_rate,
742
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
743
+ norm_layer=norm_layer,
744
+ downsample=None,
745
+ use_checkpoint=use_checkpoint,
746
+ img_size=img_size,
747
+ patch_size=patch_size,
748
+ resi_connection=resi_connection,
749
+ hier_win_ratios=hier_win_ratios
750
+ )
751
+ self.layers.append(layer)
752
+ self.norm = norm_layer(self.num_features)
753
+
754
+ # build the last conv layer in deep feature extraction
755
+ if resi_connection == '1conv':
756
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
757
+ elif resi_connection == '3conv':
758
+ # to save parameters and memory
759
+ self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
760
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
761
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
762
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
763
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
764
+
765
+ #####################################################################################################
766
+ ################################ 3, high quality image reconstruction ################################
767
+ if self.upsampler == 'pixelshuffle':
768
+ # for classical SR
769
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
770
+ nn.LeakyReLU(inplace=True))
771
+ self.upsample = Upsample(upscale, num_feat)
772
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
773
+ elif self.upsampler == 'pixelshuffledirect':
774
+ # for lightweight SR (to save parameters)
775
+ self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
776
+ (patches_resolution[0], patches_resolution[1]))
777
+ elif self.upsampler == 'nearest+conv':
778
+ # for real-world SR (less artifacts)
779
+ assert self.upscale == 4, 'only support x4 now.'
780
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
781
+ nn.LeakyReLU(inplace=True))
782
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
783
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
784
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
785
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
786
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
787
+ else:
788
+ # for image denoising and JPEG compression artifact reduction
789
+ self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
790
+
791
+ self.apply(self._init_weights)
792
+
793
+ def _init_weights(self, m):
794
+ if isinstance(m, nn.Linear):
795
+ trunc_normal_(m.weight, std=.02)
796
+ if isinstance(m, nn.Linear) and m.bias is not None:
797
+ nn.init.constant_(m.bias, 0)
798
+ elif isinstance(m, nn.LayerNorm):
799
+ nn.init.constant_(m.bias, 0)
800
+ nn.init.constant_(m.weight, 1.0)
801
+
802
+ @torch.jit.ignore
803
+ def no_weight_decay(self):
804
+ return {'absolute_pos_embed'}
805
+
806
+ @torch.jit.ignore
807
+ def no_weight_decay_keywords(self):
808
+ return {'relative_position_bias_table'}
809
+
810
+
811
+ def forward_features(self, x):
812
+ x_size = (x.shape[2], x.shape[3])
813
+ x = self.patch_embed(x)
814
+ if self.ape:
815
+ x = x + self.absolute_pos_embed
816
+ x = self.pos_drop(x)
817
+
818
+ for layer in self.layers:
819
+ x = layer(x, x_size)
820
+
821
+ x = self.norm(x) # B L C
822
+ x = self.patch_unembed(x, x_size)
823
+
824
+ return x
825
+
826
+ def infer_image(self, image_path, device):
827
+
828
+ io_backend_opt = {'type':'disk'}
829
+ self.file_client = FileClient(io_backend_opt.pop('type'), **io_backend_opt)
830
+
831
+ # load lq image
832
+ lq_path = image_path
833
+ img_bytes = self.file_client.get(lq_path, 'lq')
834
+ img_lq = imfrombytes(img_bytes, float32=True)
835
+
836
+ # BGR to RGB, HWC to CHW, numpy to tensor
837
+ x = img2tensor(img_lq, bgr2rgb=True, float32=True)[None,...]
838
+
839
+ x= x.to(device)
840
+
841
+ out = self(x)
842
+
843
+ out = out.cpu()
844
+
845
+ out = tensor2img(out)
846
+
847
+ return out
848
+
849
+ def forward(self, x):
850
+ H, W = x.shape[2:]
851
+
852
+ self.mean = self.mean.type_as(x)
853
+ x = (x - self.mean) * self.img_range
854
+
855
+ if self.upsampler == 'pixelshuffle':
856
+ # for classical SR
857
+ x = self.conv_first(x)
858
+ x = self.conv_after_body(self.forward_features(x)) + x
859
+ x = self.conv_before_upsample(x)
860
+ x = self.conv_last(self.upsample(x))
861
+ elif self.upsampler == 'pixelshuffledirect':
862
+ # for lightweight SR
863
+ x = self.conv_first(x)
864
+ x = self.conv_after_body(self.forward_features(x)) + x
865
+ x = self.upsample(x)
866
+ elif self.upsampler == 'nearest+conv':
867
+ # for real-world SR
868
+ x = self.conv_first(x)
869
+ x = self.conv_after_body(self.forward_features(x)) + x
870
+ x = self.conv_before_upsample(x)
871
+ x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
872
+ x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
873
+ x = self.conv_last(self.lrelu(self.conv_hr(x)))
874
+ else:
875
+ # for image denoising and JPEG compression artifact reduction
876
+ x_first = self.conv_first(x)
877
+ res = self.conv_after_body(self.forward_features(x_first)) + x_first
878
+ x = x + self.conv_last(res)
879
+
880
+ x = x / self.img_range + self.mean
881
+
882
+ return x[:, :, :H*self.upscale, :W*self.upscale]
883
+
884
+
885
+ if __name__ == '__main__':
886
+ upscale = 4
887
+ base_win_size = [8, 8]
888
+ height = (1024 // upscale // base_win_size[0] + 1) * base_win_size[0]
889
+ width = (720 // upscale // base_win_size[1] + 1) * base_win_size[1]
890
+
891
+ ## HiT-SIR
892
+ model = HiT_SIR(upscale=4, img_size=(height, width),
893
+ base_win_size=base_win_size, img_range=1., depths=[6, 6, 6, 6],
894
+ embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
895
+
896
+ params_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
897
+ print("params: ", params_num)
898
+