Update birefnet.py
Browse files- birefnet.py +35 -27
    	
        birefnet.py
    CHANGED
    
    | @@ -4,10 +4,14 @@ import os | |
| 4 | 
             
            import math
         | 
| 5 | 
             
            from transformers import PretrainedConfig
         | 
| 6 |  | 
|  | |
| 7 | 
             
            class Config(PretrainedConfig):
         | 
| 8 | 
             
                def __init__(self) -> None:
         | 
|  | |
|  | |
|  | |
| 9 | 
             
                    # PATH settings
         | 
| 10 | 
            -
                    self.sys_home_dir = os.path.expanduser('~') | 
| 11 |  | 
| 12 | 
             
                    # TASK settings
         | 
| 13 | 
             
                    self.task = ['DIS5K', 'COD', 'HRSOD', 'DIS5K+HRSOD+HRS10K', 'P3M-10k'][0]
         | 
| @@ -615,6 +619,7 @@ from timm.models.layers import DropPath, to_2tuple, trunc_normal_ | |
| 615 |  | 
| 616 | 
             
            # config = Config()
         | 
| 617 |  | 
|  | |
| 618 | 
             
            class Mlp(nn.Module):
         | 
| 619 | 
             
                """ Multilayer perceptron."""
         | 
| 620 |  | 
| @@ -739,7 +744,8 @@ class WindowAttention(nn.Module): | |
| 739 | 
             
                        attn = (q @ k.transpose(-2, -1))
         | 
| 740 |  | 
| 741 | 
             
                        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
         | 
| 742 | 
            -
                            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 | 
|  | |
| 743 | 
             
                        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
         | 
| 744 | 
             
                        attn = attn + relative_position_bias.unsqueeze(0)
         | 
