feat: validate sample packing requires flash_attention (#1465)
Browse files* feat: validate sample packing requires flash_attention
* fix: check for sdp_attn per suggestion
* feat: add FA to tests
src/axolotl/utils/config/models/input/v0_4_1/__init__.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
"""
|
| 2 |
Module for pydantic models for configuration
|
| 3 |
"""
|
|
|
|
| 4 |
# pylint: disable=too-many-lines
|
| 5 |
|
| 6 |
import logging
|
|
@@ -655,6 +656,20 @@ class AxolotlInputConfig(
|
|
| 655 |
|
| 656 |
return data
|
| 657 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 658 |
@model_validator(mode="before")
|
| 659 |
@classmethod
|
| 660 |
def check_sample_packing_w_rl(cls, data):
|
|
|
|
| 1 |
"""
|
| 2 |
Module for pydantic models for configuration
|
| 3 |
"""
|
| 4 |
+
|
| 5 |
# pylint: disable=too-many-lines
|
| 6 |
|
| 7 |
import logging
|
|
|
|
| 656 |
|
| 657 |
return data
|
| 658 |
|
| 659 |
+
@model_validator(mode="before")
|
| 660 |
+
@classmethod
|
| 661 |
+
def check_sample_packing_wo_flash(cls, data):
|
| 662 |
+
if (
|
| 663 |
+
data.get("sample_packing")
|
| 664 |
+
and not data.get("flash_attention")
|
| 665 |
+
and not data.get("sdp_attention")
|
| 666 |
+
):
|
| 667 |
+
raise ValueError(
|
| 668 |
+
"sample_packing requires flash_attention or sdp_attention to be set to true"
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
return data
|
| 672 |
+
|
| 673 |
@model_validator(mode="before")
|
| 674 |
@classmethod
|
| 675 |
def check_sample_packing_w_rl(cls, data):
|
tests/test_validation.py
CHANGED
|
@@ -600,6 +600,7 @@ class TestValidation(BaseValidation):
|
|
| 600 |
{
|
| 601 |
"sample_packing": True,
|
| 602 |
"pad_to_sequence_len": None,
|
|
|
|
| 603 |
}
|
| 604 |
)
|
| 605 |
| minimal_cfg
|
|
@@ -901,6 +902,7 @@ class TestValidation(BaseValidation):
|
|
| 901 |
{
|
| 902 |
"sample_packing": True,
|
| 903 |
"eval_table_size": 100,
|
|
|
|
| 904 |
}
|
| 905 |
)
|
| 906 |
| minimal_cfg
|
|
@@ -916,6 +918,7 @@ class TestValidation(BaseValidation):
|
|
| 916 |
{
|
| 917 |
"sample_packing": True,
|
| 918 |
"eval_sample_packing": False,
|
|
|
|
| 919 |
}
|
| 920 |
)
|
| 921 |
| minimal_cfg
|
|
@@ -928,6 +931,7 @@ class TestValidation(BaseValidation):
|
|
| 928 |
{
|
| 929 |
"sample_packing": False,
|
| 930 |
"eval_table_size": 100,
|
|
|
|
| 931 |
}
|
| 932 |
)
|
| 933 |
| minimal_cfg
|
|
@@ -941,6 +945,7 @@ class TestValidation(BaseValidation):
|
|
| 941 |
"sample_packing": True,
|
| 942 |
"eval_table_size": 100,
|
| 943 |
"eval_sample_packing": False,
|
|
|
|
| 944 |
}
|
| 945 |
)
|
| 946 |
| minimal_cfg
|
|
|
|
| 600 |
{
|
| 601 |
"sample_packing": True,
|
| 602 |
"pad_to_sequence_len": None,
|
| 603 |
+
"flash_attention": True,
|
| 604 |
}
|
| 605 |
)
|
| 606 |
| minimal_cfg
|
|
|
|
| 902 |
{
|
| 903 |
"sample_packing": True,
|
| 904 |
"eval_table_size": 100,
|
| 905 |
+
"flash_attention": True,
|
| 906 |
}
|
| 907 |
)
|
| 908 |
| minimal_cfg
|
|
|
|
| 918 |
{
|
| 919 |
"sample_packing": True,
|
| 920 |
"eval_sample_packing": False,
|
| 921 |
+
"flash_attention": True,
|
| 922 |
}
|
| 923 |
)
|
| 924 |
| minimal_cfg
|
|
|
|
| 931 |
{
|
| 932 |
"sample_packing": False,
|
| 933 |
"eval_table_size": 100,
|
| 934 |
+
"flash_attention": True,
|
| 935 |
}
|
| 936 |
)
|
| 937 |
| minimal_cfg
|
|
|
|
| 945 |
"sample_packing": True,
|
| 946 |
"eval_table_size": 100,
|
| 947 |
"eval_sample_packing": False,
|
| 948 |
+
"flash_attention": True,
|
| 949 |
}
|
| 950 |
)
|
| 951 |
| minimal_cfg
|