Feat: Add support for upstream FA2 (#626)
Browse files* Feat: Add support for upstream FA2
* chore: add is_falcon_derived_model: true to examples
* chore: add config to readme for documentation
* feat: add extra model types
* fix: remove old falcon flash patch
* chore: pin transformers and accelerate
- README.md +4 -0
- examples/falcon/config-7b-lora.yml +1 -0
- examples/falcon/config-7b-qlora.yml +1 -0
- examples/falcon/config-7b.yml +1 -0
- requirements.txt +2 -2
- src/axolotl/monkeypatch/falcon_attn_hijack_flash.py +0 -101
- src/axolotl/utils/config.py +16 -0
- src/axolotl/utils/models.py +6 -14
README.md
CHANGED
|
@@ -408,6 +408,10 @@ tokenizer_legacy:
|
|
| 408 |
# this is reported to improve training speed on some models
|
| 409 |
resize_token_embeddings_to_32x:
|
| 410 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 411 |
# whether you are training a 4-bit GPTQ quantized model
|
| 412 |
gptq: true
|
| 413 |
gptq_groupsize: 128 # group size
|
|
|
|
| 408 |
# this is reported to improve training speed on some models
|
| 409 |
resize_token_embeddings_to_32x:
|
| 410 |
|
| 411 |
+
# used to identify if the model is falcon/llama based
|
| 412 |
+
is_falcon_derived_model:
|
| 413 |
+
is_llama_derived_model:
|
| 414 |
+
|
| 415 |
# whether you are training a 4-bit GPTQ quantized model
|
| 416 |
gptq: true
|
| 417 |
gptq_groupsize: 128 # group size
|
examples/falcon/config-7b-lora.yml
CHANGED
|
@@ -3,6 +3,7 @@ base_model_config: tiiuae/falcon-7b
|
|
| 3 |
trust_remote_code: true
|
| 4 |
model_type: AutoModelForCausalLM
|
| 5 |
tokenizer_type: AutoTokenizer
|
|
|
|
| 6 |
load_in_8bit: true
|
| 7 |
load_in_4bit: false
|
| 8 |
gptq: false
|
|
|
|
| 3 |
trust_remote_code: true
|
| 4 |
model_type: AutoModelForCausalLM
|
| 5 |
tokenizer_type: AutoTokenizer
|
| 6 |
+
is_falcon_derived_model: true
|
| 7 |
load_in_8bit: true
|
| 8 |
load_in_4bit: false
|
| 9 |
gptq: false
|
examples/falcon/config-7b-qlora.yml
CHANGED
|
@@ -6,6 +6,7 @@ base_model_config: tiiuae/falcon-7b
|
|
| 6 |
trust_remote_code: true
|
| 7 |
model_type: AutoModelForCausalLM
|
| 8 |
tokenizer_type: AutoTokenizer
|
|
|
|
| 9 |
load_in_8bit: false
|
| 10 |
# enable 4bit for QLoRA
|
| 11 |
load_in_4bit: true
|
|
|
|
| 6 |
trust_remote_code: true
|
| 7 |
model_type: AutoModelForCausalLM
|
| 8 |
tokenizer_type: AutoTokenizer
|
| 9 |
+
is_falcon_derived_model: true
|
| 10 |
load_in_8bit: false
|
| 11 |
# enable 4bit for QLoRA
|
| 12 |
load_in_4bit: true
|
examples/falcon/config-7b.yml
CHANGED
|
@@ -3,6 +3,7 @@ base_model_config: tiiuae/falcon-7b
|
|
| 3 |
trust_remote_code: true
|
| 4 |
model_type: AutoModelForCausalLM
|
| 5 |
tokenizer_type: AutoTokenizer
|
|
|
|
| 6 |
load_in_8bit: false
|
| 7 |
load_in_4bit: false
|
| 8 |
gptq: false
|
|
|
|
| 3 |
trust_remote_code: true
|
| 4 |
model_type: AutoModelForCausalLM
|
| 5 |
tokenizer_type: AutoTokenizer
|
| 6 |
+
is_falcon_derived_model: true
|
| 7 |
load_in_8bit: false
|
| 8 |
load_in_4bit: false
|
| 9 |
gptq: false
|
requirements.txt
CHANGED
|
@@ -4,9 +4,9 @@ torch==2.0.1
|
|
| 4 |
auto-gptq
|
| 5 |
packaging
|
| 6 |
peft @ git+https://github.com/huggingface/peft.git
|
| 7 |
-
transformers @ git+https://github.com/huggingface/transformers.git
|
| 8 |
bitsandbytes>=0.41.1
|
| 9 |
-
accelerate @ git+https://github.com/huggingface/accelerate
|
| 10 |
deepspeed
|
| 11 |
addict
|
| 12 |
evaluate
|
|
|
|
| 4 |
auto-gptq
|
| 5 |
packaging
|
| 6 |
peft @ git+https://github.com/huggingface/peft.git
|
| 7 |
+
transformers @ git+https://github.com/huggingface/transformers.git@0ac3875011d32dc85e0e83970507e3afe8f0febb
|
| 8 |
bitsandbytes>=0.41.1
|
| 9 |
+
accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9
|
| 10 |
deepspeed
|
| 11 |
addict
|
| 12 |
evaluate
|
src/axolotl/monkeypatch/falcon_attn_hijack_flash.py
DELETED
|
@@ -1,101 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Flash Attention monkey patch for Falcon
|
| 3 |
-
|
| 4 |
-
copied from https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/falcon_flash_attn_monkey_patch.py
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
from typing import Optional, Tuple
|
| 8 |
-
|
| 9 |
-
import torch
|
| 10 |
-
import transformers
|
| 11 |
-
from flash_attn import flash_attn_func
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def forward(
|
| 15 |
-
self,
|
| 16 |
-
hidden_states: torch.Tensor,
|
| 17 |
-
alibi: Optional[torch.Tensor],
|
| 18 |
-
attention_mask: torch.Tensor, # pylint: disable=unused-argument
|
| 19 |
-
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 20 |
-
head_mask: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
|
| 21 |
-
use_cache: bool = False,
|
| 22 |
-
output_attentions: bool = False, # pylint: disable=unused-argument
|
| 23 |
-
):
|
| 24 |
-
fused_qkv = self.query_key_value(
|
| 25 |
-
hidden_states
|
| 26 |
-
) # [batch_size, seq_length, 3 x hidden_size]
|
| 27 |
-
num_kv_heads = (
|
| 28 |
-
self.num_heads if self.new_decoder_architecture else self.num_kv_heads
|
| 29 |
-
)
|
| 30 |
-
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
| 31 |
-
(
|
| 32 |
-
query_layer,
|
| 33 |
-
key_layer,
|
| 34 |
-
value_layer,
|
| 35 |
-
) = self._split_heads( # pylint: disable=protected-access
|
| 36 |
-
fused_qkv
|
| 37 |
-
)
|
| 38 |
-
|
| 39 |
-
batch_size, query_length, _, _ = query_layer.shape
|
| 40 |
-
|
| 41 |
-
query_layer = query_layer.transpose(1, 2).reshape(
|
| 42 |
-
batch_size * self.num_heads, query_length, self.head_dim
|
| 43 |
-
)
|
| 44 |
-
key_layer = key_layer.transpose(1, 2).reshape(
|
| 45 |
-
batch_size * num_kv_heads,
|
| 46 |
-
query_length,
|
| 47 |
-
self.head_dim,
|
| 48 |
-
)
|
| 49 |
-
value_layer = value_layer.transpose(1, 2).reshape(
|
| 50 |
-
batch_size * num_kv_heads, query_length, self.head_dim
|
| 51 |
-
)
|
| 52 |
-
|
| 53 |
-
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
|
| 54 |
-
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
|
| 55 |
-
|
| 56 |
-
if layer_past is not None:
|
| 57 |
-
past_key, past_value = layer_past
|
| 58 |
-
# concatenate along seq_length dimension:
|
| 59 |
-
# - key: [batch_size * self.num_heads, kv_length, head_dim]
|
| 60 |
-
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
| 61 |
-
key_layer = torch.cat((past_key, key_layer), dim=1)
|
| 62 |
-
value_layer = torch.cat((past_value, value_layer), dim=1)
|
| 63 |
-
|
| 64 |
-
# unused
|
| 65 |
-
# _, kv_length, _ = key_layer.shape
|
| 66 |
-
if use_cache:
|
| 67 |
-
present = (key_layer, value_layer)
|
| 68 |
-
else:
|
| 69 |
-
present = None
|
| 70 |
-
# unused
|
| 71 |
-
# attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
|
| 72 |
-
query_layer_ = (
|
| 73 |
-
query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
| 74 |
-
.transpose(1, 2)
|
| 75 |
-
.to(torch.bfloat16)
|
| 76 |
-
)
|
| 77 |
-
key_layer_ = (
|
| 78 |
-
key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
|
| 79 |
-
.transpose(1, 2)
|
| 80 |
-
.to(torch.bfloat16)
|
| 81 |
-
)
|
| 82 |
-
value_layer_ = (
|
| 83 |
-
value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
|
| 84 |
-
.transpose(1, 2)
|
| 85 |
-
.to(torch.bfloat16)
|
| 86 |
-
)
|
| 87 |
-
|
| 88 |
-
if alibi is not None:
|
| 89 |
-
raise ValueError("`alibi` is not supported when `use_flash_attn` is True")
|
| 90 |
-
|
| 91 |
-
# below output will have shape (batch_size, seqlen, nheads, headdim)
|
| 92 |
-
attn_output = flash_attn_func(query_layer_, key_layer_, value_layer_, causal=True)
|
| 93 |
-
attn_output = attn_output.reshape(
|
| 94 |
-
batch_size, query_length, self.num_heads * self.head_dim
|
| 95 |
-
)
|
| 96 |
-
output_tensor = self.dense(attn_output)
|
| 97 |
-
return output_tensor, present
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
def replace_falcon_attn_with_flash_attn():
|
| 101 |
-
transformers.models.falcon.modeling_falcon.FalconAttention.forward = forward
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/axolotl/utils/config.py
CHANGED
|
@@ -86,6 +86,22 @@ def normalize_config(cfg):
|
|
| 86 |
or (cfg.model_type and "llama" in cfg.model_type.lower())
|
| 87 |
)
|
| 88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
| 90 |
|
| 91 |
|
|
|
|
| 86 |
or (cfg.model_type and "llama" in cfg.model_type.lower())
|
| 87 |
)
|
| 88 |
|
| 89 |
+
# figure out if the model is falcon
|
| 90 |
+
cfg.is_falcon_derived_model = (
|
| 91 |
+
(
|
| 92 |
+
hasattr(model_config, "model_type")
|
| 93 |
+
and model_config.model_type
|
| 94 |
+
in [
|
| 95 |
+
"falcon",
|
| 96 |
+
"RefinedWebModel",
|
| 97 |
+
"RefinedWeb",
|
| 98 |
+
]
|
| 99 |
+
)
|
| 100 |
+
or cfg.is_falcon_derived_model
|
| 101 |
+
or "falcon" in cfg.base_model
|
| 102 |
+
or (cfg.model_type and "rwforcausallm" in cfg.model_type.lower())
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
| 106 |
|
| 107 |
|
src/axolotl/utils/models.py
CHANGED
|
@@ -114,25 +114,13 @@ def load_model(
|
|
| 114 |
|
| 115 |
replace_btlm_attn_with_flash_attn(cfg.base_model)
|
| 116 |
|
| 117 |
-
if
|
| 118 |
-
"falcon",
|
| 119 |
-
"RefinedWebModel",
|
| 120 |
-
"RefinedWeb",
|
| 121 |
-
]:
|
| 122 |
-
if cfg.flash_attention:
|
| 123 |
-
from axolotl.monkeypatch.falcon_attn_hijack_flash import (
|
| 124 |
-
replace_falcon_attn_with_flash_attn,
|
| 125 |
-
)
|
| 126 |
-
|
| 127 |
-
replace_falcon_attn_with_flash_attn()
|
| 128 |
-
|
| 129 |
-
if cfg.is_llama_derived_model and cfg.flash_attention:
|
| 130 |
if cfg.device not in ["mps", "cpu"] and not inference:
|
| 131 |
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
| 132 |
replace_llama_attn_with_flash_attn,
|
| 133 |
)
|
| 134 |
|
| 135 |
-
LOG.info("patching with flash attention")
|
| 136 |
replace_llama_attn_with_flash_attn(packed=cfg.sample_packing)
|
| 137 |
elif cfg.is_llama_derived_model and cfg.xformers_attention:
|
| 138 |
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
|
@@ -213,6 +201,10 @@ def load_model(
|
|
| 213 |
bnb_4bit_use_double_quant=True,
|
| 214 |
bnb_4bit_quant_type="nf4",
|
| 215 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
try:
|
| 217 |
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
| 218 |
from transformers import LlamaForCausalLM
|
|
|
|
| 114 |
|
| 115 |
replace_btlm_attn_with_flash_attn(cfg.base_model)
|
| 116 |
|
| 117 |
+
if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
if cfg.device not in ["mps", "cpu"] and not inference:
|
| 119 |
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
| 120 |
replace_llama_attn_with_flash_attn,
|
| 121 |
)
|
| 122 |
|
| 123 |
+
LOG.info("patching with flash attention for sample packing")
|
| 124 |
replace_llama_attn_with_flash_attn(packed=cfg.sample_packing)
|
| 125 |
elif cfg.is_llama_derived_model and cfg.xformers_attention:
|
| 126 |
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
|
|
|
| 201 |
bnb_4bit_use_double_quant=True,
|
| 202 |
bnb_4bit_quant_type="nf4",
|
| 203 |
)
|
| 204 |
+
# sample packing uses custom FA2 patch
|
| 205 |
+
if cfg.flash_attention and not cfg.sample_packing:
|
| 206 |
+
if cfg.is_llama_derived_model or cfg.is_falcon_derived_model:
|
| 207 |
+
model_kwargs["use_flash_attention_2"] = True
|
| 208 |
try:
|
| 209 |
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
| 210 |
from transformers import LlamaForCausalLM
|