|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from collections import namedtuple |
|
from typing import Optional, List, Union |
|
|
|
import torch |
|
from transformers import PretrainedConfig, PreTrainedModel |
|
from .mamba_vision import * |
|
from timm.models import create_model, load_checkpoint |
|
|
|
|
|
class MambaVisionConfig(PretrainedConfig): |
|
|
|
def __init__( |
|
self, |
|
args: Optional[dict] = None, |
|
**kwargs, |
|
): |
|
self.args = args |
|
super().__init__(**kwargs) |
|
|
|
|
|
class MambaVisionModel(PreTrainedModel): |
|
"""Pretrained Hugging Face model for MambaVision. |
|
|
|
This class inherits from PreTrainedModel, which provides |
|
HuggingFace's functionality for loading and saving models. |
|
""" |
|
|
|
config_class = MambaVisionConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
MambaVisionArgs = namedtuple("MambaVisionArgs", config.args.keys()) |
|
args = MambaVisionArgs(**config.args) |
|
self.config = config |
|
self.model = create_model(args.model) |
|
|
|
def forward(self, x: torch.Tensor): |
|
return self.model.forward(x) |
|
|