make sure to capture non-null defaults from config validation (#1415)
Browse files
src/axolotl/utils/config/__init__.py
CHANGED
|
@@ -208,11 +208,11 @@ def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
|
|
| 208 |
dict(
|
| 209 |
AxolotlConfigWCapabilities(
|
| 210 |
**cfg.to_dict(), capabilities=capabilities
|
| 211 |
-
).model_dump(
|
| 212 |
)
|
| 213 |
)
|
| 214 |
return DictDefault(
|
| 215 |
-
dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(
|
| 216 |
)
|
| 217 |
|
| 218 |
|
|
|
|
| 208 |
dict(
|
| 209 |
AxolotlConfigWCapabilities(
|
| 210 |
**cfg.to_dict(), capabilities=capabilities
|
| 211 |
+
).model_dump(exclude_none=True)
|
| 212 |
)
|
| 213 |
)
|
| 214 |
return DictDefault(
|
| 215 |
+
dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_none=True))
|
| 216 |
)
|
| 217 |
|
| 218 |
|
src/axolotl/utils/config/models/input/v0_4_1/__init__.py
CHANGED
|
@@ -151,12 +151,6 @@ class PeftConfig(BaseModel):
|
|
| 151 |
loftq_config: Optional[LoftQConfig] = None
|
| 152 |
|
| 153 |
|
| 154 |
-
class AutoType(str, Enum):
|
| 155 |
-
"""auto type string configuration subset - used for bf16"""
|
| 156 |
-
|
| 157 |
-
AUTO = "auto"
|
| 158 |
-
|
| 159 |
-
|
| 160 |
class SpecialTokensConfig(BaseModel):
|
| 161 |
"""Special tokens configuration subset"""
|
| 162 |
|
|
@@ -307,12 +301,14 @@ class HyperparametersConfig(BaseModel):
|
|
| 307 |
},
|
| 308 |
)
|
| 309 |
|
| 310 |
-
train_on_inputs: Optional[bool] =
|
| 311 |
group_by_length: Optional[bool] = None
|
| 312 |
|
| 313 |
learning_rate: Union[str, float]
|
| 314 |
-
weight_decay: Optional[float] =
|
| 315 |
-
optimizer: Optional[
|
|
|
|
|
|
|
| 316 |
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
|
| 317 |
default=None, metadata={"help": "Optional arguments to supply to optimizer."}
|
| 318 |
)
|
|
@@ -323,7 +319,7 @@ class HyperparametersConfig(BaseModel):
|
|
| 323 |
},
|
| 324 |
)
|
| 325 |
torchdistx_path: Optional[str] = None
|
| 326 |
-
lr_scheduler: Optional[SchedulerType] =
|
| 327 |
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
|
| 328 |
lr_quadratic_warmup: Optional[bool] = None
|
| 329 |
cosine_min_lr_ratio: Optional[float] = None
|
|
@@ -473,7 +469,7 @@ class AxolotlInputConfig(
|
|
| 473 |
loss_watchdog_threshold: Optional[float] = None
|
| 474 |
loss_watchdog_patience: Optional[int] = None
|
| 475 |
|
| 476 |
-
bf16: Optional[Union[
|
| 477 |
fp16: Optional[bool] = None
|
| 478 |
bfloat16: Optional[bool] = None # for non-AMP cases
|
| 479 |
float16: Optional[bool] = None # for non-AMP cases
|
|
@@ -487,7 +483,7 @@ class AxolotlInputConfig(
|
|
| 487 |
|
| 488 |
unfrozen_parameters: Optional[List[str]] = None
|
| 489 |
|
| 490 |
-
sequence_len: int = Field(default=
|
| 491 |
sample_packing: Optional[bool] = None
|
| 492 |
eval_sample_packing: Optional[bool] = None
|
| 493 |
pad_to_sequence_len: Optional[bool] = None
|
|
@@ -548,10 +544,10 @@ class AxolotlInputConfig(
|
|
| 548 |
sample_packing_eff_est: Optional[float] = None
|
| 549 |
axolotl_config_path: Optional[str] = None
|
| 550 |
|
| 551 |
-
is_falcon_derived_model: Optional[bool] = Field(default=
|
| 552 |
-
is_llama_derived_model: Optional[bool] = Field(default=
|
| 553 |
-
is_mistral_derived_model: Optional[bool] = Field(default=
|
| 554 |
-
is_qwen_derived_model: Optional[bool] = Field(default=
|
| 555 |
|
| 556 |
@field_validator("datasets", mode="before")
|
| 557 |
@classmethod
|
|
|
|
| 151 |
loftq_config: Optional[LoftQConfig] = None
|
| 152 |
|
| 153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
class SpecialTokensConfig(BaseModel):
|
| 155 |
"""Special tokens configuration subset"""
|
| 156 |
|
|
|
|
| 301 |
},
|
| 302 |
)
|
| 303 |
|
| 304 |
+
train_on_inputs: Optional[bool] = False
|
| 305 |
group_by_length: Optional[bool] = None
|
| 306 |
|
| 307 |
learning_rate: Union[str, float]
|
| 308 |
+
weight_decay: Optional[float] = 0.0
|
| 309 |
+
optimizer: Optional[
|
| 310 |
+
Union[OptimizerNames, Literal["lion_pytorch"]]
|
| 311 |
+
] = OptimizerNames.ADAMW_HF.value
|
| 312 |
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
|
| 313 |
default=None, metadata={"help": "Optional arguments to supply to optimizer."}
|
| 314 |
)
|
|
|
|
| 319 |
},
|
| 320 |
)
|
| 321 |
torchdistx_path: Optional[str] = None
|
| 322 |
+
lr_scheduler: Optional[SchedulerType] = "cosine"
|
| 323 |
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
|
| 324 |
lr_quadratic_warmup: Optional[bool] = None
|
| 325 |
cosine_min_lr_ratio: Optional[float] = None
|
|
|
|
| 469 |
loss_watchdog_threshold: Optional[float] = None
|
| 470 |
loss_watchdog_patience: Optional[int] = None
|
| 471 |
|
| 472 |
+
bf16: Optional[Union[Literal["auto"], bool]] = "auto"
|
| 473 |
fp16: Optional[bool] = None
|
| 474 |
bfloat16: Optional[bool] = None # for non-AMP cases
|
| 475 |
float16: Optional[bool] = None # for non-AMP cases
|
|
|
|
| 483 |
|
| 484 |
unfrozen_parameters: Optional[List[str]] = None
|
| 485 |
|
| 486 |
+
sequence_len: int = Field(default=512)
|
| 487 |
sample_packing: Optional[bool] = None
|
| 488 |
eval_sample_packing: Optional[bool] = None
|
| 489 |
pad_to_sequence_len: Optional[bool] = None
|
|
|
|
| 544 |
sample_packing_eff_est: Optional[float] = None
|
| 545 |
axolotl_config_path: Optional[str] = None
|
| 546 |
|
| 547 |
+
is_falcon_derived_model: Optional[bool] = Field(default=None)
|
| 548 |
+
is_llama_derived_model: Optional[bool] = Field(default=None)
|
| 549 |
+
is_mistral_derived_model: Optional[bool] = Field(default=None)
|
| 550 |
+
is_qwen_derived_model: Optional[bool] = Field(default=None)
|
| 551 |
|
| 552 |
@field_validator("datasets", mode="before")
|
| 553 |
@classmethod
|
tests/test_validation.py
CHANGED
|
@@ -54,6 +54,18 @@ class TestValidation(BaseValidation):
|
|
| 54 |
Test the validation module
|
| 55 |
"""
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
def test_datasets_min_length(self):
|
| 58 |
cfg = DictDefault(
|
| 59 |
{
|
|
|
|
| 54 |
Test the validation module
|
| 55 |
"""
|
| 56 |
|
| 57 |
+
def test_defaults(self, minimal_cfg):
|
| 58 |
+
test_cfg = DictDefault(
|
| 59 |
+
{
|
| 60 |
+
"weight_decay": None,
|
| 61 |
+
}
|
| 62 |
+
| minimal_cfg
|
| 63 |
+
)
|
| 64 |
+
cfg = validate_config(test_cfg)
|
| 65 |
+
|
| 66 |
+
assert cfg.train_on_inputs is False
|
| 67 |
+
assert cfg.weight_decay is None
|
| 68 |
+
|
| 69 |
def test_datasets_min_length(self):
|
| 70 |
cfg = DictDefault(
|
| 71 |
{
|