from transformers import AutoConfig, AutoModel, AutoModelForCausalLM from fla.models.gated_deltaproduct.configuration_gated_deltaproduct import GatedDeltaProductConfig from fla.models.gated_deltaproduct.modeling_gated_deltaproduct import GatedDeltaProductForCausalLM, GatedDeltaProductModel AutoConfig.register(GatedDeltaProductConfig.model_type, GatedDeltaProductConfig) AutoModel.register(GatedDeltaProductConfig, GatedDeltaProductModel) AutoModelForCausalLM.register(GatedDeltaProductConfig, GatedDeltaProductForCausalLM) __all__ = [ "GatedDeltaProductConfig", "GatedDeltaProductForCausalLM", "GatedDeltaProductModel", ]