Do you really use flash attention?

#5
by GinnM - opened

I noticed that:

        attn = scaled_dot_product_attention(
            query=xq.transpose(1, 2),
            key=xk.transpose(1, 2),
            value=xv.transpose(1, 2),
            attn_mask=attention_mask.bool(),
            dropout_p=0,
        ).transpose(1, 2)

But in the scenario that the attn_mask parameter is not None, scaled_dot_product_attention will not use flash attention actually.

I tried both sdpa and flash attention and both got error:
RuntimeError: Failed to load AutoModel for chandar-lab/NeoBERT. Error: NeoBERT does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet.

Sign up or log in to comment