|
from transformers import PretrainedConfig |
|
|
|
|
|
class USADConfig(PretrainedConfig): |
|
model_type = "usad" |
|
|
|
def __init__( |
|
self, |
|
encoder_dim: int = 384, |
|
num_layers: int = 12, |
|
attention_type: str = "mhsa", |
|
num_attention_heads: int = 6, |
|
mamba_d_state: int = 16, |
|
mamba_d_conv: int = 4, |
|
mamba_expand: int = 2, |
|
mamba_bidirectional: bool = False, |
|
feed_forward_expansion_factor: int = 4, |
|
conv_expansion_factor: int = 2, |
|
feed_forward_dropout_p: float = 0.1, |
|
attention_dropout_p: float = 0.1, |
|
conv_dropout_p: float = 0.1, |
|
conv_kernel_size: int = 31, |
|
half_step_residual: bool = True, |
|
transformer_style: bool = True, |
|
use_framewise_subsample: bool = True, |
|
use_patchwise_subsample: bool = False, |
|
conv_subsample_channels: int = 64, |
|
conv_subsample_rate: int = 2, |
|
input_dim: int = 128, |
|
input_dropout_p: float = 0.0, |
|
conv_pos: bool = True, |
|
conv_pos_depth: int = 5, |
|
conv_pos_width: int = 95, |
|
conv_pos_groups: int = 16, |
|
subsample_normalization: bool = True, |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
|
|
self.encoder_dim = encoder_dim |
|
self.num_layers = num_layers |
|
self.attention_type = attention_type |
|
self.num_attention_heads = num_attention_heads |
|
self.mamba_d_state = mamba_d_state |
|
self.mamba_d_conv = mamba_d_conv |
|
self.mamba_expand = mamba_expand |
|
self.mamba_bidirectional = mamba_bidirectional |
|
self.feed_forward_expansion_factor = feed_forward_expansion_factor |
|
self.conv_expansion_factor = conv_expansion_factor |
|
self.feed_forward_dropout_p = feed_forward_dropout_p |
|
self.attention_dropout_p = attention_dropout_p |
|
self.conv_dropout_p = conv_dropout_p |
|
self.conv_kernel_size = conv_kernel_size |
|
self.half_step_residual = half_step_residual |
|
self.transformer_style = transformer_style |
|
self.use_framewise_subsample = use_framewise_subsample |
|
self.use_patchwise_subsample = use_patchwise_subsample |
|
self.conv_subsample_channels = conv_subsample_channels |
|
self.conv_subsample_rate = conv_subsample_rate |
|
self.input_dim = input_dim |
|
self.input_dropout_p = input_dropout_p |
|
self.conv_pos = conv_pos |
|
self.conv_pos_depth = conv_pos_depth |
|
self.conv_pos_width = conv_pos_width |
|
self.conv_pos_groups = conv_pos_groups |
|
self.subsample_normalization = subsample_normalization |
|
|