enable loraplus setting for dpo trainer (#1646)
Browse files
src/axolotl/core/trainer_builder.py
CHANGED
|
@@ -798,6 +798,40 @@ class AxolotlDPOTrainer(DPOTrainer):
|
|
| 798 |
|
| 799 |
tag_names = ["axolotl", "dpo"]
|
| 800 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 801 |
@wraps(DPOTrainer.push_to_hub)
|
| 802 |
def push_to_hub(self, *args, **kwargs) -> str:
|
| 803 |
"""
|
|
@@ -1483,6 +1517,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|
| 1483 |
if self.cfg.bf16 or self.cfg.bfloat16:
|
| 1484 |
training_args_kwargs["bf16"] = True
|
| 1485 |
|
|
|
|
|
|
|
| 1486 |
training_args_kwargs["lr_scheduler_type"] = (
|
| 1487 |
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
|
| 1488 |
)
|
|
@@ -1535,7 +1571,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|
| 1535 |
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
| 1536 |
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
| 1537 |
|
| 1538 |
-
training_args_cls =
|
| 1539 |
if self.cfg.rl == "orpo":
|
| 1540 |
training_args_cls = ORPOConfig
|
| 1541 |
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
|
|
|
| 798 |
|
| 799 |
tag_names = ["axolotl", "dpo"]
|
| 800 |
|
| 801 |
+
def __init__(self, *args, **kwargs):
|
| 802 |
+
super().__init__(*args, **kwargs)
|
| 803 |
+
self.optimizer = None
|
| 804 |
+
|
| 805 |
+
def create_optimizer(self):
|
| 806 |
+
if self.args.loraplus_lr_ratio is None:
|
| 807 |
+
return super().create_optimizer()
|
| 808 |
+
|
| 809 |
+
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
| 810 |
+
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
| 811 |
+
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
| 812 |
+
self.args,
|
| 813 |
+
opt_model,
|
| 814 |
+
)
|
| 815 |
+
|
| 816 |
+
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
| 817 |
+
if loraplus_lr_ratio:
|
| 818 |
+
print("Using lora+")
|
| 819 |
+
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
|
| 820 |
+
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
| 821 |
+
opt_model,
|
| 822 |
+
optimizer_cls,
|
| 823 |
+
optimizer_kwargs,
|
| 824 |
+
loraplus_lr_ratio,
|
| 825 |
+
loraplus_lr_embedding,
|
| 826 |
+
)
|
| 827 |
+
|
| 828 |
+
if is_sagemaker_mp_enabled():
|
| 829 |
+
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
| 830 |
+
self.optimizer
|
| 831 |
+
)
|
| 832 |
+
|
| 833 |
+
return self.optimizer
|
| 834 |
+
|
| 835 |
@wraps(DPOTrainer.push_to_hub)
|
| 836 |
def push_to_hub(self, *args, **kwargs) -> str:
|
| 837 |
"""
|
|
|
|
| 1517 |
if self.cfg.bf16 or self.cfg.bfloat16:
|
| 1518 |
training_args_kwargs["bf16"] = True
|
| 1519 |
|
| 1520 |
+
training_args_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
| 1521 |
+
training_args_kwargs["loraplus_lr_embedding"] = self.cfg.loraplus_lr_embedding
|
| 1522 |
training_args_kwargs["lr_scheduler_type"] = (
|
| 1523 |
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
|
| 1524 |
)
|
|
|
|
| 1571 |
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
| 1572 |
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
| 1573 |
|
| 1574 |
+
training_args_cls = AxolotlTrainingArguments
|
| 1575 |
if self.cfg.rl == "orpo":
|
| 1576 |
training_args_cls = ORPOConfig
|
| 1577 |
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|