FEAT: add tagging support to axolotl for DPOTrainer (#1209)
Browse files* Add AxolotlDPOTrainer
* chore: lint
---------
Co-authored-by: Wing Lian <[email protected]>
src/axolotl/core/trainer_builder.py
CHANGED
|
@@ -59,6 +59,22 @@ except ImportError:
|
|
| 59 |
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
| 60 |
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
@dataclass
|
| 63 |
class AxolotlTrainingArguments(TrainingArguments):
|
| 64 |
"""
|
|
@@ -349,30 +365,13 @@ class AxolotlTrainer(Trainer):
|
|
| 349 |
# return (loss, outputs) if return_outputs else loss
|
| 350 |
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
| 351 |
|
| 352 |
-
def _sanitize_kwargs_for_tagging(self, tag_names, kwargs=None):
|
| 353 |
-
if isinstance(tag_names, str):
|
| 354 |
-
tag_names = [tag_names]
|
| 355 |
-
|
| 356 |
-
if kwargs is not None:
|
| 357 |
-
if "tags" not in kwargs:
|
| 358 |
-
kwargs["tags"] = tag_names
|
| 359 |
-
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
|
| 360 |
-
kwargs["tags"].extend(tag_names)
|
| 361 |
-
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
|
| 362 |
-
tag_names.append(kwargs["tags"])
|
| 363 |
-
kwargs["tags"] = tag_names
|
| 364 |
-
|
| 365 |
-
return kwargs
|
| 366 |
-
|
| 367 |
@wraps(Trainer.push_to_hub)
|
| 368 |
def push_to_hub(self, *args, **kwargs) -> str:
|
| 369 |
"""
|
| 370 |
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
| 371 |
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
| 372 |
"""
|
| 373 |
-
kwargs = self.
|
| 374 |
-
tag_names=self.tag_names, kwargs=kwargs
|
| 375 |
-
)
|
| 376 |
|
| 377 |
return super().push_to_hub(*args, **kwargs)
|
| 378 |
|
|
@@ -471,6 +470,24 @@ class ReLoRATrainer(AxolotlTrainer):
|
|
| 471 |
return self.lr_scheduler
|
| 472 |
|
| 473 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 474 |
class TrainerBuilderBase(abc.ABC):
|
| 475 |
"""
|
| 476 |
Base class for trainer builder
|
|
@@ -1076,7 +1093,7 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
|
|
| 1076 |
dpo_trainer_kwargs[
|
| 1077 |
"precompute_ref_log_probs"
|
| 1078 |
] = self.cfg.precompute_ref_log_probs
|
| 1079 |
-
dpo_trainer =
|
| 1080 |
self.model,
|
| 1081 |
self.model_ref,
|
| 1082 |
args=training_args,
|
|
|
|
| 59 |
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
| 60 |
|
| 61 |
|
| 62 |
+
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
| 63 |
+
if isinstance(tag_names, str):
|
| 64 |
+
tag_names = [tag_names]
|
| 65 |
+
|
| 66 |
+
if kwargs is not None:
|
| 67 |
+
if "tags" not in kwargs:
|
| 68 |
+
kwargs["tags"] = tag_names
|
| 69 |
+
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
|
| 70 |
+
kwargs["tags"].extend(tag_names)
|
| 71 |
+
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
|
| 72 |
+
tag_names.append(kwargs["tags"])
|
| 73 |
+
kwargs["tags"] = tag_names
|
| 74 |
+
|
| 75 |
+
return kwargs
|
| 76 |
+
|
| 77 |
+
|
| 78 |
@dataclass
|
| 79 |
class AxolotlTrainingArguments(TrainingArguments):
|
| 80 |
"""
|
|
|
|
| 365 |
# return (loss, outputs) if return_outputs else loss
|
| 366 |
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
| 367 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 368 |
@wraps(Trainer.push_to_hub)
|
| 369 |
def push_to_hub(self, *args, **kwargs) -> str:
|
| 370 |
"""
|
| 371 |
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
| 372 |
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
| 373 |
"""
|
| 374 |
+
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
|
|
|
|
|
|
| 375 |
|
| 376 |
return super().push_to_hub(*args, **kwargs)
|
| 377 |
|
|
|
|
| 470 |
return self.lr_scheduler
|
| 471 |
|
| 472 |
|
| 473 |
+
class AxolotlDPOTrainer(DPOTrainer):
|
| 474 |
+
"""
|
| 475 |
+
Extend the base DPOTrainer for axolotl helpers
|
| 476 |
+
"""
|
| 477 |
+
|
| 478 |
+
tag_names = ["axolotl", "dpo"]
|
| 479 |
+
|
| 480 |
+
@wraps(DPOTrainer.push_to_hub)
|
| 481 |
+
def push_to_hub(self, *args, **kwargs) -> str:
|
| 482 |
+
"""
|
| 483 |
+
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
| 484 |
+
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
| 485 |
+
"""
|
| 486 |
+
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
| 487 |
+
|
| 488 |
+
return super().push_to_hub(*args, **kwargs)
|
| 489 |
+
|
| 490 |
+
|
| 491 |
class TrainerBuilderBase(abc.ABC):
|
| 492 |
"""
|
| 493 |
Base class for trainer builder
|
|
|
|
| 1093 |
dpo_trainer_kwargs[
|
| 1094 |
"precompute_ref_log_probs"
|
| 1095 |
] = self.cfg.precompute_ref_log_probs
|
| 1096 |
+
dpo_trainer = AxolotlDPOTrainer(
|
| 1097 |
self.model,
|
| 1098 |
self.model_ref,
|
| 1099 |
args=training_args,
|