In modeling_minimax_text_01.py attention mask is not passed correctly to MiniMaxText01FlashAttention2::forward() method

#13
by sszymczyk - opened

In modeling_minimax_text_01.py file MiniMaxText01DecoderLayer::forward() method we have:

        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            position_ids=position_ids,
            attn_mask=attention_mask,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            slope_rate=slope_rate,
        )

where self_attn can be object of class MiniMaxText01LightningAttention or MiniMaxText01FlashAttention2 depending on the layer number.
Note that attention mask is always passed in named attn_mask argument.

However, in MiniMaxText01FlashAttention2::forward() we have:

    def forward(
            self,
            hidden_states: torch.Tensor,
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            past_key_value: Optional[Union[Cache, Tuple[torch.Tensor]]] = None,
            output_attentions: bool = False,
            use_cache: bool = False,
            **kwargs,
    ):

so argument name for attention mask is attention_mask here, not attn_mask as passed in MiniMaxText01DecoderLayer::forward().
Since it has default value of None, attention mask will always be None here.

Is this intentional or an error? Did you use this code to train the model?

Sign up or log in to comment