Spaces:
Running
Running
feat(train): another 25% faster
Browse files- tools/train/train.py +21 -21
tools/train/train.py
CHANGED
|
@@ -36,10 +36,10 @@ import transformers
|
|
| 36 |
import wandb
|
| 37 |
from datasets import Dataset
|
| 38 |
from distributed_shampoo import GraftingType, distributed_shampoo
|
| 39 |
-
from flax.core.frozen_dict import FrozenDict, freeze
|
| 40 |
from flax.serialization import from_bytes, to_bytes
|
| 41 |
from flax.training import train_state
|
| 42 |
-
from flax.training.common_utils import onehot
|
| 43 |
from jax.experimental import PartitionSpec, maps
|
| 44 |
from jax.experimental.pjit import pjit, with_sharding_constraint
|
| 45 |
from tqdm import tqdm
|
|
@@ -382,7 +382,7 @@ class TrainState(train_state.TrainState):
|
|
| 382 |
|
| 383 |
class MetricsLogger:
|
| 384 |
def __init__(self, state):
|
| 385 |
-
self.step = state.step
|
| 386 |
self.time = time.perf_counter()
|
| 387 |
|
| 388 |
def get_all_train_metrics(self, train_metrics, state):
|
|
@@ -792,8 +792,7 @@ def main():
|
|
| 792 |
|
| 793 |
def compute_loss(params, minibatch, dropout_rng):
|
| 794 |
# minibatch has dim (batch_size, ...)
|
| 795 |
-
minibatch =
|
| 796 |
-
labels = minibatch.pop("labels")
|
| 797 |
logits = state.apply_fn(
|
| 798 |
**minibatch, params=params, dropout_rng=dropout_rng, train=True
|
| 799 |
)[0]
|
|
@@ -883,14 +882,10 @@ def main():
|
|
| 883 |
|
| 884 |
# Define eval fn
|
| 885 |
def eval_step(params, batch):
|
| 886 |
-
batch =
|
| 887 |
-
labels = batch.pop("labels")
|
| 888 |
logits = model(**batch, params=params, train=False)[0]
|
| 889 |
loss = loss_fn(logits, labels)
|
| 890 |
-
|
| 891 |
-
# summarize metrics
|
| 892 |
-
metrics = {"loss": loss}
|
| 893 |
-
return metrics
|
| 894 |
|
| 895 |
# Create parallel version of the train and eval step
|
| 896 |
p_train_step = pjit(
|
|
@@ -940,7 +935,6 @@ def main():
|
|
| 940 |
|
| 941 |
def run_evaluation():
|
| 942 |
# ======================== Evaluating ==============================
|
| 943 |
-
eval_metrics = []
|
| 944 |
if training_args.do_eval:
|
| 945 |
eval_loader = dataset.dataloader("eval", eval_batch_size)
|
| 946 |
eval_steps = (
|
|
@@ -948,6 +942,7 @@ def main():
|
|
| 948 |
if len_eval_dataset is not None
|
| 949 |
else None
|
| 950 |
)
|
|
|
|
| 951 |
for batch in tqdm(
|
| 952 |
eval_loader,
|
| 953 |
desc="Evaluating...",
|
|
@@ -955,13 +950,15 @@ def main():
|
|
| 955 |
leave=False,
|
| 956 |
total=eval_steps,
|
| 957 |
):
|
| 958 |
-
#
|
| 959 |
-
|
| 960 |
-
|
|
|
|
| 961 |
|
| 962 |
-
#
|
| 963 |
-
|
| 964 |
-
|
|
|
|
| 965 |
|
| 966 |
# log metrics
|
| 967 |
metrics_logger.log(eval_metrics, step=state.step, prefix="eval")
|
|
@@ -1050,6 +1047,7 @@ def main():
|
|
| 1050 |
# init variables
|
| 1051 |
last_time = time.perf_counter()
|
| 1052 |
train_metrics = None
|
|
|
|
| 1053 |
|
| 1054 |
with maps.mesh(mesh.devices, mesh.axis_names):
|
| 1055 |
for epoch in epochs:
|
|
@@ -1088,10 +1086,12 @@ def main():
|
|
| 1088 |
),
|
| 1089 |
batch,
|
| 1090 |
)
|
|
|
|
|
|
|
| 1091 |
|
| 1092 |
# train step
|
| 1093 |
-
state, train_metrics = p_train_step(state,
|
| 1094 |
-
step
|
| 1095 |
|
| 1096 |
if step % training_args.logging_steps == 0 and jax.process_index() == 0:
|
| 1097 |
all_metrics = metrics_logger.get_all_train_metrics(
|
|
@@ -1100,7 +1100,7 @@ def main():
|
|
| 1100 |
metrics_logger.log(all_metrics, step=step, prefix="train")
|
| 1101 |
|
| 1102 |
eval_metrics = None
|
| 1103 |
-
if
|
| 1104 |
eval_metrics = run_evaluation()
|
| 1105 |
|
| 1106 |
if step % training_args.save_steps == 0:
|
|
|
|
| 36 |
import wandb
|
| 37 |
from datasets import Dataset
|
| 38 |
from distributed_shampoo import GraftingType, distributed_shampoo
|
| 39 |
+
from flax.core.frozen_dict import FrozenDict, freeze
|
| 40 |
from flax.serialization import from_bytes, to_bytes
|
| 41 |
from flax.training import train_state
|
| 42 |
+
from flax.training.common_utils import onehot
|
| 43 |
from jax.experimental import PartitionSpec, maps
|
| 44 |
from jax.experimental.pjit import pjit, with_sharding_constraint
|
| 45 |
from tqdm import tqdm
|
|
|
|
| 382 |
|
| 383 |
class MetricsLogger:
|
| 384 |
def __init__(self, state):
|
| 385 |
+
self.step = int(state.step)
|
| 386 |
self.time = time.perf_counter()
|
| 387 |
|
| 388 |
def get_all_train_metrics(self, train_metrics, state):
|
|
|
|
| 792 |
|
| 793 |
def compute_loss(params, minibatch, dropout_rng):
|
| 794 |
# minibatch has dim (batch_size, ...)
|
| 795 |
+
minibatch, labels = minibatch.pop("labels")
|
|
|
|
| 796 |
logits = state.apply_fn(
|
| 797 |
**minibatch, params=params, dropout_rng=dropout_rng, train=True
|
| 798 |
)[0]
|
|
|
|
| 882 |
|
| 883 |
# Define eval fn
|
| 884 |
def eval_step(params, batch):
|
| 885 |
+
batch, labels = batch.pop("labels")
|
|
|
|
| 886 |
logits = model(**batch, params=params, train=False)[0]
|
| 887 |
loss = loss_fn(logits, labels)
|
| 888 |
+
return loss
|
|
|
|
|
|
|
|
|
|
| 889 |
|
| 890 |
# Create parallel version of the train and eval step
|
| 891 |
p_train_step = pjit(
|
|
|
|
| 935 |
|
| 936 |
def run_evaluation():
|
| 937 |
# ======================== Evaluating ==============================
|
|
|
|
| 938 |
if training_args.do_eval:
|
| 939 |
eval_loader = dataset.dataloader("eval", eval_batch_size)
|
| 940 |
eval_steps = (
|
|
|
|
| 942 |
if len_eval_dataset is not None
|
| 943 |
else None
|
| 944 |
)
|
| 945 |
+
eval_loss = []
|
| 946 |
for batch in tqdm(
|
| 947 |
eval_loader,
|
| 948 |
desc="Evaluating...",
|
|
|
|
| 950 |
leave=False,
|
| 951 |
total=eval_steps,
|
| 952 |
):
|
| 953 |
+
# freeze batch to pass safely to JAX transforms
|
| 954 |
+
batch = freeze(batch)
|
| 955 |
+
# accumulate losses async
|
| 956 |
+
eval_loss.append(p_eval_step(state.params, batch))
|
| 957 |
|
| 958 |
+
# get the mean of the loss
|
| 959 |
+
eval_loss = jnp.stack(eval_loss)
|
| 960 |
+
eval_loss = jnp.mean(eval_loss)
|
| 961 |
+
eval_metrics = {"loss": eval_loss}
|
| 962 |
|
| 963 |
# log metrics
|
| 964 |
metrics_logger.log(eval_metrics, step=state.step, prefix="eval")
|
|
|
|
| 1047 |
# init variables
|
| 1048 |
last_time = time.perf_counter()
|
| 1049 |
train_metrics = None
|
| 1050 |
+
step = int(state.step)
|
| 1051 |
|
| 1052 |
with maps.mesh(mesh.devices, mesh.axis_names):
|
| 1053 |
for epoch in epochs:
|
|
|
|
| 1086 |
),
|
| 1087 |
batch,
|
| 1088 |
)
|
| 1089 |
+
# freeze batch to pass safely to jax transforms
|
| 1090 |
+
batch = freeze(batch)
|
| 1091 |
|
| 1092 |
# train step
|
| 1093 |
+
state, train_metrics = p_train_step(state, batch, delta_time)
|
| 1094 |
+
step += 1
|
| 1095 |
|
| 1096 |
if step % training_args.logging_steps == 0 and jax.process_index() == 0:
|
| 1097 |
all_metrics = metrics_logger.get_all_train_metrics(
|
|
|
|
| 1100 |
metrics_logger.log(all_metrics, step=step, prefix="train")
|
| 1101 |
|
| 1102 |
eval_metrics = None
|
| 1103 |
+
if step % training_args.eval_steps == 0:
|
| 1104 |
eval_metrics = run_evaluation()
|
| 1105 |
|
| 1106 |
if step % training_args.save_steps == 0:
|