Commit
·
11ba200
1
Parent(s):
77a17f7
refactor: revert alibi stuff
Browse filesSigned-off-by: jupyterjazz <[email protected]>
mha.py
CHANGED
|
@@ -56,7 +56,15 @@ class FlashSelfAttention(nn.Module):
|
|
| 56 |
(default: 0.0)
|
| 57 |
"""
|
| 58 |
|
| 59 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
super().__init__()
|
| 61 |
assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
|
| 62 |
assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
|
|
@@ -64,6 +72,7 @@ class FlashSelfAttention(nn.Module):
|
|
| 64 |
self.softmax_scale = softmax_scale
|
| 65 |
self.drop = nn.Dropout(attention_dropout)
|
| 66 |
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
|
|
|
|
| 67 |
self.deterministic = deterministic
|
| 68 |
|
| 69 |
def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
|
|
@@ -87,6 +96,8 @@ class FlashSelfAttention(nn.Module):
|
|
| 87 |
assert qkv.is_cuda
|
| 88 |
causal = self.causal if causal is None else causal
|
| 89 |
unpadded = cu_seqlens is not None
|
|
|
|
|
|
|
| 90 |
if unpadded:
|
| 91 |
assert cu_seqlens.dtype == torch.int32
|
| 92 |
assert max_seqlen is not None
|
|
@@ -99,6 +110,7 @@ class FlashSelfAttention(nn.Module):
|
|
| 99 |
softmax_scale=self.softmax_scale,
|
| 100 |
causal=causal,
|
| 101 |
alibi_slopes=self.alibi_slopes,
|
|
|
|
| 102 |
deterministic=self.deterministic,
|
| 103 |
)
|
| 104 |
else:
|
|
@@ -108,6 +120,7 @@ class FlashSelfAttention(nn.Module):
|
|
| 108 |
softmax_scale=self.softmax_scale,
|
| 109 |
causal=causal,
|
| 110 |
alibi_slopes=self.alibi_slopes,
|
|
|
|
| 111 |
deterministic=self.deterministic,
|
| 112 |
)
|
| 113 |
|
|
@@ -123,7 +136,15 @@ class FlashCrossAttention(nn.Module):
|
|
| 123 |
(default: 0.0)
|
| 124 |
"""
|
| 125 |
|
| 126 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
super().__init__()
|
| 128 |
assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
|
| 129 |
assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
|
|
@@ -131,6 +152,7 @@ class FlashCrossAttention(nn.Module):
|
|
| 131 |
self.softmax_scale = softmax_scale
|
| 132 |
self.drop = nn.Dropout(attention_dropout)
|
| 133 |
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
|
|
|
|
| 134 |
self.deterministic = deterministic
|
| 135 |
|
| 136 |
def forward(
|
|
@@ -160,6 +182,8 @@ class FlashCrossAttention(nn.Module):
|
|
| 160 |
assert q.is_cuda and kv.is_cuda
|
| 161 |
causal = self.causal if causal is None else causal
|
| 162 |
unpadded = cu_seqlens is not None
|
|
|
|
|
|
|
| 163 |
if unpadded:
|
| 164 |
assert cu_seqlens.dtype == torch.int32
|
| 165 |
assert max_seqlen is not None
|
|
@@ -179,6 +203,7 @@ class FlashCrossAttention(nn.Module):
|
|
| 179 |
softmax_scale=self.softmax_scale,
|
| 180 |
causal=causal,
|
| 181 |
alibi_slopes=self.alibi_slopes,
|
|
|
|
| 182 |
deterministic=self.deterministic,
|
| 183 |
)
|
| 184 |
else:
|
|
@@ -192,6 +217,7 @@ class FlashCrossAttention(nn.Module):
|
|
| 192 |
causal=causal,
|
| 193 |
softmax_scale=self.softmax_scale,
|
| 194 |
alibi_slopes=self.alibi_slopes,
|
|
|
|
| 195 |
deterministic=self.deterministic,
|
| 196 |
)
|
| 197 |
|
|
@@ -367,6 +393,7 @@ class MHA(nn.Module):
|
|
| 367 |
rotary_emb_scale_base=None,
|
| 368 |
rotary_emb_interleaved=False,
|
| 369 |
use_alibi=False,
|
|
|
|
| 370 |
fused_bias_fc=False,
|
| 371 |
use_flash_attn=False,
|
| 372 |
return_residual=False,
|
|
@@ -396,6 +423,8 @@ class MHA(nn.Module):
|
|
| 396 |
alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
|
| 397 |
else:
|
| 398 |
alibi_slopes = None
|
|
|
|
|
|
|
| 399 |
|
| 400 |
self.num_heads = num_heads
|
| 401 |
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
|
|
@@ -426,12 +455,12 @@ class MHA(nn.Module):
|
|
| 426 |
)
|
| 427 |
wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
|
| 428 |
inner_attn_cls = (
|
| 429 |
-
partial(FlashSelfAttention, alibi_slopes=alibi_slopes)
|
| 430 |
if use_flash_attn
|
| 431 |
else SelfAttention
|
| 432 |
)
|
| 433 |
inner_cross_attn_cls = (
|
| 434 |
-
partial(FlashCrossAttention, alibi_slopes=alibi_slopes)
|
| 435 |
if use_flash_attn
|
| 436 |
else CrossAttention
|
| 437 |
)
|
|
@@ -584,7 +613,6 @@ class MHA(nn.Module):
|
|
| 584 |
assert key_padding_mask is None
|
| 585 |
assert self.use_flash_attn
|
| 586 |
assert not self.dwconv
|
| 587 |
-
# assert self.rotary_emb_dim == 0
|
| 588 |
if key_padding_mask is not None:
|
| 589 |
assert cu_seqlens is None
|
| 590 |
assert max_seqlen is None
|
|
|
|
| 56 |
(default: 0.0)
|
| 57 |
"""
|
| 58 |
|
| 59 |
+
def __init__(
|
| 60 |
+
self,
|
| 61 |
+
causal=False,
|
| 62 |
+
softmax_scale=None,
|
| 63 |
+
attention_dropout=0.0,
|
| 64 |
+
window_size=(-1, -1),
|
| 65 |
+
alibi_slopes=None,
|
| 66 |
+
deterministic=False,
|
| 67 |
+
):
|
| 68 |
super().__init__()
|
| 69 |
assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
|
| 70 |
assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
|
|
|
|
| 72 |
self.softmax_scale = softmax_scale
|
| 73 |
self.drop = nn.Dropout(attention_dropout)
|
| 74 |
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
|
| 75 |
+
self.window_size = window_size
|
| 76 |
self.deterministic = deterministic
|
| 77 |
|
| 78 |
def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
|
|
|
|
| 96 |
assert qkv.is_cuda
|
| 97 |
causal = self.causal if causal is None else causal
|
| 98 |
unpadded = cu_seqlens is not None
|
| 99 |
+
if self.alibi_slopes is not None:
|
| 100 |
+
self.alibi_slopes = self.alibi_slopes.to(torch.float32)
|
| 101 |
if unpadded:
|
| 102 |
assert cu_seqlens.dtype == torch.int32
|
| 103 |
assert max_seqlen is not None
|
|
|
|
| 110 |
softmax_scale=self.softmax_scale,
|
| 111 |
causal=causal,
|
| 112 |
alibi_slopes=self.alibi_slopes,
|
| 113 |
+
window_size=self.window_size,
|
| 114 |
deterministic=self.deterministic,
|
| 115 |
)
|
| 116 |
else:
|
|
|
|
| 120 |
softmax_scale=self.softmax_scale,
|
| 121 |
causal=causal,
|
| 122 |
alibi_slopes=self.alibi_slopes,
|
| 123 |
+
window_size=self.window_size,
|
| 124 |
deterministic=self.deterministic,
|
| 125 |
)
|
| 126 |
|
|
|
|
| 136 |
(default: 0.0)
|
| 137 |
"""
|
| 138 |
|
| 139 |
+
def __init__(
|
| 140 |
+
self,
|
| 141 |
+
causal=False,
|
| 142 |
+
softmax_scale=None,
|
| 143 |
+
attention_dropout=0.0,
|
| 144 |
+
alibi_slopes=None,
|
| 145 |
+
window_size=(-1, -1),
|
| 146 |
+
deterministic=False,
|
| 147 |
+
):
|
| 148 |
super().__init__()
|
| 149 |
assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
|
| 150 |
assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
|
|
|
|
| 152 |
self.softmax_scale = softmax_scale
|
| 153 |
self.drop = nn.Dropout(attention_dropout)
|
| 154 |
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
|
| 155 |
+
self.window_size = window_size
|
| 156 |
self.deterministic = deterministic
|
| 157 |
|
| 158 |
def forward(
|
|
|
|
| 182 |
assert q.is_cuda and kv.is_cuda
|
| 183 |
causal = self.causal if causal is None else causal
|
| 184 |
unpadded = cu_seqlens is not None
|
| 185 |
+
if self.alibi_slopes is not None:
|
| 186 |
+
self.alibi_slopes = self.alibi_slopes.to(torch.float32)
|
| 187 |
if unpadded:
|
| 188 |
assert cu_seqlens.dtype == torch.int32
|
| 189 |
assert max_seqlen is not None
|
|
|
|
| 203 |
softmax_scale=self.softmax_scale,
|
| 204 |
causal=causal,
|
| 205 |
alibi_slopes=self.alibi_slopes,
|
| 206 |
+
window_size=self.window_size,
|
| 207 |
deterministic=self.deterministic,
|
| 208 |
)
|
| 209 |
else:
|
|
|
|
| 217 |
causal=causal,
|
| 218 |
softmax_scale=self.softmax_scale,
|
| 219 |
alibi_slopes=self.alibi_slopes,
|
| 220 |
+
window_size=self.window_size,
|
| 221 |
deterministic=self.deterministic,
|
| 222 |
)
|
| 223 |
|
|
|
|
| 393 |
rotary_emb_scale_base=None,
|
| 394 |
rotary_emb_interleaved=False,
|
| 395 |
use_alibi=False,
|
| 396 |
+
window_size=(-1, -1),
|
| 397 |
fused_bias_fc=False,
|
| 398 |
use_flash_attn=False,
|
| 399 |
return_residual=False,
|
|
|
|
| 423 |
alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
|
| 424 |
else:
|
| 425 |
alibi_slopes = None
|
| 426 |
+
if window_size != (-1, -1):
|
| 427 |
+
assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
|
| 428 |
|
| 429 |
self.num_heads = num_heads
|
| 430 |
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
|
|
|
|
| 455 |
)
|
| 456 |
wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
|
| 457 |
inner_attn_cls = (
|
| 458 |
+
partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
|
| 459 |
if use_flash_attn
|
| 460 |
else SelfAttention
|
| 461 |
)
|
| 462 |
inner_cross_attn_cls = (
|
| 463 |
+
partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
|
| 464 |
if use_flash_attn
|
| 465 |
else CrossAttention
|
| 466 |
)
|
|
|
|
| 613 |
assert key_padding_mask is None
|
| 614 |
assert self.use_flash_attn
|
| 615 |
assert not self.dwconv
|
|
|
|
| 616 |
if key_padding_mask is not None:
|
| 617 |
assert cu_seqlens is None
|
| 618 |
assert max_seqlen is None
|