Patil commited on
Commit
386787e
·
verified ·
1 Parent(s): 3e29a7d

Update birefnet.py

Browse files
Files changed (1) hide show
  1. birefnet.py +35 -27
birefnet.py CHANGED
@@ -4,10 +4,14 @@ import os
4
  import math
5
  from transformers import PretrainedConfig
6
 
 
7
  class Config(PretrainedConfig):
8
  def __init__(self) -> None:
 
 
 
9
  # PATH settings
10
- self.sys_home_dir = os.path.expanduser('~') # Make up your file system as: SYS_HOME_DIR/codes/dis/BiRefNet, SYS_HOME_DIR/datasets/dis/xx, SYS_HOME_DIR/weights/xx
11
 
12
  # TASK settings
13
  self.task = ['DIS5K', 'COD', 'HRSOD', 'DIS5K+HRSOD+HRS10K', 'P3M-10k'][0]
@@ -615,6 +619,7 @@ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
615
 
616
  # config = Config()
617
 
 
618
  class Mlp(nn.Module):
619
  """ Multilayer perceptron."""
620
 
@@ -739,7 +744,8 @@ class WindowAttention(nn.Module):
739
  attn = (q @ k.transpose(-2, -1))
740
 
741
  relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
742
- self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
 
743
  relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
744
  attn = attn + relative_position_bias.unsqueeze(0)
745
 
