File size: 626 Bytes
93ce109 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
from transformers import PretrainedConfig
class MambaVisionConfig(PretrainedConfig):
model_type = "mambavision"
def __init__(
self,
depths=[1, 3, 11, 4],
num_heads=[2, 4, 8, 16],
window_size=[8, 8, 14, 7],
dim=80,
in_dim=32,
mlp_ratio=4,
drop_path_rate=0.2,
**kwargs,
):
self.depths = depths
self.num_heads = num_heads
self.window_size = window_size
self.dim = dim
self.in_dim = in_dim
self.mlp_ratio = mlp_ratio
self.drop_path_rate = drop_path_rate
super().__init__(**kwargs) |