Mistral: Sliding Window Attention with Flash Attention and Sample Packing (#732)
Browse files* Implement Mistral FA + SWA + Sample Packing
* Handle unbroadcastable tensor
* chore: lint
* Simplify _prepare_decoder_attention_mask
* Uncomment window size
* Upgrade flash-attn to minimum of 2.3.0 to support SWA
* Add original condition to avoid error during inference
* chore: lint
* use torchscript to prevent oom
* chore: pylint
---------
Co-authored-by: Wing Lian <[email protected]>
- setup.py +1 -1
- src/axolotl/monkeypatch/mistral_attn_hijack_flash.py +104 -5
setup.py
CHANGED
|
@@ -46,7 +46,7 @@ setup(
|
|
| 46 |
dependency_links=dependency_links,
|
| 47 |
extras_require={
|
| 48 |
"flash-attn": [
|
| 49 |
-
"flash-attn>=2.
|
| 50 |
],
|
| 51 |
"deepspeed": [
|
| 52 |
"deepspeed",
|
|
|
|
| 46 |
dependency_links=dependency_links,
|
| 47 |
extras_require={
|
| 48 |
"flash-attn": [
|
| 49 |
+
"flash-attn>=2.3.0",
|
| 50 |
],
|
| 51 |
"deepspeed": [
|
| 52 |
"deepspeed",
|
src/axolotl/monkeypatch/mistral_attn_hijack_flash.py
CHANGED
|
@@ -14,6 +14,9 @@ from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-impor
|
|
| 14 |
flash_attn_varlen_qkvpacked_func,
|
| 15 |
)
|
| 16 |
from transformers.modeling_outputs import BaseModelOutputWithPast
|
|
|
|
|
|
|
|
|
|
| 17 |
from transformers.models.mistral.modeling_mistral import (
|
| 18 |
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
| 19 |
)
|
|
@@ -42,6 +45,44 @@ def replace_mistral_attn_with_flash_attn(
|
|
| 42 |
)
|
| 43 |
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
| 46 |
# requires the attention mask to be the same as the key_padding_mask
|
| 47 |
def _prepare_decoder_attention_mask(
|
|
@@ -53,11 +94,29 @@ def _prepare_decoder_attention_mask(
|
|
| 53 |
sliding_window,
|
| 54 |
): # pylint: disable=unused-argument
|
| 55 |
# [bsz, seq_len]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
return attention_mask
|
| 57 |
|
| 58 |
|
| 59 |
def flashattn_forward(
|
| 60 |
-
self,
|
| 61 |
hidden_states: torch.Tensor,
|
| 62 |
attention_mask: Optional[torch.Tensor] = None,
|
| 63 |
position_ids: Optional[torch.LongTensor] = None,
|
|
@@ -91,10 +150,41 @@ def flashattn_forward(
|
|
| 91 |
query_states, key_states, cos, sin, position_ids
|
| 92 |
)
|
| 93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
if past_key_value is not None:
|
| 95 |
-
#
|
| 96 |
-
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
past_key_value = (key_states, value_states) if use_cache else None
|
| 100 |
|
|
@@ -120,7 +210,13 @@ def flashattn_forward(
|
|
| 120 |
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
| 121 |
|
| 122 |
output = flash_attn_varlen_qkvpacked_func(
|
| 123 |
-
qkv,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
)
|
| 125 |
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
| 126 |
elif query_states.shape == key_states.shape:
|
|
@@ -146,6 +242,7 @@ def flashattn_forward(
|
|
| 146 |
0.0,
|
| 147 |
softmax_scale=None,
|
| 148 |
causal=is_causal,
|
|
|
|
| 149 |
)
|
| 150 |
output = output_pad_fn(output_unpad)
|
| 151 |
else:
|
|
@@ -157,6 +254,7 @@ def flashattn_forward(
|
|
| 157 |
query_states,
|
| 158 |
torch.stack([key_states, value_states], 2),
|
| 159 |
causal=is_causal,
|
|
|
|
| 160 |
)
|
| 161 |
else:
|
| 162 |
( # pylint: disable=unbalanced-tuple-unpacking
|
|
@@ -191,6 +289,7 @@ def flashattn_forward(
|
|
| 191 |
0.0,
|
| 192 |
softmax_scale=None,
|
| 193 |
causal=is_causal,
|
|
|
|
| 194 |
)
|
| 195 |
output = output_pad_fn(output_unpad)
|
| 196 |
|
|
|
|
| 14 |
flash_attn_varlen_qkvpacked_func,
|
| 15 |
)
|
| 16 |
from transformers.modeling_outputs import BaseModelOutputWithPast
|
| 17 |
+
from transformers.models.mistral.modeling_mistral import (
|
| 18 |
+
MistralAttention as OriginalMistralAttention,
|
| 19 |
+
)
|
| 20 |
from transformers.models.mistral.modeling_mistral import (
|
| 21 |
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
| 22 |
)
|
|
|
|
| 45 |
)
|
| 46 |
|
| 47 |
|
| 48 |
+
@torch.jit.script
|
| 49 |
+
def _make_sliding_window_causal_mask(
|
| 50 |
+
bsz: int,
|
| 51 |
+
tgt_len: int,
|
| 52 |
+
dtype: torch.dtype,
|
| 53 |
+
device: torch.device,
|
| 54 |
+
past_key_values_length: int = 0,
|
| 55 |
+
sliding_window: int = 4096,
|
| 56 |
+
):
|
| 57 |
+
"""
|
| 58 |
+
Make causal mask used for sliding window attention
|
| 59 |
+
"""
|
| 60 |
+
tensor = torch.full(
|
| 61 |
+
(tgt_len, tgt_len),
|
| 62 |
+
fill_value=1,
|
| 63 |
+
device=device,
|
| 64 |
+
)
|
| 65 |
+
mask = torch.tril(tensor, diagonal=0)
|
| 66 |
+
# make the mask banded to account for sliding window
|
| 67 |
+
# NOTE: HF implementation is wrong as of 14-10-2023 for torch.triu, needs +1
|
| 68 |
+
mask = torch.triu(mask, diagonal=-sliding_window + 1)
|
| 69 |
+
mask = torch.log(mask).to(dtype)
|
| 70 |
+
|
| 71 |
+
if past_key_values_length > 0:
|
| 72 |
+
mask = torch.cat(
|
| 73 |
+
[
|
| 74 |
+
torch.zeros(
|
| 75 |
+
tgt_len, past_key_values_length, dtype=dtype, device=device
|
| 76 |
+
),
|
| 77 |
+
mask,
|
| 78 |
+
],
|
| 79 |
+
dim=-1,
|
| 80 |
+
)
|
| 81 |
+
return mask[None, None, :, :].expand(
|
| 82 |
+
bsz, 1, tgt_len, tgt_len + past_key_values_length
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
| 87 |
# requires the attention mask to be the same as the key_padding_mask
|
| 88 |
def _prepare_decoder_attention_mask(
|
|
|
|
| 94 |
sliding_window,
|
| 95 |
): # pylint: disable=unused-argument
|
| 96 |
# [bsz, seq_len]
|
| 97 |
+
if attention_mask is None:
|
| 98 |
+
return attention_mask
|
| 99 |
+
|
| 100 |
+
# NOTE: attention mask and sliding masks are only broadcastable in certain scenarios.
|
| 101 |
+
# Without attention_mask.shape[0] == 1, error will trigger after eval loss but only when wandb is enabled.
|
| 102 |
+
if input_shape[-1] > 1 and attention_mask.shape[0] == 1:
|
| 103 |
+
sliding_window_mask = _make_sliding_window_causal_mask(
|
| 104 |
+
bsz=input_shape[0],
|
| 105 |
+
tgt_len=input_shape[1],
|
| 106 |
+
dtype=inputs_embeds.dtype,
|
| 107 |
+
device=inputs_embeds.device,
|
| 108 |
+
past_key_values_length=past_key_values_length,
|
| 109 |
+
sliding_window=sliding_window,
|
| 110 |
+
)
|
| 111 |
+
attention_mask = attention_mask + sliding_window_mask
|
| 112 |
+
else:
|
| 113 |
+
LOG.info("skipping sliding window mask, not broadcastable with attention mask")
|
| 114 |
+
|
| 115 |
return attention_mask
|
| 116 |
|
| 117 |
|
| 118 |
def flashattn_forward(
|
| 119 |
+
self: OriginalMistralAttention,
|
| 120 |
hidden_states: torch.Tensor,
|
| 121 |
attention_mask: Optional[torch.Tensor] = None,
|
| 122 |
position_ids: Optional[torch.LongTensor] = None,
|
|
|
|
| 150 |
query_states, key_states, cos, sin, position_ids
|
| 151 |
)
|
| 152 |
|
| 153 |
+
use_sliding_windows = (
|
| 154 |
+
hasattr(self.config, "sliding_window") is not None
|
| 155 |
+
and kv_seq_len > self.config.sliding_window
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
if use_sliding_windows:
|
| 159 |
+
window_size = (self.config.sliding_window, self.config.sliding_window)
|
| 160 |
+
else:
|
| 161 |
+
window_size = (-1, -1)
|
| 162 |
+
|
| 163 |
if past_key_value is not None:
|
| 164 |
+
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
| 165 |
+
if (
|
| 166 |
+
hasattr(self.config, "sliding_window")
|
| 167 |
+
and kv_seq_len > self.config.sliding_window
|
| 168 |
+
):
|
| 169 |
+
slicing_tokens = kv_seq_len - self.config.sliding_window
|
| 170 |
+
|
| 171 |
+
past_key = past_key_value[0]
|
| 172 |
+
past_value = past_key_value[1]
|
| 173 |
+
|
| 174 |
+
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
|
| 175 |
+
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
|
| 176 |
+
|
| 177 |
+
if past_key.shape[-2] != self.config.sliding_window - 1:
|
| 178 |
+
raise ValueError(
|
| 179 |
+
f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
|
| 180 |
+
f" {past_key.shape}"
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
past_key_value = (past_key, past_value) if use_cache else None
|
| 184 |
+
|
| 185 |
+
if past_key_value is not None:
|
| 186 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
| 187 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
| 188 |
|
| 189 |
past_key_value = (key_states, value_states) if use_cache else None
|
| 190 |
|
|
|
|
| 210 |
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
| 211 |
|
| 212 |
output = flash_attn_varlen_qkvpacked_func(
|
| 213 |
+
qkv,
|
| 214 |
+
cu_seqlens,
|
| 215 |
+
max_seqlen,
|
| 216 |
+
0.0,
|
| 217 |
+
softmax_scale=None,
|
| 218 |
+
causal=True,
|
| 219 |
+
window_size=window_size,
|
| 220 |
)
|
| 221 |
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
| 222 |
elif query_states.shape == key_states.shape:
|
|
|
|
| 242 |
0.0,
|
| 243 |
softmax_scale=None,
|
| 244 |
causal=is_causal,
|
| 245 |
+
window_size=window_size,
|
| 246 |
)
|
| 247 |
output = output_pad_fn(output_unpad)
|
| 248 |
else:
|
|
|
|
| 254 |
query_states,
|
| 255 |
torch.stack([key_states, value_states], 2),
|
| 256 |
causal=is_causal,
|
| 257 |
+
window_size=window_size,
|
| 258 |
)
|
| 259 |
else:
|
| 260 |
( # pylint: disable=unbalanced-tuple-unpacking
|
|
|
|
| 289 |
0.0,
|
| 290 |
softmax_scale=None,
|
| 291 |
causal=is_causal,
|
| 292 |
+
window_size=window_size,
|
| 293 |
)
|
| 294 |
output = output_pad_fn(output_unpad)
|
| 295 |
|