| 745 |  | 
| @@ -974,8 +980,9 @@ class BasicLayer(nn.Module): | |
| 974 | 
             
                    """
         | 
| 975 |  | 
| 976 | 
             
                    # calculate attention mask for SW-MSA
         | 
| 977 | 
            -
                     | 
| 978 | 
            -
                     | 
|  | |
| 979 | 
             
                    img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # 1 Hp Wp 1
         | 
| 980 | 
             
                    h_slices = (slice(0, -self.window_size),
         | 
| 981 | 
             
                                slice(-self.window_size, -self.shift_size),
         | 
| @@ -992,7 +999,7 @@ class BasicLayer(nn.Module): | |
| 992 | 
             
                    mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
         | 
| 993 | 
             
                    mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
         | 
| 994 | 
             
                    attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
         | 
| 995 | 
            -
                    attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
         | 
| 996 |  | 
| 997 | 
             
                    for blk in self.blocks:
         | 
| 998 | 
             
                        blk.H, blk.W = H, W
         | 
| @@ -1961,6 +1968,7 @@ import torch.nn as nn | |
| 1961 | 
             
            import torch.nn.functional as F
         | 
| 1962 | 
             
            from kornia.filters import laplacian
         | 
| 1963 | 
             
            from transformers import PreTrainedModel
         | 
|  | |
| 1964 |  | 
| 1965 | 
             
            # from config import Config
         | 
| 1966 | 
             
            # from dataset import class_labels_TR_sorted
         | 
| @@ -1974,6 +1982,18 @@ from transformers import PreTrainedModel | |
| 1974 | 
             
            from .BiRefNet_config import BiRefNetConfig
         | 
| 1975 |  | 
| 1976 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1977 | 
             
            class BiRefNet(
         | 
| 1978 | 
             
                PreTrainedModel
         | 
| 1979 | 
             
            ):
         | 
| @@ -2124,18 +2144,6 @@ class Decoder(nn.Module): | |
| 2124 | 
             
                            self.gdt_convs_attn_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
         | 
| 2125 | 
             
                            self.gdt_convs_attn_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
         | 
| 2126 |  | 
| 2127 | 
            -
                def get_patches_batch(self, x, p):
         | 
| 2128 | 
            -
                    _size_h, _size_w = p.shape[2:]
         | 
| 2129 | 
            -
                    patches_batch = []
         | 
| 2130 | 
            -
                    for idx in range(x.shape[0]):
         | 
| 2131 | 
            -
                        columns_x = torch.split(x[idx], split_size_or_sections=_size_w, dim=-1)
         | 
| 2132 | 
            -
                        patches_x = []
         | 
| 2133 | 
            -
                        for column_x in columns_x:
         | 
| 2134 | 
            -
                            patches_x += [p.unsqueeze(0) for p in torch.split(column_x, split_size_or_sections=_size_h, dim=-2)]
         | 
| 2135 | 
            -
                        patch_sample = torch.cat(patches_x, dim=1)
         | 
| 2136 | 
            -
                        patches_batch.append(patch_sample)
         | 
| 2137 | 
            -
                    return torch.cat(patches_batch, dim=0)
         | 
| 2138 | 
            -
             | 
| 2139 | 
             
                def forward(self, features):
         | 
| 2140 | 
             
                    if self.training and self.config.out_ref:
         | 
| 2141 | 
             
                        outs_gdt_pred = []
         | 
| @@ -2146,10 +2154,10 @@ class Decoder(nn.Module): | |
| 2146 | 
             
                    outs = []
         | 
| 2147 |  | 
| 2148 | 
             
                    if self.config.dec_ipt:
         | 
| 2149 | 
            -
                        patches_batch =  | 
| 2150 | 
             
                        x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1)
         | 
| 2151 | 
             
                    p4 = self.decoder_block4(x4)
         | 
| 2152 | 
            -
                    m4 = self.conv_ms_spvn_4(p4) if self.config.ms_supervision else None
         | 
| 2153 | 
             
                    if self.config.out_ref:
         | 
| 2154 | 
             
                        p4_gdt = self.gdt_convs_4(p4)
         | 
| 2155 | 
             
                        if self.training:
         | 
| @@ -2167,10 +2175,10 @@ class Decoder(nn.Module): | |
| 2167 | 
             
                    _p3 = _p4 + self.lateral_block4(x3)
         | 
| 2168 |  | 
| 2169 | 
             
                    if self.config.dec_ipt:
         | 
| 2170 | 
            -
                        patches_batch =  | 
| 2171 | 
             
                        _p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1)
         | 
| 2172 | 
             
                    p3 = self.decoder_block3(_p3)
         | 
| 2173 | 
            -
                    m3 = self.conv_ms_spvn_3(p3) if self.config.ms_supervision else None
         | 
| 2174 | 
             
                    if self.config.out_ref:
         | 
| 2175 | 
             
                        p3_gdt = self.gdt_convs_3(p3)
         | 
| 2176 | 
             
                        if self.training:
         | 
| @@ -2193,10 +2201,10 @@ class Decoder(nn.Module): | |
| 2193 | 
             
                    _p2 = _p3 + self.lateral_block3(x2)
         | 
| 2194 |  | 
| 2195 | 
             
                    if self.config.dec_ipt:
         | 
| 2196 | 
            -
                        patches_batch =  | 
| 2197 | 
             
                        _p2 = torch.cat((_p2, self.ipt_blk3(F.interpolate(patches_batch, size=x2.shape[2:], mode='bilinear', align_corners=True))), 1)
         | 
| 2198 | 
             
                    p2 = self.decoder_block2(_p2)
         | 
| 2199 | 
            -
                    m2 = self.conv_ms_spvn_2(p2) if self.config.ms_supervision else None
         | 
| 2200 | 
             
                    if self.config.out_ref:
         | 
| 2201 | 
             
                        p2_gdt = self.gdt_convs_2(p2)
         | 
| 2202 | 
             
                        if self.training:
         | 
| @@ -2214,17 +2222,17 @@ class Decoder(nn.Module): | |
| 2214 | 
             
                    _p1 = _p2 + self.lateral_block2(x1)
         | 
| 2215 |  | 
| 2216 | 
             
                    if self.config.dec_ipt:
         | 
| 2217 | 
            -
                        patches_batch =  | 
| 2218 | 
             
                        _p1 = torch.cat((_p1, self.ipt_blk2(F.interpolate(patches_batch, size=x1.shape[2:], mode='bilinear', align_corners=True))), 1)
         | 
| 2219 | 
             
                    _p1 = self.decoder_block1(_p1)
         | 
| 2220 | 
             
                    _p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True)
         | 
| 2221 |  | 
| 2222 | 
             
                    if self.config.dec_ipt:
         | 
| 2223 | 
            -
                        patches_batch =  | 
| 2224 | 
             
                        _p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1)
         | 
| 2225 | 
             
                    p1_out = self.conv_out1(_p1)
         | 
| 2226 |  | 
| 2227 | 
            -
                    if self.config.ms_supervision:
         | 
| 2228 | 
             
                        outs.append(m4)
         | 
| 2229 | 
             
                        outs.append(m3)
         | 
| 2230 | 
             
                        outs.append(m2)
         | 
| @@ -2241,4 +2249,4 @@ class SimpleConvs(nn.Module): | |
| 2241 | 
             
                    self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, 1)
         | 
| 2242 |  | 
| 2243 | 
             
                def forward(self, x):
         | 
| 2244 | 
            -
                    return self.conv_out(self.conv1(x))
         | 
|  | |
| 4 | 
             
            import math
         | 
| 5 | 
             
            from transformers import PretrainedConfig
         | 
| 6 |  | 
| 7 | 
            +
             | 
| 8 | 
             
            class Config(PretrainedConfig):
         | 
| 9 | 
             
                def __init__(self) -> None:
         | 
| 10 | 
            +
                    # Compatible with the latest version of transformers
         | 
| 11 | 
            +
                    super().__init__()
         | 
| 12 | 
            +
             | 
| 13 | 
             
                    # PATH settings
         | 
| 14 | 
            +
                    self.sys_home_dir = os.path.expanduser('~')     # Make up your file system as: SYS_HOME_DIR/codes/dis/BiRefNet, SYS_HOME_DIR/datasets/dis/xx, SYS_HOME_DIR/weights/xx
         | 
| 15 |  | 
| 16 | 
             
                    # TASK settings
         | 
| 17 | 
             
                    self.task = ['DIS5K', 'COD', 'HRSOD', 'DIS5K+HRSOD+HRS10K', 'P3M-10k'][0]
         | 
|  | |
| 619 |  | 
| 620 | 
             
            # config = Config()
         | 
| 621 |  | 
| 622 | 
            +
             | 
| 623 | 
             
            class Mlp(nn.Module):
         | 
| 624 | 
             
                """ Multilayer perceptron."""
         | 
| 625 |  | 
|  | |
| 744 | 
             
                        attn = (q @ k.transpose(-2, -1))
         | 
| 745 |  | 
| 746 | 
             
                        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
         | 
| 747 | 
            +
                            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
         | 
| 748 | 
            +
                        )   # Wh*Ww, Wh*Ww, nH
         | 
| 749 | 
             
                        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
         | 
| 750 | 
             
                        attn = attn + relative_position_bias.unsqueeze(0)
         | 
| 751 |  | 
|  | |
| 980 | 
             
                    """
         | 
