yairschiff's picture
Upload CaduceusForMaskedLM
737d4a1 verified
raw
history blame contribute delete
No virus
1.96 kB
"""Caduceus config for Hugging Face.
"""
from typing import Optional, Union
from transformers import PretrainedConfig
class CaduceusConfig(PretrainedConfig):
"""Config that extends the original MambaConfig with params relevant to bi-directionality and RC equivariance."""
model_type = "caduceus"
def __init__(
self,
# From original MambaConfig
d_model: int = 2560,
n_layer: int = 64,
vocab_size: int = 50277,
ssm_cfg: Optional[dict] = None,
rms_norm: bool = True,
residual_in_fp32: bool = True,
fused_add_norm: bool = True,
pad_vocab_size_multiple: int = 8,
# Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm
norm_epsilon: float = 1e-5,
# Used in init_weights
initializer_cfg: Optional[dict] = None,
# Caduceus-specific params
bidirectional: bool = True,
bidirectional_strategy: Union[str, None] = "add",
bidirectional_weight_tie: bool = True,
rcps: bool = False,
complement_map: Optional[dict] = None, # used for RCPSEmbedding / RCPSLMHead
**kwargs,
):
super().__init__(**kwargs)
self.d_model = d_model
self.n_layer = n_layer
self.vocab_size = vocab_size
self.ssm_cfg = ssm_cfg
self.rms_norm = rms_norm
self.residual_in_fp32 = residual_in_fp32
self.fused_add_norm = fused_add_norm
self.pad_vocab_size_multiple = pad_vocab_size_multiple
self.norm_epsilon = norm_epsilon
self.initializer_cfg = initializer_cfg
self.bidirectional = bidirectional
self.bidirectional_strategy = bidirectional_strategy
self.bidirectional_weight_tie = bidirectional_weight_tie
self.rcps = rcps
self.complement_map = complement_map