support adamw and grad norm hyperparams
Browse files
src/axolotl/utils/trainer.py
CHANGED
|
@@ -115,6 +115,15 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
| 115 |
# TODO search Path("./") for one
|
| 116 |
training_arguments_kwargs["deepspeed"] = "./ds_config.json"
|
| 117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
training_args = transformers.TrainingArguments(
|
| 119 |
per_device_train_batch_size=cfg.micro_batch_size,
|
| 120 |
per_device_eval_batch_size=cfg.eval_batch_size
|
|
|
|
| 115 |
# TODO search Path("./") for one
|
| 116 |
training_arguments_kwargs["deepspeed"] = "./ds_config.json"
|
| 117 |
|
| 118 |
+
if cfg.adam_beta1:
|
| 119 |
+
training_arguments_kwargs["adam_beta1"] = cfg.adam_beta1
|
| 120 |
+
if cfg.adam_beta2:
|
| 121 |
+
training_arguments_kwargs["adam_beta2"] = cfg.adam_beta2
|
| 122 |
+
if cfg.adam_epsilon:
|
| 123 |
+
training_arguments_kwargs["adam_epsilon"] = cfg.adam_epsilon
|
| 124 |
+
if cfg.max_grad_norm:
|
| 125 |
+
training_arguments_kwargs["max_grad_norm"] = cfg.max_grad_norm
|
| 126 |
+
|
| 127 |
training_args = transformers.TrainingArguments(
|
| 128 |
per_device_train_batch_size=cfg.micro_batch_size,
|
| 129 |
per_device_eval_batch_size=cfg.eval_batch_size
|