| 981 |  | 
| 982 | 
             
                    # calculate attention mask for SW-MSA
         | 
| 983 | 
            +
                    # Turn int to torch.tensor for the compatiability with torch.compile in PyTorch 2.5.
         | 
| 984 | 
            +
                    Hp = torch.ceil(torch.tensor(H) / self.window_size).to(torch.int64) * self.window_size
         | 
| 985 | 
            +
                    Wp = torch.ceil(torch.tensor(W) / self.window_size).to(torch.int64) * self.window_size
         | 
| 986 | 
             
                    img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # 1 Hp Wp 1
         | 
| 987 | 
             
                    h_slices = (slice(0, -self.window_size),
         | 
| 988 | 
             
                                slice(-self.window_size, -self.shift_size),
         | 
|  | |
| 999 | 
             
                    mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
         | 
| 1000 | 
             
                    mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
         | 
| 1001 | 
             
                    attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
         | 
| 1002 | 
            +
                    attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)).to(x.dtype)
         | 
| 1003 |  | 
| 1004 | 
             
                    for blk in self.blocks:
         | 
| 1005 | 
             
                        blk.H, blk.W = H, W
         | 
|  | |
| 1968 | 
             
            import torch.nn.functional as F
         | 
| 1969 | 
             
            from kornia.filters import laplacian
         | 
| 1970 | 
             
            from transformers import PreTrainedModel
         | 
| 1971 | 
            +
            from einops import rearrange
         | 
| 1972 |  | 
| 1973 | 
             
            # from config import Config
         | 
| 1974 | 
             
            # from dataset import class_labels_TR_sorted
         | 
|  | |
| 1982 | 
             
            from .BiRefNet_config import BiRefNetConfig
         | 
| 1983 |  | 
| 1984 |  | 
| 1985 | 
            +
            def image2patches(image, grid_h=2, grid_w=2, patch_ref=None, transformation='b c (hg h) (wg w) -> (b hg wg) c h w'):
         | 
| 1986 | 
            +
                if patch_ref is not None:
         | 
| 1987 | 
            +
                    grid_h, grid_w = image.shape[-2] // patch_ref.shape[-2], image.shape[-1] // patch_ref.shape[-1]
         | 
| 1988 | 
            +
                patches = rearrange(image, transformation, hg=grid_h, wg=grid_w)
         | 
| 1989 | 
            +
                return patches
         | 
| 1990 | 
            +
             | 
| 1991 | 
            +
            def patches2image(patches, grid_h=2, grid_w=2, patch_ref=None, transformation='(b hg wg) c h w -> b c (hg h) (wg w)'):
         | 
| 1992 | 
            +
                if patch_ref is not None:
         | 
| 1993 | 
            +
                    grid_h, grid_w = patch_ref.shape[-2] // patches[0].shape[-2], patch_ref.shape[-1] // patches[0].shape[-1]
         | 
| 1994 | 
            +
                image = rearrange(patches, transformation, hg=grid_h, wg=grid_w)
         | 
