Early stopping metric (#537)
Browse files* set early stopping metric to check
* tweak how load_best_model_at_end gets set for early stopping
* add validation for earl;y stopping patience
* remove negation
* save results to metrics in callback
* move early stopping callback after the benchmark evals
* broadcast metrics so early stopping works
src/axolotl/utils/callbacks.py
CHANGED
|
@@ -25,6 +25,7 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
|
| 25 |
from axolotl.utils.bench import log_gpu_memory_usage
|
| 26 |
from axolotl.utils.distributed import (
|
| 27 |
barrier,
|
|
|
|
| 28 |
gather_scalar_from_all_ranks,
|
| 29 |
get_world_size,
|
| 30 |
is_distributed,
|
|
@@ -271,6 +272,7 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
|
| 271 |
lambda: len(data_loader), get_world_size()
|
| 272 |
)
|
| 273 |
|
|
|
|
| 274 |
if is_distributed() and not is_main_process():
|
| 275 |
dist.gather_object(local_bench_names, dst=0)
|
| 276 |
else:
|
|
@@ -316,4 +318,8 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
|
| 316 |
)["accuracy"]
|
| 317 |
trainer.log(results)
|
| 318 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
return BenchEvalCallback
|
|
|
|
| 25 |
from axolotl.utils.bench import log_gpu_memory_usage
|
| 26 |
from axolotl.utils.distributed import (
|
| 27 |
barrier,
|
| 28 |
+
broadcast_dict,
|
| 29 |
gather_scalar_from_all_ranks,
|
| 30 |
get_world_size,
|
| 31 |
is_distributed,
|
|
|
|
| 272 |
lambda: len(data_loader), get_world_size()
|
| 273 |
)
|
| 274 |
|
| 275 |
+
results = {}
|
| 276 |
if is_distributed() and not is_main_process():
|
| 277 |
dist.gather_object(local_bench_names, dst=0)
|
| 278 |
else:
|
|
|
|
| 318 |
)["accuracy"]
|
| 319 |
trainer.log(results)
|
| 320 |
|
| 321 |
+
results = broadcast_dict(results)
|
| 322 |
+
for key, val in results.items():
|
| 323 |
+
metrics[key] = val
|
| 324 |
+
|
| 325 |
return BenchEvalCallback
|
src/axolotl/utils/config.py
CHANGED
|
@@ -220,6 +220,15 @@ def validate_config(cfg):
|
|
| 220 |
"sample_packing not compatible with xformers_attention. Use flash_attention"
|
| 221 |
)
|
| 222 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
# TODO
|
| 224 |
# MPT 7b
|
| 225 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
|
|
|
| 220 |
"sample_packing not compatible with xformers_attention. Use flash_attention"
|
| 221 |
)
|
| 222 |
|
| 223 |
+
if cfg.early_stopping_patience:
|
| 224 |
+
if not cfg.save_steps or not cfg.eval_steps:
|
| 225 |
+
raise ValueError(
|
| 226 |
+
"`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps."
|
| 227 |
+
)
|
| 228 |
+
if cfg.save_steps % cfg.eval_steps != 0:
|
| 229 |
+
raise ValueError(
|
| 230 |
+
"`early_stopping_patience` requires that eval_steps should evenly divide save_steps."
|
| 231 |
+
)
|
| 232 |
# TODO
|
| 233 |
# MPT 7b
|
| 234 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
src/axolotl/utils/distributed.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
| 2 |
utility helpers for distributed checks
|
| 3 |
"""
|
| 4 |
import os
|
|
|
|
| 5 |
from contextlib import contextmanager
|
| 6 |
|
| 7 |
import torch
|
|
@@ -93,3 +94,30 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
|
|
| 93 |
gathered_values.append(float(tensor.item()))
|
| 94 |
return gathered_values
|
| 95 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
utility helpers for distributed checks
|
| 3 |
"""
|
| 4 |
import os
|
| 5 |
+
import pickle # nosec
|
| 6 |
from contextlib import contextmanager
|
| 7 |
|
| 8 |
import torch
|
|
|
|
| 94 |
gathered_values.append(float(tensor.item()))
|
| 95 |
return gathered_values
|
| 96 |
return None
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def broadcast_dict(vals: dict):
|
| 100 |
+
if not is_distributed():
|
| 101 |
+
return vals
|
| 102 |
+
|
| 103 |
+
if is_main_process():
|
| 104 |
+
data_byte = pickle.dumps(vals)
|
| 105 |
+
data_tensor = torch.ByteTensor(list(data_byte)).to("cuda")
|
| 106 |
+
data_size = torch.IntTensor([len(data_byte)]).to("cuda")
|
| 107 |
+
else:
|
| 108 |
+
data_tensor = torch.empty([1024], dtype=torch.uint8, device="cuda")
|
| 109 |
+
data_size = torch.IntTensor([0]).to("cuda")
|
| 110 |
+
|
| 111 |
+
dist.broadcast(data_size, 0)
|
| 112 |
+
if not is_main_process():
|
| 113 |
+
# resize
|
| 114 |
+
data_tensor = data_tensor.new_empty([data_size.item()])
|
| 115 |
+
|
| 116 |
+
dist.broadcast(data_tensor, 0)
|
| 117 |
+
|
| 118 |
+
if not is_main_process():
|
| 119 |
+
data_list = data_tensor.cpu().tolist()
|
| 120 |
+
data_byte = bytes(data_list[: data_size.item()])
|
| 121 |
+
vals = pickle.loads(data_byte) # nosec
|
| 122 |
+
|
| 123 |
+
return vals
|
src/axolotl/utils/trainer.py
CHANGED
|
@@ -576,6 +576,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|
| 576 |
training_arguments_kwargs["do_bench_eval"] = cfg.do_bench_eval
|
| 577 |
if cfg.bench_dataset:
|
| 578 |
training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
| 579 |
|
| 580 |
# DDP Config
|
| 581 |
if cfg.ddp_timeout:
|
|
@@ -601,11 +605,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|
| 601 |
output_dir=cfg.output_dir,
|
| 602 |
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
|
| 603 |
load_best_model_at_end=(
|
| 604 |
-
cfg.load_best_model_at_end is not False
|
| 605 |
and cfg.val_set_size > 0
|
| 606 |
and cfg.save_steps
|
| 607 |
and cfg.save_steps % cfg.eval_steps == 0
|
| 608 |
-
and cfg.load_in_8bit is not True
|
| 609 |
)
|
| 610 |
or False,
|
| 611 |
ddp_find_unused_parameters=False if cfg.ddp else None,
|
|
@@ -637,13 +640,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|
| 637 |
if cfg.relora_steps:
|
| 638 |
callbacks.append(ReLoRACallback(cfg))
|
| 639 |
|
| 640 |
-
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
| 641 |
-
if cfg.early_stopping_patience:
|
| 642 |
-
early_stop_cb = EarlyStoppingCallback(
|
| 643 |
-
cfg.early_stopping_patience,
|
| 644 |
-
)
|
| 645 |
-
callbacks.append(early_stop_cb)
|
| 646 |
-
|
| 647 |
if cfg.local_rank == 0 and cfg.adapter in [
|
| 648 |
"lora",
|
| 649 |
"qlora",
|
|
@@ -710,4 +706,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|
| 710 |
if cfg.do_bench_eval:
|
| 711 |
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
|
| 712 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 713 |
return trainer
|
|
|
|
| 576 |
training_arguments_kwargs["do_bench_eval"] = cfg.do_bench_eval
|
| 577 |
if cfg.bench_dataset:
|
| 578 |
training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset
|
| 579 |
+
if cfg.metric_for_best_model:
|
| 580 |
+
training_arguments_kwargs["metric_for_best_model"] = cfg.metric_for_best_model
|
| 581 |
+
if cfg.greater_is_better:
|
| 582 |
+
training_arguments_kwargs["greater_is_better"] = cfg.greater_is_better
|
| 583 |
|
| 584 |
# DDP Config
|
| 585 |
if cfg.ddp_timeout:
|
|
|
|
| 605 |
output_dir=cfg.output_dir,
|
| 606 |
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
|
| 607 |
load_best_model_at_end=(
|
| 608 |
+
(cfg.load_best_model_at_end is not False or cfg.early_stopping_patience)
|
| 609 |
and cfg.val_set_size > 0
|
| 610 |
and cfg.save_steps
|
| 611 |
and cfg.save_steps % cfg.eval_steps == 0
|
|
|
|
| 612 |
)
|
| 613 |
or False,
|
| 614 |
ddp_find_unused_parameters=False if cfg.ddp else None,
|
|
|
|
| 640 |
if cfg.relora_steps:
|
| 641 |
callbacks.append(ReLoRACallback(cfg))
|
| 642 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 643 |
if cfg.local_rank == 0 and cfg.adapter in [
|
| 644 |
"lora",
|
| 645 |
"qlora",
|
|
|
|
| 706 |
if cfg.do_bench_eval:
|
| 707 |
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
|
| 708 |
|
| 709 |
+
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
| 710 |
+
if cfg.early_stopping_patience:
|
| 711 |
+
early_stop_cb = EarlyStoppingCallback(
|
| 712 |
+
cfg.early_stopping_patience,
|
| 713 |
+
)
|
| 714 |
+
trainer.add_callback(early_stop_cb)
|
| 715 |
+
|
| 716 |
return trainer
|