# coding=utf-8 # @Author : Saurabhchand Bhati # @Affiliation : Massachusetts Institute of Technology # VMamba backbone is from https://github.com/MzeroMiko/VMamba/blob/main/vmamba.py # VMambaLayer, VMambaModel, VMambaForImageClassification are implemnted based on VMamba # SS2Dv0, SS2Dv1, SS2S are merged into one class and initiliazation is limited to v05_noz, # patch embeddings is limited to v2 and downsample is limited to v3. # MIT License # Copyright (c) 2024 MzeroMiko, Saurabhchand Bhati # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. """VMamba: Visual State Space Model configuration model""" import math import torch import warnings import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.models.layers import DropPath, trunc_normal_ from functools import partial from typing import Optional, Callable, Any, Union from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss from transformers.modeling_outputs import ImageClassifierOutput from transformers.utils import logging from transformers.modeling_utils import PreTrainedModel from .configuration_vmamba import VMambaConfig logger = logging.get_logger(__name__) # General docstring _CONFIG_FOR_DOC = "VMambaConfig" WITH_TRITON = True # WITH_TRITON = False try: import triton import triton.language as tl except: WITH_TRITON = False warnings.warn("Triton not installed, fall back to pytorch implements.") # to make sure cached_property can be loaded for triton if WITH_TRITON: try: from functools import cached_property except: warnings.warn("if you are using py37, add this line to functools.py: " "cached_property = lambda func: property(lru_cache()(func))") # torch implementation ======================================== def cross_scan_fwd(x: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0): if in_channel_first: B, C, H, W = x.shape if scans == 0: y = x.new_empty((B, 4, C, H * W)) y[:, 0, :, :] = x.flatten(2, 3) y[:, 1, :, :] = x.transpose(dim0=2, dim1=3).flatten(2, 3) y[:, 2:4, :, :] = torch.flip(y[:, 0:2, :, :], dims=[-1]) elif scans == 1: y = x.view(B, 1, C, H * W).repeat(1, 4, 1, 1) elif scans == 2: y = x.view(B, 1, C, H * W).repeat(1, 2, 1, 1) y = torch.cat([y, y.flip(dims=[-1])], dim=1) elif scans == 3: y = x.new_empty((B, 4, C, H * W)) y[:, 0, :, :] = x.flatten(2, 3) y[:, 1, :, :] = torch.rot90(x, 1, dims=(2, 3)).flatten(2, 3) y[:, 2, :, :] = torch.rot90(x, 2, dims=(2, 3)).flatten(2, 3) y[:, 3, :, :] = torch.rot90(x, 3, dims=(2, 3)).flatten(2, 3) else: B, H, W, C = x.shape if scans == 0: y = x.new_empty((B, H * W, 4, C)) y[:, :, 0, :] = x.flatten(1, 2) y[:, :, 1, :] = x.transpose(dim0=1, dim1=2).flatten(1, 2) y[:, :, 2:4, :] = torch.flip(y[:, :, 0:2, :], dims=[1]) elif scans == 1: y = x.view(B, H * W, 1, C).repeat(1, 1, 4, 1) elif scans == 2: y = x.view(B, H * W, 1, C).repeat(1, 1, 2, 1) y = torch.cat([y, y.flip(dims=[1])], dim=2) elif scans == 3: y = x.new_empty((B, H * W, 4, C)) y[:, :, 0, :] = x.flatten(1, 2) y[:, :, 1, :] = torch.rot90(x, 1, dims=(1, 2)).flatten(1, 2) y[:, :, 2, :] = torch.rot90(x, 2, dims=(1, 2)).flatten(1, 2) y[:, :, 3, :] = torch.rot90(x, 3, dims=(1, 2)).flatten(1, 2) if in_channel_first and (not out_channel_first): y = y.permute(0, 3, 1, 2).contiguous() elif (not in_channel_first) and out_channel_first: y = y.permute(0, 2, 3, 1).contiguous() return y def cross_merge_fwd(y: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0): if out_channel_first: B, K, D, H, W = y.shape y = y.view(B, K, D, -1) if scans == 0: y = y[:, 0:2] + y[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) y = y[:, 0] + y[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1) elif scans == 1: y = y.sum(1) elif scans == 2: y = y[:, 0:2] + y[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) y = y.sum(1) elif scans == 3: oy = y[:, 0, :, :].contiguous().view(B, D, -1) oy = oy + torch.rot90(y.view(B, K, D, W, H)[:, 1, :, :, :], -1, dims=(2, 3)).flatten(2, 3) oy = oy + torch.rot90(y.view(B, K, D, H, W)[:, 2, :, :, :], -2, dims=(2, 3)).flatten(2, 3) oy = oy + torch.rot90(y.view(B, K, D, W, H)[:, 3, :, :, :], -3, dims=(2, 3)).flatten(2, 3) y = oy else: B, H, W, K, D = y.shape y = y.view(B, -1, K, D) if scans == 0: y = y[:, :, 0:2] + y[:, :, 2:4].flip(dims=[1]).view(B, -1, 2, D) y = y[:, :, 0] + y[:, :, 1].view(B, W, H, -1).transpose(dim0=1, dim1=2).contiguous().view(B, -1, D) elif scans == 1: y = y.sum(2) elif scans == 2: y = y[:, :, 0:2] + y[:, :, 2:4].flip(dims=[1]).view(B, -1, 2, D) y = y.sum(2) elif scans == 3: oy = y[:, :, 0, :].contiguous().view(B, -1, D) oy = oy + torch.rot90(y.view(B, W, H, K, D)[:, :, :, 1, :], -1, dims=(1, 2)).flatten(1, 2) oy = oy + torch.rot90(y.view(B, H, W, K, D)[:, :, :, 2, :], -2, dims=(1, 2)).flatten(1, 2) oy = oy + torch.rot90(y.view(B, W, H, K, D)[:, :, :, 3, :], -3, dims=(1, 2)).flatten(1, 2) y = oy if in_channel_first and (not out_channel_first): y = y.permute(0, 2, 1).contiguous() elif (not in_channel_first) and out_channel_first: y = y.permute(0, 2, 1).contiguous() return y def cross_scan1b1_fwd(x: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0): if in_channel_first: B, _, C, H, W = x.shape if scans == 0: y = torch.stack([ x[:, 0].flatten(2, 3), x[:, 1].transpose(dim0=2, dim1=3).flatten(2, 3), torch.flip(x[:, 2].flatten(2, 3), dims=[-1]), torch.flip(x[:, 3].transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1]), ], dim=1) elif scans == 1: y = x.flatten(2, 3) elif scans == 2: y = torch.stack([ x[:, 0].flatten(2, 3), x[:, 1].flatten(2, 3), torch.flip(x[:, 2].flatten(2, 3), dims=[-1]), torch.flip(x[:, 3].flatten(2, 3), dims=[-1]), ], dim=1) elif scans == 3: y = torch.stack([ x[:, 0, :, :, :].flatten(2, 3), torch.rot90(x[:, 1, :, :, :], 1, dims=(2, 3)).flatten(2, 3), torch.rot90(x[:, 2, :, :, :], 2, dims=(2, 3)).flatten(2, 3), torch.rot90(x[:, 3, :, :, :], 3, dims=(2, 3)).flatten(2, 3), ], dim=1) else: B, H, W, _, C = x.shape if scans == 0: y = torch.stack([ x[:, :, :, 0].flatten(1, 2), x[:, :, :, 1].transpose(dim0=1, dim1=2).flatten(1, 2), torch.flip(x[:, :, :, 2].flatten(1, 2), dims=[1]), torch.flip(x[:, :, :, 3].transpose(dim0=1, dim1=2).flatten(1, 2), dims=[1]), ], dim=2) elif scans == 1: y = x.flatten(1, 2) elif scans == 2: y = torch.stack([ x[:, 0].flatten(1, 2), x[:, 1].flatten(1, 2), torch.flip(x[:, 2].flatten(1, 2), dims=[-1]), torch.flip(x[:, 3].flatten(1, 2), dims=[-1]), ], dim=2) elif scans == 3: y = torch.stack([ x[:, :, :, 0, :].flatten(1, 2), torch.rot90(x[:, :, :, 1, :], 1, dims=(1, 2)).flatten(1, 2), torch.rot90(x[:, :, :, 2, :], 2, dims=(1, 2)).flatten(1, 2), torch.rot90(x[:, :, :, 3, :], 3, dims=(1, 2)).flatten(1, 2), ], dim=1) if in_channel_first and (not out_channel_first): y = y.permute(0, 3, 1, 2).contiguous() elif (not in_channel_first) and out_channel_first: y = y.permute(0, 2, 3, 1).contiguous() return y def cross_merge1b1_fwd(y: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0): if out_channel_first: B, K, D, H, W = y.shape y = y.view(B, K, D, -1) if scans == 0: y = torch.stack([ y[:, 0], y[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).flatten(2, 3), torch.flip(y[:, 2], dims=[-1]), torch.flip(y[:, 3].view(B, -1, W, H).transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1]), ], dim=1) elif scans == 1: y = y elif scans == 2: y = torch.stack([ y[:, 0], y[:, 1], torch.flip(y[:, 2], dims=[-1]), torch.flip(y[:, 3], dims=[-1]), ], dim=1) elif scans == 3: y = torch.stack([ y[:, 0, :, :].contiguous().view(B, D, -1), torch.rot90(y.view(B, K, D, W, H)[:, 1, :, :, :], -1, dims=(2, 3)).flatten(2, 3), torch.rot90(y.view(B, K, D, H, W)[:, 2, :, :, :], -2, dims=(2, 3)).flatten(2, 3), torch.rot90(y.view(B, K, D, W, H)[:, 3, :, :, :], -3, dims=(2, 3)).flatten(2, 3), ], dim=1) else: B, H, W, K, D = y.shape y = y.view(B, -1, K, D) if scans == 0: y = torch.stack([ y[:, :, 0], y[:, :, 1].view(B, W, H, -1).transpose(dim0=1, dim1=2).flatten(1, 2), torch.flip(y[:, :, 2], dims=[1]), torch.flip(y[:, :, 3].view(B, W, H, -1).transpose(dim0=1, dim1=2).flatten(1, 2), dims=[1]), ], dim=2) elif scans == 1: y = y elif scans == 2: y = torch.stack([ y[:, :, 0], y[:, :, 1], torch.flip(y[:, :, 2], dims=[1]), torch.flip(y[:, :, 3], dims=[1]), ], dim=2) elif scans == 3: y = torch.stack([ y[:, :, 0, :].contiguous().view(B, -1, D), torch.rot90(y.view(B, W, H, K, D)[:, :, :, 1, :], -1, dims=(1, 2)).flatten(1, 2), torch.rot90(y.view(B, H, W, K, D)[:, :, :, 2, :], -2, dims=(1, 2)).flatten(1, 2), torch.rot90(y.view(B, W, H, K, D)[:, :, :, 3, :], -3, dims=(1, 2)).flatten(1, 2), ], dim=2) if out_channel_first and (not in_channel_first): y = y.permute(0, 3, 1, 2).contiguous() elif (not out_channel_first) and in_channel_first: y = y.permute(0, 2, 3, 1).contiguous() return y class CrossScanF(torch.autograd.Function): @staticmethod def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0): # x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C) # y: (B, 4, C, H * W) | (B, H * W, 4, C) ctx.in_channel_first = in_channel_first ctx.out_channel_first = out_channel_first ctx.one_by_one = one_by_one ctx.scans = scans if one_by_one: B, K, C, H, W = x.shape if not in_channel_first: B, H, W, K, C = x.shape else: B, C, H, W = x.shape if not in_channel_first: B, H, W, C = x.shape ctx.shape = (B, C, H, W) _fn = cross_scan1b1_fwd if one_by_one else cross_scan_fwd y = _fn(x, in_channel_first, out_channel_first, scans) return y @staticmethod def backward(ctx, ys: torch.Tensor): # out: (b, k, d, l) in_channel_first = ctx.in_channel_first out_channel_first = ctx.out_channel_first one_by_one = ctx.one_by_one scans = ctx.scans B, C, H, W = ctx.shape ys = ys.view(B, -1, C, H, W) if out_channel_first else ys.view(B, H, W, -1, C) _fn = cross_merge1b1_fwd if one_by_one else cross_merge_fwd y = _fn(ys, in_channel_first, out_channel_first, scans) if one_by_one: y = y.view(B, 4, -1, H, W) if in_channel_first else y.view(B, H, W, 4, -1) else: y = y.view(B, -1, H, W) if in_channel_first else y.view(B, H, W, -1) return y, None, None, None, None class CrossMergeF(torch.autograd.Function): @staticmethod def forward(ctx, ys: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0): # x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C) # y: (B, 4, C, H * W) | (B, H * W, 4, C) ctx.in_channel_first = in_channel_first ctx.out_channel_first = out_channel_first ctx.one_by_one = one_by_one ctx.scans = scans B, K, C, H, W = ys.shape if not out_channel_first: B, H, W, K, C = ys.shape ctx.shape = (B, C, H, W) _fn = cross_merge1b1_fwd if one_by_one else cross_merge_fwd y = _fn(ys, in_channel_first, out_channel_first, scans) return y @staticmethod def backward(ctx, x: torch.Tensor): # B, D, L = x.shape # out: (b, k, d, h, w) in_channel_first = ctx.in_channel_first out_channel_first = ctx.out_channel_first one_by_one = ctx.one_by_one scans = ctx.scans B, C, H, W = ctx.shape if not one_by_one: if in_channel_first: x = x.view(B, C, H, W) else: x = x.view(B, H, W, C) else: if in_channel_first: x = x.view(B, 4, C, H, W) else: x = x.view(B, H, W, 4, C) _fn = cross_scan1b1_fwd if one_by_one else cross_scan_fwd x = _fn(x, in_channel_first, out_channel_first, scans) x = x.view(B, 4, C, H, W) if out_channel_first else x.view(B, H, W, 4, C) return x, None, None, None, None # triton implements ======================================== @triton.jit def triton_cross_scan_flex( x: tl.tensor, # (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C) y: tl.tensor, # (B, 4, C, H, W) | (B, H, W, 4, C) x_layout: tl.constexpr, y_layout: tl.constexpr, operation: tl.constexpr, onebyone: tl.constexpr, scans: tl.constexpr, BC: tl.constexpr, BH: tl.constexpr, BW: tl.constexpr, DC: tl.constexpr, DH: tl.constexpr, DW: tl.constexpr, NH: tl.constexpr, NW: tl.constexpr, ): # x_layout = 0 # y_layout = 1 # 0 BCHW, 1 BHWC # operation = 0 # 0 scan, 1 merge # onebyone = 0 # 0 false, 1 true # scans = 0 # 0 cross scan, 1 unidirectional, 2 bidirectional i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_h, i_w = (i_hw // NW), (i_hw % NW) _mask_h = (i_h * BH + tl.arange(0, BH)) < DH _mask_w = (i_w * BW + tl.arange(0, BW)) < DW _mask_hw = _mask_h[:, None] & _mask_w[None, :] _for_C = min(DC - i_c * BC, BC) pos_h = (i_h * BH + tl.arange(0, BH)[:, None]) pos_w = (i_w * BW + tl.arange(0, BW)[None, :]) neg_h = (DH - i_h * BH - 1 - tl.arange(0, BH)[:, None]) neg_w = (DW - i_w * BW - 1 - tl.arange(0, BW)[None, :]) if scans == 0: # none; trans; flip; trans + flip; HWRoute0 = pos_h * DW + pos_w HWRoute1 = pos_w * DH + pos_h # trans HWRoute2 = neg_h * DW + neg_w # flip HWRoute3 = neg_w * DH + neg_h # trans + flip elif scans == 1: # none; none; none; none; HWRoute0 = pos_h * DW + pos_w HWRoute1 = HWRoute0 HWRoute2 = HWRoute0 HWRoute3 = HWRoute0 elif scans == 2: # none; none; flip; flip; HWRoute0 = pos_h * DW + pos_w HWRoute1 = HWRoute0 HWRoute2 = neg_h * DW + neg_w # flip HWRoute3 = HWRoute2 elif scans == 3: # none; rot90; rot180==flip; rot270; HWRoute0 = pos_h * DW + pos_w HWRoute1 = neg_w * DH + pos_h HWRoute2 = neg_h * DW + neg_w HWRoute3 = pos_w * DH + neg_h _tmp1 = DC * DH * DW y_ptr_base = y + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if y_layout == 0 else i_c * BC) if y_layout == 0: p_y1 = y_ptr_base + HWRoute0 p_y2 = y_ptr_base + _tmp1 + HWRoute1 p_y3 = y_ptr_base + 2 * _tmp1 + HWRoute2 p_y4 = y_ptr_base + 3 * _tmp1 + HWRoute3 else: p_y1 = y_ptr_base + HWRoute0 * 4 * DC p_y2 = y_ptr_base + DC + HWRoute1 * 4 * DC p_y3 = y_ptr_base + 2 * DC + HWRoute2 * 4 * DC p_y4 = y_ptr_base + 3 * DC + HWRoute3 * 4 * DC if onebyone == 0: x_ptr_base = x + i_b * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC) if x_layout == 0: p_x = x_ptr_base + HWRoute0 else: p_x = x_ptr_base + HWRoute0 * DC if operation == 0: for idxc in range(_for_C): _idx_x = idxc * DH * DW if x_layout == 0 else idxc _idx_y = idxc * DH * DW if y_layout == 0 else idxc _x = tl.load(p_x + _idx_x, mask=_mask_hw) tl.store(p_y1 + _idx_y, _x, mask=_mask_hw) tl.store(p_y2 + _idx_y, _x, mask=_mask_hw) tl.store(p_y3 + _idx_y, _x, mask=_mask_hw) tl.store(p_y4 + _idx_y, _x, mask=_mask_hw) elif operation == 1: for idxc in range(_for_C): _idx_x = idxc * DH * DW if x_layout == 0 else idxc _idx_y = idxc * DH * DW if y_layout == 0 else idxc _y1 = tl.load(p_y1 + _idx_y, mask=_mask_hw) _y2 = tl.load(p_y2 + _idx_y, mask=_mask_hw) _y3 = tl.load(p_y3 + _idx_y, mask=_mask_hw) _y4 = tl.load(p_y4 + _idx_y, mask=_mask_hw) tl.store(p_x + _idx_x, _y1 + _y2 + _y3 + _y4, mask=_mask_hw) else: x_ptr_base = x + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC) if x_layout == 0: p_x1 = x_ptr_base + HWRoute0 p_x2 = p_x1 + _tmp1 p_x3 = p_x2 + _tmp1 p_x4 = p_x3 + _tmp1 else: p_x1 = x_ptr_base + HWRoute0 * 4 * DC p_x2 = p_x1 + DC p_x3 = p_x2 + DC p_x4 = p_x3 + DC if operation == 0: for idxc in range(_for_C): _idx_x = idxc * DH * DW if x_layout == 0 else idxc _idx_y = idxc * DH * DW if y_layout == 0 else idxc tl.store(p_y1 + _idx_y, tl.load(p_x1 + _idx_x, mask=_mask_hw), mask=_mask_hw) tl.store(p_y2 + _idx_y, tl.load(p_x2 + _idx_x, mask=_mask_hw), mask=_mask_hw) tl.store(p_y3 + _idx_y, tl.load(p_x3 + _idx_x, mask=_mask_hw), mask=_mask_hw) tl.store(p_y4 + _idx_y, tl.load(p_x4 + _idx_x, mask=_mask_hw), mask=_mask_hw) else: for idxc in range(_for_C): _idx_x = idxc * DH * DW if x_layout == 0 else idxc _idx_y = idxc * DH * DW if y_layout == 0 else idxc tl.store(p_x1 + _idx_x, tl.load(p_y1 + _idx_y), mask=_mask_hw) tl.store(p_x2 + _idx_x, tl.load(p_y2 + _idx_y), mask=_mask_hw) tl.store(p_x3 + _idx_x, tl.load(p_y3 + _idx_y), mask=_mask_hw) tl.store(p_x4 + _idx_x, tl.load(p_y4 + _idx_y), mask=_mask_hw) class CrossScanTritonF(torch.autograd.Function): @staticmethod def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0): if one_by_one: if in_channel_first: B, _, C, H, W = x.shape else: B, H, W, _, C = x.shape else: if in_channel_first: B, C, H, W = x.shape else: B, H, W, C = x.shape B, C, H, W = int(B), int(C), int(H), int(W) BC, BH, BW = 1, 32, 32 NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC) ctx.in_channel_first = in_channel_first ctx.out_channel_first = out_channel_first ctx.one_by_one = one_by_one ctx.scans = scans ctx.shape = (B, C, H, W) ctx.triton_shape = (BC, BH, BW, NC, NH, NW) y = x.new_empty((B, 4, C, H * W)) if out_channel_first else x.new_empty((B, H * W, 4, C)) triton_cross_scan_flex[(NH * NW, NC, B)]( x.contiguous(), y, (0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans, BC, BH, BW, C, H, W, NH, NW ) return y @staticmethod def backward(ctx, y: torch.Tensor): in_channel_first = ctx.in_channel_first out_channel_first = ctx.out_channel_first one_by_one = ctx.one_by_one scans = ctx.scans B, C, H, W = ctx.shape BC, BH, BW, NC, NH, NW = ctx.triton_shape if one_by_one: x = y.new_empty((B, 4, C, H, W)) if in_channel_first else y.new_empty((B, H, W, 4, C)) else: x = y.new_empty((B, C, H, W)) if in_channel_first else y.new_empty((B, H, W, C)) triton_cross_scan_flex[(NH * NW, NC, B)]( x, y.contiguous(), (0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans, BC, BH, BW, C, H, W, NH, NW ) return x, None, None, None, None class CrossMergeTritonF(torch.autograd.Function): @staticmethod def forward(ctx, y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0): if out_channel_first: B, _, C, H, W = y.shape else: B, H, W, _, C = y.shape B, C, H, W = int(B), int(C), int(H), int(W) BC, BH, BW = 1, 32, 32 NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC) ctx.in_channel_first = in_channel_first ctx.out_channel_first = out_channel_first ctx.one_by_one = one_by_one ctx.scans = scans ctx.shape = (B, C, H, W) ctx.triton_shape = (BC, BH, BW, NC, NH, NW) if one_by_one: x = y.new_empty((B, 4, C, H * W)) if in_channel_first else y.new_empty((B, H * W, 4, C)) else: x = y.new_empty((B, C, H * W)) if in_channel_first else y.new_empty((B, H * W, C)) triton_cross_scan_flex[(NH * NW, NC, B)]( x, y.contiguous(), (0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans, BC, BH, BW, C, H, W, NH, NW ) return x @staticmethod def backward(ctx, x: torch.Tensor): in_channel_first = ctx.in_channel_first out_channel_first = ctx.out_channel_first one_by_one = ctx.one_by_one scans = ctx.scans B, C, H, W = ctx.shape BC, BH, BW, NC, NH, NW = ctx.triton_shape y = x.new_empty((B, 4, C, H, W)) if out_channel_first else x.new_empty((B, H, W, 4, C)) triton_cross_scan_flex[(NH * NW, NC, B)]( x.contiguous(), y, (0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans, BC, BH, BW, C, H, W, NH, NW ) return y, None, None, None, None, None # @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True) def cross_scan_fn(x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False): # x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C) # y: (B, 4, C, L) | (B, L, 4, C) # scans: 0: cross scan; 1 unidirectional; 2: bidirectional; CSF = CrossScanTritonF if WITH_TRITON and x.is_cuda and (not force_torch) else CrossScanF if x.is_cuda: with torch.cuda.device(x.device): return CSF.apply(x, in_channel_first, out_channel_first, one_by_one, scans) else: return CrossScanF.apply(x, in_channel_first, out_channel_first, one_by_one, scans) # @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True) def cross_merge_fn(y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False): # y: (B, 4, C, L) | (B, L, 4, C) # x: (B, C, H * W) | (B, H * W, C) | (B, 4, C, H * W) | (B, H * W, 4, C) # scans: 0: cross scan; 1 unidirectional; 2: bidirectional; CMF = CrossMergeTritonF if WITH_TRITON and y.is_cuda and (not force_torch) else CrossMergeF if y.is_cuda: with torch.cuda.device(y.device): return CMF.apply(y, in_channel_first, out_channel_first, one_by_one, scans) else: return CrossMergeF.apply(y, in_channel_first, out_channel_first, one_by_one, scans) ########################################################## # csms6s.py ########################################################## WITH_SELECTIVESCAN_MAMBA = True try: import selective_scan_cuda except ImportError: WITH_SELECTIVESCAN_MAMBA = False def selective_scan_torch( u: torch.Tensor, # (B, K * C, L) delta: torch.Tensor, # (B, K * C, L) A: torch.Tensor, # (K * C, N) B: torch.Tensor, # (B, K, N, L) C: torch.Tensor, # (B, K, N, L) D: torch.Tensor = None, # (K * C) delta_bias: torch.Tensor = None, # (K * C) delta_softplus=True, oflex=True, *args, **kwargs ): dtype_in = u.dtype Batch, K, N, L = B.shape KCdim = u.shape[1] Cdim = int(KCdim / K) assert u.shape == (Batch, KCdim, L) assert delta.shape == (Batch, KCdim, L) assert A.shape == (KCdim, N) assert C.shape == B.shape if delta_bias is not None: delta = delta + delta_bias[..., None] if delta_softplus: delta = torch.nn.functional.softplus(delta) u, delta, A, B, C = u.float(), delta.float(), A.float(), B.float(), C.float() B = B.view(Batch, K, 1, N, L).repeat(1, 1, Cdim, 1, 1).view(Batch, KCdim, N, L) C = C.view(Batch, K, 1, N, L).repeat(1, 1, Cdim, 1, 1).view(Batch, KCdim, N, L) deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) if True: x = A.new_zeros((Batch, KCdim, N)) ys = [] for i in range(L): x = deltaA[:, :, i, :] * x + deltaB_u[:, :, i, :] y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) ys.append(y) y = torch.stack(ys, dim=2) # (B, C, L) out = y if D is None else y + u * D.unsqueeze(-1) return out if oflex else out.to(dtype=dtype_in) class SelectiveScanCuda(torch.autograd.Function): @staticmethod @torch.cuda.amp.custom_fwd def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, oflex=True, backend=None): ctx.delta_softplus = delta_softplus # backend = "oflex" if WITH_SELECTIVESCAN_OFLEX and (backend is None) else backend # backend = "core" if WITH_SELECTIVESCAN_CORE and (backend is None) else backend backend = "mamba" if WITH_SELECTIVESCAN_MAMBA and (backend is None) else backend ctx.backend = backend if backend == "oflex": out, x, *rest = selective_scan_cuda_oflex.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1, oflex) elif backend == "mamba": out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, None, delta_bias, delta_softplus) ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) return out @staticmethod @torch.cuda.amp.custom_bwd def backward(ctx, dout, *args): u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors backend = ctx.backend if dout.stride(-1) != 1: dout = dout.contiguous() if backend == "oflex": du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_oflex.bwd( u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1 ) elif backend == "mamba": du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( u, delta, A, B, C, D, None, delta_bias, dout, x, None, None, ctx.delta_softplus, False ) return du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None def selective_scan_fn( u: torch.Tensor, # (B, K * C, L) delta: torch.Tensor, # (B, K * C, L) A: torch.Tensor, # (K * C, N) B: torch.Tensor, # (B, K, N, L) C: torch.Tensor, # (B, K, N, L) D: torch.Tensor = None, # (K * C) delta_bias: torch.Tensor = None, # (K * C) delta_softplus=True, oflex=True, backend=None, ): fn = selective_scan_torch if backend == "torch" or (not WITH_SELECTIVESCAN_MAMBA) else SelectiveScanCuda.apply return fn(u, delta, A, B, C, D, delta_bias, delta_softplus, oflex, backend) ########################################################## ############## HuggingFace modeling file ################# ########################################################## class VMambaLinear2d(nn.Linear): def __init__(self, *args, groups=1, **kwargs): nn.Linear.__init__(self, *args, **kwargs) self.groups = groups def forward(self, x: torch.Tensor): if len(x.shape) == 4: return F.conv2d(x, self.weight[:, :, None, None], self.bias, groups=self.groups) elif len(x.shape) == 3: return F.conv1d(x, self.weight[:, :, None], self.bias, groups=self.groups) def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): self_state_dict = self.state_dict() load_state_dict_keys = list(state_dict.keys()) if prefix + "weight" in load_state_dict_keys: state_dict[prefix + "weight"] = state_dict[prefix + "weight"].view_as(self_state_dict["weight"]) return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) class VMambaLayerNorm2d(nn.LayerNorm): def __init__(self, *args, **kwargs): nn.LayerNorm.__init__(self, *args, **kwargs) def forward(self, x: torch.Tensor): x = x.permute(0, 2, 3, 1) x = nn.LayerNorm.forward(self, x) x = x.permute(0, 3, 1, 2) return x class VMambaPatchEmbeddings(nn.Module): """ This class turns `input_values` into the initial `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a State-space model. """ def __init__(self, num_channels=3,patch_size=4,embed_dim=96): super().__init__() stride = patch_size // 2 kernel_size = stride + 1 padding = 1 self.projection = nn.Sequential( nn.Conv2d(num_channels, embed_dim // 2, kernel_size=kernel_size, stride=stride, padding=padding), VMambaLayerNorm2d(embed_dim // 2), nn.GELU(), nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding), VMambaLayerNorm2d(embed_dim), ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.projection(x) return x class VMambaDowsample(nn.Module): """ This class downsamples the input tensor using a convolutional layer followed by a layer normalization. """ def __init__(self, dim, out_dim, use_norm=True): super().__init__() self.down = nn.Conv2d(dim, out_dim, kernel_size=3, stride=2, padding=1) self.norm = VMambaLayerNorm2d(out_dim) if use_norm else nn.Identity() def forward(self, x): x = self.down(x) x = self.norm(x) return x class VMambaMlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = VMambaLinear2d(in_features, hidden_features) self.act = act_layer() self.fc2 = VMambaLinear2d(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class SS2D(nn.Module): def __init__( self, # basic dims =========== d_model=96, d_state=16, ssm_ratio=2.0, dt_rank="auto", act_layer=nn.SiLU, # dwconv =============== d_conv=3, conv_bias=True, # ====================== dropout=0.0, bias=False, # dt init ============== dt_min=0.001, dt_max=0.1, dt_init="random", dt_scale=1.0, dt_init_floor=1e-4, # forward_type="v05_noz" is always used # ====================== **kwargs, ): super().__init__() self.k_group = 4 self.d_model = int(d_model) self.d_state = int(d_state) self.d_inner = int(ssm_ratio * d_model) self.dt_rank = int(math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank) self.forward_core = partial(self.forward_corev2, force_fp32=False, no_einsum=True) self.with_dconv = d_conv > 1 # In projection self.in_proj = VMambaLinear2d(self.d_model, self.d_inner, bias=bias) self.act: nn.Module = act_layer() # Convolution if self.with_dconv: self.conv2d = nn.Conv2d( in_channels=self.d_inner, out_channels=self.d_inner, groups=self.d_inner, bias=conv_bias, kernel_size=d_conv, padding=(d_conv - 1) // 2, ) # x_proj and dt_proj self.x_proj = VMambaLinear2d(self.d_inner, self.k_group * (self.dt_rank + self.d_state * 2), groups=self.k_group, bias=False) self.dt_projs = VMambaLinear2d(self.dt_rank, self.k_group * self.d_inner, groups=self.k_group, bias=False) # out projection self.out_proj = VMambaLinear2d(self.d_inner, self.d_model, bias=bias) self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity() # Initialization self.A_logs, self.Ds, self.dt_projs_weight, self.dt_projs_bias = self.init_dt_A_D( self.d_state, self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=self.k_group, ) self.dt_projs.weight.data = self.dt_projs_weight.data.view(self.dt_projs.weight.shape) # self.dt_projs.bias.data = self.dt_projs_bias.data.view(self.dt_projs.bias.shape) del self.dt_projs_weight # del self.dt_projs_bias # Define out_norm directly with "LN2D" self.out_norm = VMambaLayerNorm2d(self.d_inner) @staticmethod def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4): dt_proj = nn.Linear(dt_rank, d_inner, bias=True) dt_init_std = dt_rank**-0.5 * dt_scale if dt_init == "constant": nn.init.constant_(dt_proj.weight, dt_init_std) elif dt_init == "random": nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) else: raise NotImplementedError dt = torch.exp( torch.rand(d_inner) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) ).clamp(min=dt_init_floor) inv_dt = dt + torch.log(-torch.expm1(-dt)) with torch.no_grad(): dt_proj.bias.copy_(inv_dt) return dt_proj @staticmethod def A_log_init(d_state, d_inner, copies=-1, device=None, merge=True): A = torch.arange(1, d_state + 1, dtype=torch.float32, device=device).view(1, -1).repeat(d_inner, 1).contiguous() A_log = torch.log(A) if copies > 0: A_log = A_log[None].repeat(copies, 1, 1).contiguous() if merge: A_log = A_log.flatten(0, 1) A_log = nn.Parameter(A_log) A_log._no_weight_decay = True return A_log @staticmethod def D_init(d_inner, copies=-1, device=None, merge=True): D = torch.ones(d_inner, device=device) if copies > 0: D = D[None].repeat(copies, 1).contiguous() if merge: D = D.flatten(0, 1) D = nn.Parameter(D) D._no_weight_decay = True return D @classmethod def init_dt_A_D(cls, d_state, dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=4): dt_projs = [ cls.dt_init(dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor) for _ in range(k_group) ] dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in dt_projs], dim=0)) dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in dt_projs], dim=0)) del dt_projs A_logs = cls.A_log_init(d_state, d_inner, copies=k_group, merge=True) Ds = cls.D_init(d_inner, copies=k_group, merge=True) return A_logs, Ds, dt_projs_weight, dt_projs_bias def forward_corev2( self, x: torch.Tensor, force_fp32=False, no_einsum=True, ): B, D, H, W = x.shape N = self.d_state L = H * W xs = cross_scan_fn(x, in_channel_first=True, out_channel_first=True) x_dbl = self.x_proj(xs.view(B, -1, L)) dts, Bs, Cs = torch.split(x_dbl.view(B, self.k_group, -1, L), [self.dt_rank, N, N], dim=2) dts = dts.contiguous().view(B, -1, L) dts = self.dt_projs(dts) xs = xs.view(B, -1, L) dts = dts.contiguous().view(B, -1, L) As = -self.A_logs.to(torch.float32).exp() Ds = self.Ds.to(torch.float32) Bs = Bs.contiguous().view(B, self.k_group, N, L) Cs = Cs.contiguous().view(B, self.k_group, N, L) delta_bias = self.dt_projs_bias.view(-1).to(torch.float32) ys = selective_scan_fn( xs, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus=True, backend="mamba" ).view(B, self.k_group, -1, H, W) y = cross_merge_fn(ys, in_channel_first=True, out_channel_first=True) y = y.view(B, -1, H, W) y = self.out_norm(y) return y.to(x.dtype) def forward(self, x: torch.Tensor): x = self.in_proj(x) x = self.conv2d(x) x = self.act(x) y = self.forward_core(x) out = self.dropout(self.out_proj(y)) return out class VSSBlock(nn.Module): def __init__( self, hidden_dim: int = 0, drop_path: float = 0, ssm_d_state: int = 1, ssm_ratio=1.0, ssm_dt_rank: Any = "auto", ssm_act_layer=nn.SiLU, ssm_conv: int = 3, ssm_conv_bias=False, ssm_drop_rate: float = 0, mlp_ratio=4.0, mlp_act_layer=nn.GELU, mlp_drop_rate: float = 0.0, use_checkpoint: bool = False, post_norm: bool = False, **kwargs, ): super().__init__() self.ssm_branch = ssm_ratio > 0 self.mlp_branch = mlp_ratio > 0 self.use_checkpoint = use_checkpoint self.post_norm = post_norm if self.ssm_branch: self.norm = VMambaLayerNorm2d(hidden_dim) self.op = SS2D( d_model=hidden_dim, d_state=ssm_d_state, ssm_ratio=ssm_ratio, dt_rank=ssm_dt_rank, act_layer=ssm_act_layer, d_conv=ssm_conv, conv_bias=ssm_conv_bias, dropout=ssm_drop_rate, ) self.drop_path = DropPath(drop_path) if self.mlp_branch: self.norm2 = VMambaLayerNorm2d(hidden_dim) mlp_hidden_dim = int(hidden_dim * mlp_ratio) self.mlp = VMambaMlp(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=mlp_act_layer, drop=mlp_drop_rate) def _forward(self, input: torch.Tensor): x = input if self.ssm_branch: if self.post_norm: x = x + self.drop_path(self.norm(self.op(x))) else: x = x + self.drop_path(self.op(self.norm(x))) if self.mlp_branch: if self.post_norm: x = x + self.drop_path(self.norm2(self.mlp(x))) else: x = x + self.drop_path(self.mlp(self.norm2(x))) return x def forward(self, input: torch.Tensor): if self.use_checkpoint: return checkpoint.checkpoint(self._forward, input) else: return self._forward(input) class VMambaLayer(nn.Module): def __init__( self, input_dim, depth, drop_path=0.0, norm_layer=VMambaLayerNorm2d, downsample=nn.Identity(), use_checkpoint=False, **kwargs, ): super().__init__() self.input_dim = input_dim self.use_checkpoint = use_checkpoint self.blocks = nn.ModuleList() for i in range(depth): self.blocks.append( VSSBlock(hidden_dim=input_dim, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer,use_checkpoint=use_checkpoint,**kwargs, ) ) self.downsample = downsample def forward(self, x): for block in self.blocks: x = block(x) x = self.downsample(x) return x class VMambaPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = VMambaConfig base_model_prefix = "vmamba" supports_gradient_checkpointing = False def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, nn.Linear): trunc_normal_(module.weight, std=0.02) if isinstance(module, nn.Linear) and module.bias is not None: nn.init.constant_(module.bias, 0) elif isinstance(module, nn.LayerNorm): nn.init.constant_(module.bias, 0) nn.init.constant_(module.weight, 1.0) class VMambaModel(VMambaPreTrainedModel): def __init__(self, config): super().__init__(config) self.config = config dims = config.dims if isinstance(dims, int): dims = [int(dims * 2**i_layer) for i_layer in range(self.num_layers)] self.dims = dims self.patch_embeddings = VMambaPatchEmbeddings(patch_size=config.patch_size, embed_dim=dims[0]) self.num_layers = len(config.depths) dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))] self.num_features = dims[-1] self.layers = nn.ModuleList() for i in range(self.num_layers): layer = VMambaLayer( input_dim=self.dims[i], depth=config.depths[i], drop_path=dpr[sum(config.depths[:i]):sum(config.depths[:i+1])], downsample=VMambaDowsample(self.dims[i], self.dims[i+1]) if i < self.num_layers - 1 else nn.Identity(), use_checkpoint=config.use_checkpoint, ) self.layers.append(layer) self.norm = VMambaLayerNorm2d(self.num_features) self.avgpool = nn.AdaptiveAvgPool2d(1) def get_input_embeddings(self) -> VMambaPatchEmbeddings: return self.patch_embeddings def forward(self, input_values: torch.Tensor): x = self.patch_embeddings(input_values) for layer in self.layers: x = layer(x) x = self.norm(x) x = self.avgpool(x).flatten(1) return x class VMambaForImageClassification(VMambaPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_classes = config.num_classes self.vmamba = VMambaModel(config) self.head = nn.Linear(self.vmamba.num_features, self.num_classes) if self.num_classes > 0 else nn.Identity() # Initialize weights and apply final processing self.post_init() def forward( self, pixel_values: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None, ): outputs = self.vmamba( pixel_values, ) logits = self.head(outputs) loss = None if labels is not None: labels = labels.to(logits.device) if self.config.loss_type == "ce": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "bce": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) if return_dict: output = (logits,) + (outputs,) return ((loss,) + output) if loss is not None else output return ImageClassifierOutput( loss=loss, logits=logits, hidden_states=outputs, ) __all__ = [ "VMambaModel", "VMambaPreTrainedModel", "VMambaForImageClassification", ]