use mixins for orpo and kto configs so they work with axolotl customizations (#1674)
Browse files
src/axolotl/core/trainer_builder.py
CHANGED
|
@@ -91,11 +91,12 @@ def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
|
| 91 |
|
| 92 |
|
| 93 |
@dataclass
|
| 94 |
-
class
|
| 95 |
"""
|
| 96 |
-
|
| 97 |
"""
|
| 98 |
|
|
|
|
| 99 |
model_type: Optional[str] = field(
|
| 100 |
default=None, metadata={"help": "HF model configuration model_type."}
|
| 101 |
)
|
|
@@ -227,6 +228,30 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|
| 227 |
)
|
| 228 |
|
| 229 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
class AxolotlTrainer(Trainer):
|
| 231 |
"""
|
| 232 |
Extend the base Trainer for axolotl helpers
|
|
@@ -1583,14 +1608,14 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|
| 1583 |
|
| 1584 |
training_args_cls = AxolotlTrainingArguments
|
| 1585 |
if self.cfg.rl == "orpo":
|
| 1586 |
-
training_args_cls =
|
| 1587 |
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
| 1588 |
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
| 1589 |
if self.cfg.max_prompt_len:
|
| 1590 |
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
| 1591 |
|
| 1592 |
if self.cfg.rl == "kto":
|
| 1593 |
-
training_args_cls =
|
| 1594 |
|
| 1595 |
training_args_kwargs["beta"] = self.cfg.rl_beta or 0.1
|
| 1596 |
training_args_kwargs["desirable_weight"] = (
|
|
@@ -1605,12 +1630,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|
| 1605 |
if self.cfg.max_prompt_len:
|
| 1606 |
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
| 1607 |
|
| 1608 |
-
training_args = training_args_cls(
|
|
|
|
| 1609 |
per_device_train_batch_size=self.cfg.micro_batch_size,
|
| 1610 |
max_steps=self.cfg.max_steps or total_num_steps,
|
| 1611 |
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
| 1612 |
learning_rate=self.cfg.learning_rate,
|
| 1613 |
-
output_dir=self.cfg.output_dir,
|
| 1614 |
warmup_steps=self.cfg.warmup_steps,
|
| 1615 |
logging_first_step=True,
|
| 1616 |
logging_steps=1,
|
|
|
|
| 91 |
|
| 92 |
|
| 93 |
@dataclass
|
| 94 |
+
class AxolotlTrainingMixins:
|
| 95 |
"""
|
| 96 |
+
Mixin class for the Axolotl training args.
|
| 97 |
"""
|
| 98 |
|
| 99 |
+
# pylint: disable=duplicate-code
|
| 100 |
model_type: Optional[str] = field(
|
| 101 |
default=None, metadata={"help": "HF model configuration model_type."}
|
| 102 |
)
|
|
|
|
| 228 |
)
|
| 229 |
|
| 230 |
|
| 231 |
+
@dataclass
|
| 232 |
+
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
| 233 |
+
"""
|
| 234 |
+
Training arguments for Causal trainer
|
| 235 |
+
|
| 236 |
+
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
|
| 237 |
+
so it can't be used as a mixin.
|
| 238 |
+
"""
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
@dataclass
|
| 242 |
+
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
|
| 243 |
+
"""
|
| 244 |
+
ORPO config for ORPO training
|
| 245 |
+
"""
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
@dataclass
|
| 249 |
+
class AxolotlKTOConfig(AxolotlTrainingMixins, KTOConfig):
|
| 250 |
+
"""
|
| 251 |
+
KTO config for KTO training
|
| 252 |
+
"""
|
| 253 |
+
|
| 254 |
+
|
| 255 |
class AxolotlTrainer(Trainer):
|
| 256 |
"""
|
| 257 |
Extend the base Trainer for axolotl helpers
|
|
|
|
| 1608 |
|
| 1609 |
training_args_cls = AxolotlTrainingArguments
|
| 1610 |
if self.cfg.rl == "orpo":
|
| 1611 |
+
training_args_cls = AxolotlORPOConfig
|
| 1612 |
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
| 1613 |
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
| 1614 |
if self.cfg.max_prompt_len:
|
| 1615 |
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
| 1616 |
|
| 1617 |
if self.cfg.rl == "kto":
|
| 1618 |
+
training_args_cls = AxolotlKTOConfig
|
| 1619 |
|
| 1620 |
training_args_kwargs["beta"] = self.cfg.rl_beta or 0.1
|
| 1621 |
training_args_kwargs["desirable_weight"] = (
|
|
|
|
| 1630 |
if self.cfg.max_prompt_len:
|
| 1631 |
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
| 1632 |
|
| 1633 |
+
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
| 1634 |
+
output_dir=self.cfg.output_dir,
|
| 1635 |
per_device_train_batch_size=self.cfg.micro_batch_size,
|
| 1636 |
max_steps=self.cfg.max_steps or total_num_steps,
|
| 1637 |
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
| 1638 |
learning_rate=self.cfg.learning_rate,
|
|
|
|
| 1639 |
warmup_steps=self.cfg.warmup_steps,
|
| 1640 |
logging_first_step=True,
|
| 1641 |
logging_steps=1,
|