Spaces:
Running
Running
feat(pjit): follow t5x style
Browse files- tools/train/train.py +65 -58
tools/train/train.py
CHANGED
|
@@ -765,6 +765,7 @@ def main():
|
|
| 765 |
# define batch specs
|
| 766 |
keys = ["attention_mask", "decoder_input_ids", "input_ids", "labels"]
|
| 767 |
batch_spec = freeze({k: PartitionSpec("batch") for k in keys})
|
|
|
|
| 768 |
|
| 769 |
# label smoothed cross entropy
|
| 770 |
def loss_fn(logits, labels):
|
|
@@ -774,18 +775,22 @@ def main():
|
|
| 774 |
|
| 775 |
# Define gradient update step fn
|
| 776 |
def train_step(state, batch, delta_time):
|
|
|
|
| 777 |
# check correct batch shape during compilation
|
| 778 |
-
assert batch["labels"].shape[0:
|
| 779 |
-
training_args.dp_devices,
|
| 780 |
training_args.gradient_accumulation_steps,
|
| 781 |
-
|
| 782 |
), f"Expected label batch of shape dp_devices x gradient_acculumation x batch_per_device and got {batch['labels'].shape}"
|
| 783 |
-
# create a new rng
|
| 784 |
-
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
| 785 |
-
# use a different rng per node
|
| 786 |
-
dropout_rng = jax.random.fold_in(dropout_rng, jax.process_index())
|
| 787 |
|
| 788 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 789 |
minibatch = unfreeze(minibatch)
|
| 790 |
labels = minibatch.pop("labels")
|
| 791 |
logits = state.apply_fn(
|
|
@@ -795,58 +800,61 @@ def main():
|
|
| 795 |
|
| 796 |
grad_fn = jax.value_and_grad(compute_loss)
|
| 797 |
|
| 798 |
-
def
|
| 799 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 800 |
|
| 801 |
-
|
| 802 |
-
|
| 803 |
-
|
| 804 |
-
|
| 805 |
-
)
|
| 806 |
-
loss, grads = grad_fn(state.params, minibatch)
|
| 807 |
-
else:
|
| 808 |
|
| 809 |
-
|
| 810 |
-
minibatch = jax.tree_map(
|
| 811 |
-
lambda x: x[i],
|
| 812 |
-
device_batch,
|
| 813 |
-
)
|
| 814 |
-
return jax.tree_map(
|
| 815 |
-
lambda x, y: x + y,
|
| 816 |
-
cumul_loss_grads,
|
| 817 |
-
grad_fn(state.params, minibatch),
|
| 818 |
-
)
|
| 819 |
|
| 820 |
-
|
| 821 |
-
|
| 822 |
-
|
| 823 |
-
|
| 824 |
-
|
| 825 |
-
|
| 826 |
-
|
| 827 |
-
|
| 828 |
-
|
| 829 |
-
|
| 830 |
-
|
| 831 |
-
|
| 832 |
-
|
| 833 |
-
|
| 834 |
-
|
| 835 |
-
|
| 836 |
-
|
| 837 |
-
|
| 838 |
-
|
| 839 |
-
|
| 840 |
-
|
| 841 |
-
|
| 842 |
-
|
| 843 |
-
|
| 844 |
-
|
| 845 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 846 |
|
|
|
|
|
|
|
| 847 |
state = state.apply_gradients(
|
| 848 |
grads=grads,
|
| 849 |
-
dropout_rng=
|
| 850 |
train_time=state.train_time + delta_time,
|
| 851 |
train_samples=state.train_samples + batch_size_per_step,
|
| 852 |
)
|
|
@@ -872,7 +880,7 @@ def main():
|
|
| 872 |
# Create parallel version of the train and eval step
|
| 873 |
p_train_step = pjit(
|
| 874 |
train_step,
|
| 875 |
-
in_axis_resources=(state_spec,
|
| 876 |
out_axis_resources=(state_spec, None),
|
| 877 |
donate_argnums=(0,),
|
| 878 |
)
|
|
@@ -1053,13 +1061,12 @@ def main():
|
|
| 1053 |
delta_time = new_time - last_time
|
| 1054 |
last_time = new_time
|
| 1055 |
|
| 1056 |
-
# reshape data into (
|
| 1057 |
batch = jax.tree_map(
|
| 1058 |
lambda x: x.reshape(
|
| 1059 |
(
|
| 1060 |
-
training_args.dp_devices,
|
| 1061 |
training_args.gradient_accumulation_steps,
|
| 1062 |
-
|
| 1063 |
)
|
| 1064 |
+ x.shape[1:]
|
| 1065 |
),
|
|
|
|
| 765 |
# define batch specs
|
| 766 |
keys = ["attention_mask", "decoder_input_ids", "input_ids", "labels"]
|
| 767 |
batch_spec = freeze({k: PartitionSpec("batch") for k in keys})
|
| 768 |
+
grad_batch_spec = freeze({k: PartitionSpec(None, "batch") for k in keys})
|
| 769 |
|
| 770 |
# label smoothed cross entropy
|
| 771 |
def loss_fn(logits, labels):
|
|
|
|
| 775 |
|
| 776 |
# Define gradient update step fn
|
| 777 |
def train_step(state, batch, delta_time):
|
| 778 |
+
# batch is (gradient_accumulation_steps, minibatch_size, ...)
|
| 779 |
# check correct batch shape during compilation
|
| 780 |
+
assert batch["labels"].shape[0:2] == (
|
|
|
|
| 781 |
training_args.gradient_accumulation_steps,
|
| 782 |
+
minibatch_size,
|
| 783 |
), f"Expected label batch of shape dp_devices x gradient_acculumation x batch_per_device and got {batch['labels'].shape}"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 784 |
|
| 785 |
+
# get a minibatch (one gradient accumulation slice)
|
| 786 |
+
def get_minibatch(batch, grad_idx):
|
| 787 |
+
return jax.tree_map(
|
| 788 |
+
lambda x: jax.lax.dynamic_index_in_dim(x, grad_idx, keepdims=False),
|
| 789 |
+
batch,
|
| 790 |
+
)
|
| 791 |
+
|
| 792 |
+
def compute_loss(params, minibatch, dropout_rng):
|
| 793 |
+
# minibatch has dim (batch_size, ...)
|
| 794 |
minibatch = unfreeze(minibatch)
|
| 795 |
labels = minibatch.pop("labels")
|
| 796 |
logits = state.apply_fn(
|
|
|
|
| 800 |
|
| 801 |
grad_fn = jax.value_and_grad(compute_loss)
|
| 802 |
|
| 803 |
+
def loss_and_grad(grad_idx, dropout_rng):
|
| 804 |
+
minibatch = get_minibatch(batch, grad_idx)
|
| 805 |
+
# ensure batch is sharded over devices
|
| 806 |
+
minibatch = jax.tree_map(
|
| 807 |
+
lambda x: with_sharding_constraint(x, PartitionSpec("batch")), minibatch
|
| 808 |
+
)
|
| 809 |
+
# return loss and grads
|
| 810 |
+
return grad_fn(state.params, minibatch, dropout_rng)
|
| 811 |
|
| 812 |
+
# create a new rng
|
| 813 |
+
dropout_rng, _ = jax.random.split(state.dropout_rng)
|
| 814 |
+
# use a different rng per node
|
| 815 |
+
dropout_rng = jax.random.fold_in(dropout_rng, jax.process_index())
|
|
|
|
|
|
|
|
|
|
| 816 |
|
| 817 |
+
if training_args.gradient_accumulation_steps == 1:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 818 |
|
| 819 |
+
def batch_step(dropout_rng):
|
| 820 |
+
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
|
| 821 |
+
loss_grad = loss_and_grad(0, dropout_rng)
|
| 822 |
+
return loss_grad, new_dropout_rng
|
| 823 |
+
|
| 824 |
+
loss_grad, dropout_rng = batch_step(dropout_rng)
|
| 825 |
+
else:
|
| 826 |
+
# create initial state for per_minibatch_step loop
|
| 827 |
+
init_cumul_loss_grad = (
|
| 828 |
+
0.0,
|
| 829 |
+
jax.tree_map(jnp.zeros_like, state.params),
|
| 830 |
+
)
|
| 831 |
+
init_minibatch_step = (init_cumul_loss_grad, dropout_rng)
|
| 832 |
+
|
| 833 |
+
# accumulate gradients
|
| 834 |
+
def cumul_minibatch_step(grad_idx, cumul_loss_grad_dropout):
|
| 835 |
+
cumul_loss_grad, dropout_rng = cumul_loss_grad_dropout
|
| 836 |
+
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
|
| 837 |
+
loss_grad = loss_and_grad(grad_idx, dropout_rng)
|
| 838 |
+
cumul_loss_grad = jax.tree_map(jnp.add, cumul_loss_grad, loss_grad)
|
| 839 |
+
return cumul_loss_grad, new_dropout_rng
|
| 840 |
+
|
| 841 |
+
# loop over gradients
|
| 842 |
+
loss_grad, dropout_rng = jax.lax.fori_loop(
|
| 843 |
+
0,
|
| 844 |
+
training_args.gradient_accumulation_steps,
|
| 845 |
+
cumul_minibatch_step,
|
| 846 |
+
init_minibatch_step,
|
| 847 |
+
)
|
| 848 |
+
# sum -> mean
|
| 849 |
+
loss_grad = jax.tree_map(
|
| 850 |
+
lambda x: x / training_args.gradient_accumulation_steps, loss_grad
|
| 851 |
+
)
|
| 852 |
|
| 853 |
+
# update state
|
| 854 |
+
loss, grads = loss_grad
|
| 855 |
state = state.apply_gradients(
|
| 856 |
grads=grads,
|
| 857 |
+
dropout_rng=dropout_rng,
|
| 858 |
train_time=state.train_time + delta_time,
|
| 859 |
train_samples=state.train_samples + batch_size_per_step,
|
| 860 |
)
|
|
|
|
| 880 |
# Create parallel version of the train and eval step
|
| 881 |
p_train_step = pjit(
|
| 882 |
train_step,
|
| 883 |
+
in_axis_resources=(state_spec, grad_batch_spec, None),
|
| 884 |
out_axis_resources=(state_spec, None),
|
| 885 |
donate_argnums=(0,),
|
| 886 |
)
|
|
|
|
| 1061 |
delta_time = new_time - last_time
|
| 1062 |
last_time = new_time
|
| 1063 |
|
| 1064 |
+
# reshape data into (gradient_accumulation_steps, minibatch_size, ...)
|
| 1065 |
batch = jax.tree_map(
|
| 1066 |
lambda x: x.reshape(
|
| 1067 |
(
|
|
|
|
| 1068 |
training_args.gradient_accumulation_steps,
|
| 1069 |
+
minibatch_size,
|
| 1070 |
)
|
| 1071 |
+ x.shape[1:]
|
| 1072 |
),
|