@@ -974,8 +980,9 @@ class BasicLayer(nn.Module):
974
  """
975
 
976
  # calculate attention mask for SW-MSA
977
- Hp = int(np.ceil(H / self.window_size)) * self.window_size
978
- Wp = int(np.ceil(W / self.window_size)) * self.window_size
 
979
  img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
980
  h_slices = (slice(0, -self.window_size),
981
  slice(-self.window_size, -self.shift_size),
@@ -992,7 +999,7 @@ class BasicLayer(nn.Module):
992
  mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
993
  mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
994
  attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
995
- attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
996
 
997
  for blk in self.blocks:
998
  blk.H, blk.W = H, W
@@ -1961,6 +1968,7 @@ import torch.nn as nn
1961
  import torch.nn.functional as F
1962
  from kornia.filters import laplacian
1963
  from transformers import PreTrainedModel
 
1964
 
1965
  # from config import Config
1966
  # from dataset import class_labels_TR_sorted
@@ -1974,6 +1982,18 @@ from transformers import PreTrainedModel
1974
  from .BiRefNet_config import BiRefNetConfig
1975
 
1976
 
 
 
 
 
 
 
 
 
 
 
 
 
1977
  class BiRefNet(
1978
  PreTrainedModel
1979
  ):
@@ -2124,18 +2144,6 @@ class Decoder(nn.Module):
2124
  self.gdt_convs_attn_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2125
  self.gdt_convs_attn_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2126
 
2127
- def get_patches_batch(self, x, p):
2128
- _size_h, _size_w = p.shape[2:]
2129
- patches_batch = []
2130
- for idx in range(x.shape[0]):
2131
- columns_x = torch.split(x[idx], split_size_or_sections=_size_w, dim=-1)
2132
- patches_x = []
2133
- for column_x in columns_x:
2134
- patches_x += [p.unsqueeze(0) for p in torch.split(column_x, split_size_or_sections=_size_h, dim=-2)]
2135
- patch_sample = torch.cat(patches_x, dim=1)
2136
- patches_batch.append(patch_sample)
2137
- return torch.cat(patches_batch, dim=0)
2138
-
2139
  def forward(self, features):
2140
  if self.training and self.config.out_ref:
2141
  outs_gdt_pred = []
@@ -2146,10 +2154,10 @@ class Decoder(nn.Module):
2146
  outs = []
2147
 
2148
  if self.config.dec_ipt:
2149
- patches_batch = self.get_patches_batch(x, x4) if self.split else x
2150
  x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1)
2151
  p4 = self.decoder_block4(x4)
2152
- m4 = self.conv_ms_spvn_4(p4) if self.config.ms_supervision else None
2153
  if self.config.out_ref:
2154
  p4_gdt = self.gdt_convs_4(p4)
2155
  if self.training:
@@ -2167,10 +2175,10 @@ class Decoder(nn.Module):
2167
  _p3 = _p4 + self.lateral_block4(x3)
2168
 
2169
  if self.config.dec_ipt:
2170
- patches_batch = self.get_patches_batch(x, _p3) if self.split else x
2171
  _p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1)
2172
  p3 = self.decoder_block3(_p3)
2173
- m3 = self.conv_ms_spvn_3(p3) if self.config.ms_supervision else None
2174
  if self.config.out_ref:
2175
  p3_gdt = self.gdt_convs_3(p3)
2176
  if self.training:
@@ -2193,10 +2201,10 @@ class Decoder(nn.Module):
2193
  _p2 = _p3 + self.lateral_block3(x2)
2194
 
2195
  if self.config.dec_ipt:
2196
- patches_batch = self.get_patches_batch(x, _p2) if self.split else x
2197
  _p2 = torch.cat((_p2, self.ipt_blk3(F.interpolate(patches_batch, size=x2.shape[2:], mode='bilinear', align_corners=True))), 1)
2198
  p2 = self.decoder_block2(_p2)
2199
- m2 = self.conv_ms_spvn_2(p2) if self.config.ms_supervision else None
2200
  if self.config.out_ref:
2201
  p2_gdt = self.gdt_convs_2(p2)
2202
  if self.training:
@@ -2214,17 +2222,17 @@ class Decoder(nn.Module):
2214
  _p1 = _p2 + self.lateral_block2(x1)
2215
 
2216
  if self.config.dec_ipt:
2217
- patches_batch = self.get_patches_batch(x, _p1) if self.split else x
2218
  _p1 = torch.cat((_p1, self.ipt_blk2(F.interpolate(patches_batch, size=x1.shape[2:], mode='bilinear', align_corners=True))), 1)
2219
  _p1 = self.decoder_block1(_p1)
2220
  _p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True)
2221
 
2222
  if self.config.dec_ipt:
2223
- patches_batch = self.get_patches_batch(x, _p1) if self.split else x
2224
  _p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1)
2225
  p1_out = self.conv_out1(_p1)
2226
 
2227
- if self.config.ms_supervision:
2228
  outs.append(m4)
2229
  outs.append(m3)
2230
  outs.append(m2)
@@ -2241,4 +2249,4 @@ class SimpleConvs(nn.Module):
2241
  self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, 1)
2242
 
2243
  def forward(self, x):
2244
- return self.conv_out(self.conv1(x))
 
4
  import math
5
  from transformers import PretrainedConfig
6
 
7
+
8
  class Config(PretrainedConfig):
9
  def __init__(self) -> None:
10
+ # Compatible with the latest version of transformers
11
+ super().__init__()
12
+
13
  # PATH settings
14
+ self.sys_home_dir = os.path.expanduser('~') # Make up your file system as: SYS_HOME_DIR/codes/dis/BiRefNet, SYS_HOME_DIR/datasets/dis/xx, SYS_HOME_DIR/weights/xx
15
 
16
  # TASK settings
17
  self.task = ['DIS5K', 'COD', 'HRSOD', 'DIS5K+HRSOD+HRS10K', 'P3M-10k'][0]
 
619
 
620
  # config = Config()
621
 
622
+
623
  class Mlp(nn.Module):
624
  """ Multilayer perceptron."""
625
 
 
744
  attn = (q @ k.transpose(-2, -1))
745
 
746
  relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
747
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
748
+ ) # Wh*Ww, Wh*Ww, nH
749
  relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
750
  attn = attn + relative_position_bias.unsqueeze(0)
751
 
 
980
  """
981
 
982
  # calculate attention mask for SW-MSA
983
+ # Turn int to torch.tensor for the compatiability with torch.compile in PyTorch 2.5.
984
+ Hp = torch.ceil(torch.tensor(H) / self.window_size).to(torch.int64) * self.window_size
985
+ Wp = torch.ceil(torch.tensor(W) / self.window_size).to(torch.int64) * self.window_size
986
  img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
987
  h_slices = (slice(0, -self.window_size),
988
  slice(-self.window_size, -self.shift_size),
 
999
  mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
1000
  mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
1001
  attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
1002
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)).to(x.dtype)
1003
 
1004
  for blk in self.blocks:
1005
  blk.H, blk.W = H, W
 
1968
  import torch.nn.functional as F
1969
  from kornia.filters import laplacian
