# -*- coding: utf-8 -*- | |
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM | |
from fla.models.linear_attn.configuration_linear_attn import LinearAttentionConfig | |
from fla.models.linear_attn.modeling_linear_attn import LinearAttentionForCausalLM, LinearAttentionModel | |
AutoConfig.register(LinearAttentionConfig.model_type, LinearAttentionConfig) | |
AutoModel.register(LinearAttentionConfig, LinearAttentionModel) | |
AutoModelForCausalLM.register(LinearAttentionConfig, LinearAttentionForCausalLM) | |
__all__ = ['LinearAttentionConfig', 'LinearAttentionForCausalLM', 'LinearAttentionModel'] | |