| 1995 | 
            +
                return image
         | 
| 1996 | 
            +
             | 
| 1997 | 
             
            class BiRefNet(
         | 
| 1998 | 
             
                PreTrainedModel
         | 
| 1999 | 
             
            ):
         | 
|  | |
| 2144 | 
             
                            self.gdt_convs_attn_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
         | 
| 2145 | 
             
                            self.gdt_convs_attn_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
         | 
| 2146 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 2147 | 
             
                def forward(self, features):
         | 
| 2148 | 
             
                    if self.training and self.config.out_ref:
         | 
| 2149 | 
             
                        outs_gdt_pred = []
         | 
|  | |
| 2154 | 
             
                    outs = []
         | 
| 2155 |  | 
| 2156 | 
             
                    if self.config.dec_ipt:
         | 
| 2157 | 
            +
                        patches_batch = image2patches(x, patch_ref=x4, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
         | 
| 2158 | 
             
                        x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1)
         | 
| 2159 | 
             
                    p4 = self.decoder_block4(x4)
         | 
| 2160 | 
            +
                    m4 = self.conv_ms_spvn_4(p4) if self.config.ms_supervision and self.training else None
         | 
| 2161 | 
             
                    if self.config.out_ref:
         | 
| 2162 | 
             
                        p4_gdt = self.gdt_convs_4(p4)
         | 
| 2163 | 
             
                        if self.training:
         | 
|  | |
| 2175 | 
             
                    _p3 = _p4 + self.lateral_block4(x3)
         | 
| 2176 |  | 
| 2177 | 
             
                    if self.config.dec_ipt:
         | 
| 2178 | 
            +
                        patches_batch = image2patches(x, patch_ref=_p3, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
         | 
| 2179 | 
             
                        _p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1)
         | 
| 2180 | 
             
                    p3 = self.decoder_block3(_p3)
         | 
| 2181 | 
            +
                    m3 = self.conv_ms_spvn_3(p3) if self.config.ms_supervision and self.training else None
         | 
| 2182 | 
             
                    if self.config.out_ref:
         | 
| 2183 | 
             
                        p3_gdt = self.gdt_convs_3(p3)
         | 
| 2184 | 
             
                        if self.training:
         | 
|  | |
| 2201 | 
             
                    _p2 = _p3 + self.lateral_block3(x2)
         | 
| 2202 |  | 
| 2203 | 
             
                    if self.config.dec_ipt:
         | 
| 2204 | 
            +
                        patches_batch = image2patches(x, patch_ref=_p2, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
         | 
| 2205 | 
             
                        _p2 = torch.cat((_p2, self.ipt_blk3(F.interpolate(patches_batch, size=x2.shape[2:], mode='bilinear', align_corners=True))), 1)
         | 
| 2206 | 
             
                    p2 = self.decoder_block2(_p2)
         | 
| 2207 | 
            +
                    m2 = self.conv_ms_spvn_2(p2) if self.config.ms_supervision and self.training else None
         | 
| 2208 | 
             
                    if self.config.out_ref:
         | 
| 2209 | 
             
                        p2_gdt = self.gdt_convs_2(p2)
         | 
| 2210 | 
             
                        if self.training:
         | 
|  | |
| 2222 | 
             
                    _p1 = _p2 + self.lateral_block2(x1)
         | 
| 2223 |  | 
| 2224 | 
             
                    if self.config.dec_ipt:
         | 
| 2225 | 
            +
                        patches_batch = image2patches(x, patch_ref=_p1, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
         | 
| 2226 | 
             
                        _p1 = torch.cat((_p1, self.ipt_blk2(F.interpolate(patches_batch, size=x1.shape[2:], mode='bilinear', align_corners=True))), 1)
         | 
| 2227 | 
             
                    _p1 = self.decoder_block1(_p1)
         | 
| 2228 | 
             
                    _p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True)
         | 
| 2229 |  | 
| 2230 | 
             
                    if self.config.dec_ipt:
         | 
| 2231 | 
            +
                        patches_batch = image2patches(x, patch_ref=_p1, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
         | 
| 2232 | 
             
                        _p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1)
         | 
| 2233 | 
             
                    p1_out = self.conv_out1(_p1)
         | 
| 2234 |  | 
| 2235 | 
            +
                    if self.config.ms_supervision and self.training:
         | 
| 2236 | 
             
                        outs.append(m4)
         | 
| 2237 | 
             
                        outs.append(m3)
         | 
| 2238 | 
             
                        outs.append(m2)
         | 
|  | |
| 2249 | 
             
                    self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, 1)
         | 
| 2250 |  | 
| 2251 | 
             
                def forward(self, x):
         | 
| 2252 | 
            +
                    return self.conv_out(self.conv1(x))
         | 

