Spaces:
Running
Running
feat(train): overhead from 70% to 1% 馃コ
Browse files- tools/train/train.py +21 -5
tools/train/train.py
CHANGED
|
@@ -777,9 +777,10 @@ def main():
|
|
| 777 |
def train_step(state, batch, delta_time):
|
| 778 |
# batch is (gradient_accumulation_steps, minibatch_size, ...)
|
| 779 |
# check correct batch shape during compilation
|
| 780 |
-
assert batch["labels"].shape[0:
|
| 781 |
training_args.gradient_accumulation_steps,
|
| 782 |
-
|
|
|
|
| 783 |
), f"Expected label batch of shape dp_devices x gradient_acculumation x batch_per_device and got {batch['labels'].shape}"
|
| 784 |
|
| 785 |
# get a minibatch (one gradient accumulation slice)
|
|
@@ -801,13 +802,27 @@ def main():
|
|
| 801 |
grad_fn = jax.value_and_grad(compute_loss)
|
| 802 |
|
| 803 |
def loss_and_grad(grad_idx, dropout_rng):
|
|
|
|
| 804 |
minibatch = get_minibatch(batch, grad_idx)
|
| 805 |
# ensure batch is sharded over devices
|
| 806 |
minibatch = jax.tree_map(
|
| 807 |
lambda x: with_sharding_constraint(x, PartitionSpec("batch")), minibatch
|
| 808 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 809 |
# return loss and grads
|
| 810 |
-
return
|
| 811 |
|
| 812 |
# create a new rng
|
| 813 |
dropout_rng, _ = jax.random.split(state.dropout_rng)
|
|
@@ -1061,12 +1076,13 @@ def main():
|
|
| 1061 |
delta_time = new_time - last_time
|
| 1062 |
last_time = new_time
|
| 1063 |
|
| 1064 |
-
# reshape data into (gradient_accumulation_steps,
|
| 1065 |
batch = jax.tree_map(
|
| 1066 |
lambda x: x.reshape(
|
| 1067 |
(
|
| 1068 |
training_args.gradient_accumulation_steps,
|
| 1069 |
-
|
|
|
|
| 1070 |
)
|
| 1071 |
+ x.shape[1:]
|
| 1072 |
),
|
|
|
|
| 777 |
def train_step(state, batch, delta_time):
|
| 778 |
# batch is (gradient_accumulation_steps, minibatch_size, ...)
|
| 779 |
# check correct batch shape during compilation
|
| 780 |
+
assert batch["labels"].shape[0:3] == (
|
| 781 |
training_args.gradient_accumulation_steps,
|
| 782 |
+
training_args.dp_devices,
|
| 783 |
+
training_args.per_device_train_batch_size,
|
| 784 |
), f"Expected label batch of shape dp_devices x gradient_acculumation x batch_per_device and got {batch['labels'].shape}"
|
| 785 |
|
| 786 |
# get a minibatch (one gradient accumulation slice)
|
|
|
|
| 802 |
grad_fn = jax.value_and_grad(compute_loss)
|
| 803 |
|
| 804 |
def loss_and_grad(grad_idx, dropout_rng):
|
| 805 |
+
# minibatch at grad_idx, shape (dp_devices, per_device_train_batch_size, ...)
|
| 806 |
minibatch = get_minibatch(batch, grad_idx)
|
| 807 |
# ensure batch is sharded over devices
|
| 808 |
minibatch = jax.tree_map(
|
| 809 |
lambda x: with_sharding_constraint(x, PartitionSpec("batch")), minibatch
|
| 810 |
)
|
| 811 |
+
# calculate loss and grads independently per dp_device
|
| 812 |
+
loss_grads = jax.vmap(grad_fn, in_axes=(None, 0, None), out_axes=(0, 0))(
|
| 813 |
+
state.params, minibatch, dropout_rng
|
| 814 |
+
)
|
| 815 |
+
# ensure they are sharded over devices
|
| 816 |
+
loss_grads = jax.tree_map(
|
| 817 |
+
lambda x: with_sharding_constraint(x, PartitionSpec("batch")),
|
| 818 |
+
loss_grads,
|
| 819 |
+
)
|
| 820 |
+
|
| 821 |
+
# average across all devices
|
| 822 |
+
loss_grads = jax.tree_map(lambda x: jnp.mean(x, axis=0), loss_grads)
|
| 823 |
+
|
| 824 |
# return loss and grads
|
| 825 |
+
return loss_grads
|
| 826 |
|
| 827 |
# create a new rng
|
| 828 |
dropout_rng, _ = jax.random.split(state.dropout_rng)
|
|
|
|
| 1076 |
delta_time = new_time - last_time
|
| 1077 |
last_time = new_time
|
| 1078 |
|
| 1079 |
+
# reshape data into (gradient_accumulation_steps, dp_devices, batch_per_dp, ...)
|
| 1080 |
batch = jax.tree_map(
|
| 1081 |
lambda x: x.reshape(
|
| 1082 |
(
|
| 1083 |
training_args.gradient_accumulation_steps,
|
| 1084 |
+
training_args.dp_devices,
|
| 1085 |
+
training_args.per_device_train_batch_size,
|
| 1086 |
)
|
| 1087 |
+ x.shape[1:]
|
| 1088 |
),
|