1970
  from transformers import PreTrainedModel
1971
+ from einops import rearrange
1972
 
1973
  # from config import Config
1974
  # from dataset import class_labels_TR_sorted
 
1982
  from .BiRefNet_config import BiRefNetConfig
1983
 
1984
 
1985
+ def image2patches(image, grid_h=2, grid_w=2, patch_ref=None, transformation='b c (hg h) (wg w) -> (b hg wg) c h w'):
1986
+ if patch_ref is not None:
1987
+ grid_h, grid_w = image.shape[-2] // patch_ref.shape[-2], image.shape[-1] // patch_ref.shape[-1]
1988
+ patches = rearrange(image, transformation, hg=grid_h, wg=grid_w)
1989
+ return patches
1990
+
1991
+ def patches2image(patches, grid_h=2, grid_w=2, patch_ref=None, transformation='(b hg wg) c h w -> b c (hg h) (wg w)'):
1992
+ if patch_ref is not None:
1993
+ grid_h, grid_w = patch_ref.shape[-2] // patches[0].shape[-2], patch_ref.shape[-1] // patches[0].shape[-1]
1994
+ image = rearrange(patches, transformation, hg=grid_h, wg=grid_w)
1995
+ return image
1996
+
1997
  class BiRefNet(
1998
  PreTrainedModel
1999
  ):
 
2144
  self.gdt_convs_attn_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2145
  self.gdt_convs_attn_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2146
 
 
 
 
 
 
 
 
 
 
 
 
 
2147
  def forward(self, features):
2148
  if self.training and self.config.out_ref:
2149
  outs_gdt_pred = []
 
2154
  outs = []
2155
 
2156
  if self.config.dec_ipt:
2157
+ patches_batch = image2patches(x, patch_ref=x4, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
2158
  x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1)
2159
  p4 = self.decoder_block4(x4)
2160
+ m4 = self.conv_ms_spvn_4(p4) if self.config.ms_supervision and self.training else None
2161
  if self.config.out_ref:
2162
  p4_gdt = self.gdt_convs_4(p4)
2163
  if self.training:
 
2175
  _p3 = _p4 + self.lateral_block4(x3)
2176
 
2177
  if self.config.dec_ipt:
2178
+ patches_batch = image2patches(x, patch_ref=_p3, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
2179
  _p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1)
2180
  p3 = self.decoder_block3(_p3)
2181
+ m3 = self.conv_ms_spvn_3(p3) if self.config.ms_supervision and self.training else None
2182
  if self.config.out_ref:
2183
  p3_gdt = self.gdt_convs_3(p3)
2184
  if self.training:
 
2201
  _p2 = _p3 + self.lateral_block3(x2)
2202
 
2203
  if self.config.dec_ipt:
2204
+ patches_batch = image2patches(x, patch_ref=_p2, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
2205
  _p2 = torch.cat((_p2, self.ipt_blk3(F.interpolate(patches_batch, size=x2.shape[2:], mode='bilinear', align_corners=True))), 1)
2206
  p2 = self.decoder_block2(_p2)
2207
+ m2 = self.conv_ms_spvn_2(p2) if self.config.ms_supervision and self.training else None
2208
  if self.config.out_ref:
2209
  p2_gdt = self.gdt_convs_2(p2)
2210
  if self.training:
 
2222
  _p1 = _p2 + self.lateral_block2(x1)
2223
 
2224
  if self.config.dec_ipt:
2225
+ patches_batch = image2patches(x, patch_ref=_p1, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
2226
  _p1 = torch.cat((_p1, self.ipt_blk2(F.interpolate(patches_batch, size=x1.shape[2:], mode='bilinear', align_corners=True))), 1)
2227
  _p1 = self.decoder_block1(_p1)
2228
  _p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True)
2229
 
2230
  if self.config.dec_ipt:
2231
+ patches_batch = image2patches(x, patch_ref=_p1, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
2232
  _p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1)
2233
  p1_out = self.conv_out1(_p1)
2234
 
2235
+ if self.config.ms_supervision and self.training:
2236
  outs.append(m4)
2237
  outs.append(m3)
2238
  outs.append(m2)
 
2249
  self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, 1)
2250
 
2251
  def forward(self, x):
2252
+ return self.conv_out(self.conv1(x))