gather/broadcast the max value of the packing efficiency automatically (#463)
Browse files- src/axolotl/utils/distributed.py +88 -0
- src/axolotl/utils/trainer.py +32 -12
src/axolotl/utils/distributed.py
CHANGED
|
@@ -121,3 +121,91 @@ def broadcast_dict(vals: dict):
|
|
| 121 |
vals = pickle.loads(data_byte) # nosec
|
| 122 |
|
| 123 |
return vals
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
vals = pickle.loads(data_byte) # nosec
|
| 122 |
|
| 123 |
return vals
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def compute_and_broadcast(fn): # pylint: disable=invalid-name
|
| 127 |
+
"""
|
| 128 |
+
Compute a value using the function 'fn' only on the specified rank (default is 0).
|
| 129 |
+
The value is then broadcasted to all other ranks.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
- fn (callable): A function that computes the value. This should not have any side effects.
|
| 133 |
+
- rank (int, optional): The rank that computes the value. Default is 0.
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
- The computed value (int or float).
|
| 137 |
+
"""
|
| 138 |
+
if is_main_process():
|
| 139 |
+
value_scalar = fn()
|
| 140 |
+
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
|
| 141 |
+
else:
|
| 142 |
+
value_tensor = torch.tensor(0.0, device=dist.get_rank()) # Placeholder tensor
|
| 143 |
+
|
| 144 |
+
# Broadcast the tensor to all processes.
|
| 145 |
+
barrier()
|
| 146 |
+
dist.broadcast(value_tensor, src=0)
|
| 147 |
+
|
| 148 |
+
# Convert the tensor back to its original type (int or float)
|
| 149 |
+
if value_tensor == value_tensor.int():
|
| 150 |
+
return int(value_tensor.item())
|
| 151 |
+
return float(value_tensor.item())
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def gather_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
|
| 155 |
+
"""
|
| 156 |
+
Run a callable 'fn' on all ranks and gather the results on the specified rank.
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
- fn (callable): A function that computes the value. This should not have any side effects.
|
| 160 |
+
- rank (int, optional): The rank that gathers the values. Default is 0.
|
| 161 |
+
- world_size (int, optional): Total number of processes in the current distributed setup.
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
- A list of computed values from all ranks if on the gathering rank, otherwise None.
|
| 165 |
+
"""
|
| 166 |
+
value_scalar = fn()
|
| 167 |
+
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
|
| 168 |
+
|
| 169 |
+
# Placeholder tensor for gathering results
|
| 170 |
+
if is_main_process():
|
| 171 |
+
gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)]
|
| 172 |
+
else:
|
| 173 |
+
gathered_tensors = None
|
| 174 |
+
|
| 175 |
+
dist.gather(value_tensor, gather_list=gathered_tensors, dst=0)
|
| 176 |
+
|
| 177 |
+
if is_main_process():
|
| 178 |
+
# Convert tensors back to their original type (int or float)
|
| 179 |
+
gathered_values = []
|
| 180 |
+
for tensor in gathered_tensors:
|
| 181 |
+
if tensor == tensor.int():
|
| 182 |
+
gathered_values.append(int(tensor.item()))
|
| 183 |
+
else:
|
| 184 |
+
gathered_values.append(float(tensor.item()))
|
| 185 |
+
return gathered_values
|
| 186 |
+
return None
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def reduce_and_broadcast(fn1, fn2):
|
| 190 |
+
"""
|
| 191 |
+
Run a callable 'fn1' on all ranks, gather the results, reduce them using 'fn2',
|
| 192 |
+
and then broadcast the reduced result to all ranks.
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
- fn1 (callable): A function that computes the value on each rank.
|
| 196 |
+
- fn2 (callable): A reduction function that takes a list of values and returns a single value.
|
| 197 |
+
- world_size (int, optional): Total number of processes in the current distributed setup.
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
- The reduced and broadcasted value.
|
| 201 |
+
"""
|
| 202 |
+
|
| 203 |
+
# Gather values from all ranks using fn1
|
| 204 |
+
if not is_distributed():
|
| 205 |
+
return fn2([fn1()])
|
| 206 |
+
|
| 207 |
+
gathered_values = gather_from_all_ranks(fn1, world_size=dist.get_world_size())
|
| 208 |
+
|
| 209 |
+
# Use compute_and_broadcast to compute the reduced value on the main process
|
| 210 |
+
# and then broadcast it to all ranks
|
| 211 |
+
return compute_and_broadcast(lambda: fn2(gathered_values))
|
src/axolotl/utils/trainer.py
CHANGED
|
@@ -8,11 +8,12 @@ from contextlib import contextmanager
|
|
| 8 |
from dataclasses import dataclass, field
|
| 9 |
from functools import partial
|
| 10 |
from pathlib import Path
|
| 11 |
-
from typing import Optional, Union
|
| 12 |
|
| 13 |
import numpy as np
|
| 14 |
import torch
|
| 15 |
import torch.cuda
|
|
|
|
| 16 |
import transformers
|
| 17 |
from datasets import Dataset, set_caching_enabled
|
| 18 |
from torch.optim.lr_scheduler import OneCycleLR
|
|
@@ -35,7 +36,12 @@ from axolotl.utils.callbacks import (
|
|
| 35 |
)
|
| 36 |
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
| 37 |
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
| 38 |
-
from axolotl.utils.distributed import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
| 40 |
|
| 41 |
LOG = logging.getLogger("axolotl")
|
|
@@ -456,7 +462,16 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
|
| 456 |
f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}"
|
| 457 |
)
|
| 458 |
else:
|
| 459 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 460 |
data_loader = MultipackDistributedDataloader(
|
| 461 |
train_dataset,
|
| 462 |
batch_size=cfg.micro_batch_size,
|
|
@@ -474,18 +489,23 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
|
| 474 |
data_loader_len = data_loader.len_w_stats()
|
| 475 |
actual_eff = data_loader.efficiency()
|
| 476 |
LOG.info(f"data_loader_len: {data_loader_len}")
|
| 477 |
-
total_num_steps = int(
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 484 |
)
|
| 485 |
LOG.info(
|
| 486 |
-
f"π UPDATE CONFIG WITH: `sample_packing_eff_est: {
|
| 487 |
)
|
| 488 |
-
cfg.sample_packing_eff_est =
|
| 489 |
else:
|
| 490 |
total_num_steps = int(
|
| 491 |
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
|
|
|
| 8 |
from dataclasses import dataclass, field
|
| 9 |
from functools import partial
|
| 10 |
from pathlib import Path
|
| 11 |
+
from typing import List, Optional, Union
|
| 12 |
|
| 13 |
import numpy as np
|
| 14 |
import torch
|
| 15 |
import torch.cuda
|
| 16 |
+
import torch.distributed as dist
|
| 17 |
import transformers
|
| 18 |
from datasets import Dataset, set_caching_enabled
|
| 19 |
from torch.optim.lr_scheduler import OneCycleLR
|
|
|
|
| 36 |
)
|
| 37 |
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
| 38 |
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
| 39 |
+
from axolotl.utils.distributed import (
|
| 40 |
+
is_distributed,
|
| 41 |
+
is_main_process,
|
| 42 |
+
reduce_and_broadcast,
|
| 43 |
+
zero_first,
|
| 44 |
+
)
|
| 45 |
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
| 46 |
|
| 47 |
LOG = logging.getLogger("axolotl")
|
|
|
|
| 462 |
f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}"
|
| 463 |
)
|
| 464 |
else:
|
| 465 |
+
if cfg.world_size > 1 and is_distributed():
|
| 466 |
+
sampler = DistributedSampler(
|
| 467 |
+
train_dataset,
|
| 468 |
+
num_replicas=cfg.world_size,
|
| 469 |
+
rank=dist.get_rank(),
|
| 470 |
+
seed=cfg.seed or 42,
|
| 471 |
+
)
|
| 472 |
+
else:
|
| 473 |
+
sampler = RandomSampler(train_dataset)
|
| 474 |
+
|
| 475 |
data_loader = MultipackDistributedDataloader(
|
| 476 |
train_dataset,
|
| 477 |
batch_size=cfg.micro_batch_size,
|
|
|
|
| 489 |
data_loader_len = data_loader.len_w_stats()
|
| 490 |
actual_eff = data_loader.efficiency()
|
| 491 |
LOG.info(f"data_loader_len: {data_loader_len}")
|
| 492 |
+
total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs))
|
| 493 |
+
|
| 494 |
+
def calc_sample_packing_eff_est(estimates: List[float]):
|
| 495 |
+
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
|
| 496 |
+
return max(estimates)
|
| 497 |
+
|
| 498 |
+
sample_packing_actual_eff_all = reduce_and_broadcast(
|
| 499 |
+
lambda: actual_eff,
|
| 500 |
+
calc_sample_packing_eff_est,
|
| 501 |
+
)
|
| 502 |
+
sample_packing_eff_est = (
|
| 503 |
+
math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0
|
| 504 |
)
|
| 505 |
LOG.info(
|
| 506 |
+
f"π UPDATE CONFIG WITH: `sample_packing_eff_est: {sample_packing_eff_est}`"
|
| 507 |
)
|
| 508 |
+
cfg.sample_packing_eff_est = sample_packing_eff_est
|
| 509 |
else:
|
| 510 |
total_num_steps = int(
|
| 511 |
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|