Qwen2 (#1166)
Browse files* qwen2 multipack support
* fix qwen derived model check so it doesn't break qwen2
* fixes to ensure qwen2 packing works
* bump requirements for qwen2
* requirements typo
- requirements.txt +2 -2
- src/axolotl/core/trainer_builder.py +1 -1
- src/axolotl/monkeypatch/qwen2/__init__.py +12 -0
- src/axolotl/utils/config.py +6 -11
- src/axolotl/utils/models.py +10 -2
requirements.txt
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
| 2 |
packaging==23.2
|
| 3 |
peft==0.7.0
|
| 4 |
-
transformers
|
| 5 |
tokenizers==0.15.0
|
| 6 |
bitsandbytes>=0.41.1
|
| 7 |
-
accelerate
|
| 8 |
deepspeed
|
| 9 |
addict
|
| 10 |
fire
|
|
|
|
| 1 |
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
| 2 |
packaging==23.2
|
| 3 |
peft==0.7.0
|
| 4 |
+
transformers==4.37.0
|
| 5 |
tokenizers==0.15.0
|
| 6 |
bitsandbytes>=0.41.1
|
| 7 |
+
accelerate==0.26.1
|
| 8 |
deepspeed
|
| 9 |
addict
|
| 10 |
fire
|
src/axolotl/core/trainer_builder.py
CHANGED
|
@@ -905,7 +905,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
| 905 |
]
|
| 906 |
]
|
| 907 |
if use_batch_sampler_collator:
|
| 908 |
-
if self.cfg.model_config_type
|
| 909 |
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
| 910 |
else:
|
| 911 |
collator = BatchSamplerDataCollatorForSeq2Seq
|
|
|
|
| 905 |
]
|
| 906 |
]
|
| 907 |
if use_batch_sampler_collator:
|
| 908 |
+
if self.cfg.model_config_type in ["mixtral", "qwen2"]:
|
| 909 |
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
| 910 |
else:
|
| 911 |
collator = BatchSamplerDataCollatorForSeq2Seq
|
src/axolotl/monkeypatch/qwen2/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Patches to support multipack for qwen2
|
| 3 |
+
"""
|
| 4 |
+
import transformers
|
| 5 |
+
|
| 6 |
+
from axolotl.monkeypatch.utils import get_unpad_data
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def replace_qwen2_attn_with_multipack_flash_attn():
|
| 10 |
+
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
| 11 |
+
get_unpad_data
|
| 12 |
+
)
|
src/axolotl/utils/config.py
CHANGED
|
@@ -142,17 +142,12 @@ def normalize_config(cfg):
|
|
| 142 |
)
|
| 143 |
|
| 144 |
cfg.is_qwen_derived_model = (
|
| 145 |
-
(
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
)
|
| 152 |
-
or cfg.is_qwen_derived_model
|
| 153 |
-
or "qwen" in cfg.base_model.lower()
|
| 154 |
-
or (cfg.model_type and "qwen" in cfg.model_type.lower())
|
| 155 |
-
)
|
| 156 |
|
| 157 |
if isinstance(cfg.learning_rate, str):
|
| 158 |
cfg.learning_rate = float(cfg.learning_rate)
|
|
|
|
| 142 |
)
|
| 143 |
|
| 144 |
cfg.is_qwen_derived_model = (
|
| 145 |
+
hasattr(model_config, "model_type")
|
| 146 |
+
and model_config.model_type
|
| 147 |
+
in [
|
| 148 |
+
"qwen",
|
| 149 |
+
]
|
| 150 |
+
) or cfg.is_qwen_derived_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
if isinstance(cfg.learning_rate, str):
|
| 153 |
cfg.learning_rate = float(cfg.learning_rate)
|
src/axolotl/utils/models.py
CHANGED
|
@@ -334,6 +334,14 @@ def load_model(
|
|
| 334 |
LOG.info("patching mixtral with flash attention")
|
| 335 |
replace_mixtral_attn_with_multipack_flash_attn()
|
| 336 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
|
| 338 |
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
| 339 |
|
|
@@ -426,14 +434,14 @@ def load_model(
|
|
| 426 |
cfg.is_llama_derived_model
|
| 427 |
or cfg.is_falcon_derived_model
|
| 428 |
or cfg.is_mistral_derived_model
|
| 429 |
-
or model_config.model_type
|
| 430 |
):
|
| 431 |
model_kwargs["attn_implementation"] = "flash_attention_2"
|
| 432 |
model_config._attn_implementation = ( # pylint: disable=protected-access
|
| 433 |
"flash_attention_2"
|
| 434 |
)
|
| 435 |
else:
|
| 436 |
-
if model_config.model_type
|
| 437 |
model_kwargs["attn_implementation"] = "flash_attention_2"
|
| 438 |
model_config._attn_implementation = ( # pylint: disable=protected-access
|
| 439 |
"flash_attention_2"
|
|
|
|
| 334 |
LOG.info("patching mixtral with flash attention")
|
| 335 |
replace_mixtral_attn_with_multipack_flash_attn()
|
| 336 |
|
| 337 |
+
if cfg.model_config_type == "qwen2" and cfg.flash_attention and cfg.sample_packing:
|
| 338 |
+
from axolotl.monkeypatch.qwen2 import (
|
| 339 |
+
replace_qwen2_attn_with_multipack_flash_attn,
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
LOG.info("patching qwen2 with flash attention")
|
| 343 |
+
replace_qwen2_attn_with_multipack_flash_attn()
|
| 344 |
+
|
| 345 |
if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
|
| 346 |
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
| 347 |
|
|
|
|
| 434 |
cfg.is_llama_derived_model
|
| 435 |
or cfg.is_falcon_derived_model
|
| 436 |
or cfg.is_mistral_derived_model
|
| 437 |
+
or model_config.model_type in ["mixtral", "qwen2"]
|
| 438 |
):
|
| 439 |
model_kwargs["attn_implementation"] = "flash_attention_2"
|
| 440 |
model_config._attn_implementation = ( # pylint: disable=protected-access
|
| 441 |
"flash_attention_2"
|
| 442 |
)
|
| 443 |
else:
|
| 444 |
+
if model_config.model_type in ["mixtral", "qwen2"]:
|
| 445 |
model_kwargs["attn_implementation"] = "flash_attention_2"
|
| 446 |
model_config._attn_implementation = ( # pylint: disable=protected-access
|
| 447 |
"flash_attention_2"
|