from transformers.models.llama.configuration_llama import LlamaConfig class LlamaMLAConfig(LlamaConfig): model_type = "llamamla" def __init__( self, *args, kv_lora_rank=512, q_lora_rank=None, qk_rope_head_dim=64, qk_nope_head_dim=128, v_head_dim=128, query_pre_attn_scalar=128, softcap=None, **kwargs ): super().__init__(*args, **kwargs) self.kv_lora_rank = kv_lora_rank self.q_lora_rank = q_lora_rank self.qk_rope_head_dim = qk_rope_head_dim self.qk_nope_head_dim = qk_nope_head_dim self.qk_head_dim = qk_rope_head_dim + qk_nope_head_dim self.v_head_dim = v_head_dim self.query_pre_attn_scalar = query_pre_attn_scalar self.softcap = softcap