ssmi153
commited on
Update XFormers Attention Monkeypatch to handle Llama-2 70B (GQA) (#339)
Browse files* Fix XFormers attention for Llama-2 70B (GQA)
Updated XFormers MonkeyPatch to handle GQA as used in Llama-2 70B. All the updated code is taken directly from the Transformers library: https://github.com/huggingface/transformers/commit/07360b6c9c9448d619a82798419ed291dfc6ac8f#diff-06392bad3b9e97be9ade60d4ac46f73b6809388f4d507c2ba1384ab872711c51 from their llama_modeling.py file.
* Catch configs without pretraining_tp
* Whitespace bug fix
Command had accidentally been moved out of if-else block.
* pre-commit formatting fixes
Thanks to
@winglian
src/axolotl/monkeypatch/llama_attn_hijack_xformers.py
CHANGED
|
@@ -7,6 +7,7 @@ import math
|
|
| 7 |
from typing import Optional, Tuple
|
| 8 |
|
| 9 |
import torch
|
|
|
|
| 10 |
import transformers.models.llama.modeling_llama
|
| 11 |
from torch import nn
|
| 12 |
|
|
@@ -38,21 +39,48 @@ def xformers_forward(
|
|
| 38 |
# pylint: disable=duplicate-code
|
| 39 |
bsz, q_len, _ = hidden_states.size()
|
| 40 |
|
| 41 |
-
|
| 42 |
-
self.
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
self.
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
self.v_proj(
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
kv_seq_len = key_states.shape[-2]
|
| 58 |
if past_key_value is not None:
|
|
@@ -73,6 +101,14 @@ def xformers_forward(
|
|
| 73 |
|
| 74 |
past_key_value = (key_states, value_states) if use_cache else None
|
| 75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
# We only apply xformers optimizations if we don't need to output the whole attention matrix
|
| 77 |
if not output_attentions:
|
| 78 |
query_states = query_states.transpose(1, 2)
|
|
@@ -128,10 +164,23 @@ def xformers_forward(
|
|
| 128 |
f" {attn_output.size()}"
|
| 129 |
)
|
| 130 |
|
| 131 |
-
attn_output = attn_output.transpose(1, 2)
|
|
|
|
| 132 |
|
| 133 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
return attn_output, attn_weights, past_key_value
|
| 136 |
|
| 137 |
|
|
|
|
| 7 |
from typing import Optional, Tuple
|
| 8 |
|
| 9 |
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
import transformers.models.llama.modeling_llama
|
| 12 |
from torch import nn
|
| 13 |
|
|
|
|
| 39 |
# pylint: disable=duplicate-code
|
| 40 |
bsz, q_len, _ = hidden_states.size()
|
| 41 |
|
| 42 |
+
if not hasattr(self, "pretraining_tp"):
|
| 43 |
+
self.pretraining_tp = 1
|
| 44 |
+
|
| 45 |
+
if self.pretraining_tp > 1:
|
| 46 |
+
key_value_slicing = (
|
| 47 |
+
self.num_key_value_heads * self.head_dim
|
| 48 |
+
) // self.pretraining_tp
|
| 49 |
+
query_slices = self.q_proj.weight.split(
|
| 50 |
+
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
|
| 51 |
+
)
|
| 52 |
+
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
| 53 |
+
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
| 54 |
+
|
| 55 |
+
query_states = [
|
| 56 |
+
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
|
| 57 |
+
]
|
| 58 |
+
query_states = torch.cat(query_states, dim=-1)
|
| 59 |
+
|
| 60 |
+
key_states = [
|
| 61 |
+
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
|
| 62 |
+
]
|
| 63 |
+
key_states = torch.cat(key_states, dim=-1)
|
| 64 |
+
|
| 65 |
+
value_states = [
|
| 66 |
+
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
|
| 67 |
+
]
|
| 68 |
+
value_states = torch.cat(value_states, dim=-1)
|
| 69 |
+
|
| 70 |
+
else:
|
| 71 |
+
query_states = self.q_proj(hidden_states)
|
| 72 |
+
key_states = self.k_proj(hidden_states)
|
| 73 |
+
value_states = self.v_proj(hidden_states)
|
| 74 |
+
|
| 75 |
+
query_states = query_states.view(
|
| 76 |
+
bsz, q_len, self.num_heads, self.head_dim
|
| 77 |
+
).transpose(1, 2)
|
| 78 |
+
key_states = key_states.view(
|
| 79 |
+
bsz, q_len, self.num_key_value_heads, self.head_dim
|
| 80 |
+
).transpose(1, 2)
|
| 81 |
+
value_states = value_states.view(
|
| 82 |
+
bsz, q_len, self.num_key_value_heads, self.head_dim
|
| 83 |
+
).transpose(1, 2)
|
| 84 |
|
| 85 |
kv_seq_len = key_states.shape[-2]
|
| 86 |
if past_key_value is not None:
|
|
|
|
| 101 |
|
| 102 |
past_key_value = (key_states, value_states) if use_cache else None
|
| 103 |
|
| 104 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
| 105 |
+
key_states = transformers.models.llama.modeling_llama.repeat_kv(
|
| 106 |
+
key_states, self.num_key_value_groups
|
| 107 |
+
)
|
| 108 |
+
value_states = transformers.models.llama.modeling_llama.repeat_kv(
|
| 109 |
+
value_states, self.num_key_value_groups
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
# We only apply xformers optimizations if we don't need to output the whole attention matrix
|
| 113 |
if not output_attentions:
|
| 114 |
query_states = query_states.transpose(1, 2)
|
|
|
|
| 164 |
f" {attn_output.size()}"
|
| 165 |
)
|
| 166 |
|
| 167 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 168 |
+
# end x-formers vs. not x-formers if-else block
|
| 169 |
|
| 170 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
| 171 |
+
|
| 172 |
+
if self.pretraining_tp > 1:
|
| 173 |
+
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
|
| 174 |
+
o_proj_slices = self.o_proj.weight.split(
|
| 175 |
+
self.hidden_size // self.pretraining_tp, dim=1
|
| 176 |
+
)
|
| 177 |
+
attn_output = sum(
|
| 178 |
+
F.linear(attn_output[i], o_proj_slices[i])
|
| 179 |
+
for i in range(self.pretraining_tp)
|
| 180 |
+
)
|
| 181 |
+
else:
|
| 182 |
+
attn_output = self.o_proj(attn_output)
|
| 183 |
+
|
| 184 |
return attn_output, attn_weights, past_key_value
|
| 185 |
|
| 186 |
|