|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" OrthoViT model configuration""" |
|
|
|
from transformers.models.vit.configuration_vit import ViTConfig |
|
from transformers.utils import logging |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
class OrthoViTConfig(ViTConfig): |
|
r""" |
|
This is the configuration class to store the configuration of an [`OrthoViTModel`]. |
|
It is used to instantiate a ViT model with orthogonal residual connections |
|
according to the specified arguments, defining the model architecture. Instantiating a |
|
configuration with the defaults will yield a similar configuration to that of |
|
the ViT-base model. |
|
|
|
Configuration objects inherit from [`ViTConfig`] and can be used to control the model outputs. |
|
Read the documentation from [`ViTConfig`] for more information. |
|
|
|
Args: |
|
residual_connection (`str`, *optional*, defaults to `"linear"`): |
|
The type of residual connection to use. Can be "linear" or "orthogonal". |
|
orthogonal_method (`str`, *optional*, defaults to `"channel"`): |
|
The method for orthogonalization if `residual_connection` is "orthogonal". |
|
Can be "channel" or "global". For ViT, "channel" is typically used for token embeddings. |
|
residual_connection_dim (`int`, *optional*, defaults to -1): |
|
The dimension along which to compute orthogonality. Defaults to -1 (last dimension). |
|
residual_eps (`float`, *optional*, defaults to 1e-6): |
|
Epsilon value for numerical stability in orthogonalization. |
|
residual_perturbation (`float`, *optional*, defaults to `None`): |
|
Magnitude of random perturbation to add to the module output before connection. |
|
|
|
Example: |
|
|
|
```python |
|
>>> from modeling_ortho_vit import OrthoViTModel |
|
>>> from configuration_ortho_vit import OrthoViTConfig |
|
|
|
>>> # Initializing a ViT-base style configuration with orthogonal connections |
|
>>> configuration = OrthoViTConfig(residual_connection="orthogonal") |
|
|
|
>>> # Initializing a model (with random weights) from the ViT-base style configuration |
|
>>> model = OrthoViTModel(configuration) |
|
|
|
>>> # Accessing the model configuration |
|
>>> configuration = model.config |
|
```""" |
|
model_type = "vit" |
|
|
|
def __init__( |
|
self, |
|
residual_connection="linear", |
|
orthogonal_method="channel", |
|
residual_connection_dim=-1, |
|
residual_eps=1e-6, |
|
residual_perturbation=None, |
|
elementwise_affine_ln=False, |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
self.residual_connection = residual_connection |
|
self.orthogonal_method = orthogonal_method |
|
self.residual_connection_dim = residual_connection_dim |
|
self.residual_eps = residual_eps |
|
self.residual_perturbation = residual_perturbation |
|
self.elementwise_affine_ln = elementwise_affine_ln |
|
|
|
@property |
|
def residual_kwargs(self) -> dict: |
|
|
|
return dict( |
|
method=self.residual_connection, |
|
orthogonal_method=self.orthogonal_method, |
|
dim=self.residual_connection_dim, |
|
perturbation=self.residual_perturbation, |
|
|
|
) |