Spaces:
Running
Running
fix: state.step type
Browse files- dev/seq2seq/run_seq2seq_flax.py +14 -12
dev/seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -416,7 +416,7 @@ def wandb_log(metrics, step=None, prefix=None):
|
|
| 416 |
f"{prefix}/{k}" if prefix is not None else k: v for k, v in metrics.items()
|
| 417 |
}
|
| 418 |
if step is not None:
|
| 419 |
-
log_metrics["train/step"] =
|
| 420 |
wandb.log(log_metrics)
|
| 421 |
|
| 422 |
|
|
@@ -846,7 +846,7 @@ def main():
|
|
| 846 |
f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
|
| 847 |
)
|
| 848 |
logger.info(
|
| 849 |
-
f" Total train batch size (w. parallel &
|
| 850 |
)
|
| 851 |
logger.info(f" Total global steps = {total_steps}")
|
| 852 |
logger.info(f" Total optimization steps = {total_optimization_steps}")
|
|
@@ -854,7 +854,7 @@ def main():
|
|
| 854 |
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
| 855 |
|
| 856 |
# set default x-axis as 'train/step'
|
| 857 |
-
wandb_log({}, step=state.step)
|
| 858 |
wandb.define_metric("*", step_metric="train/step")
|
| 859 |
|
| 860 |
# add interesting config parameters
|
|
@@ -893,7 +893,7 @@ def main():
|
|
| 893 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
| 894 |
|
| 895 |
# log metrics
|
| 896 |
-
wandb_log(eval_metrics, step=state.step, prefix="eval")
|
| 897 |
|
| 898 |
# Print metrics and update progress bar
|
| 899 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
|
|
@@ -956,7 +956,7 @@ def main():
|
|
| 956 |
)
|
| 957 |
# save some space
|
| 958 |
c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
|
| 959 |
-
c.cleanup(
|
| 960 |
|
| 961 |
wandb.run.log_artifact(artifact)
|
| 962 |
|
|
@@ -972,7 +972,8 @@ def main():
|
|
| 972 |
|
| 973 |
for epoch in epochs:
|
| 974 |
# ======================== Training ================================
|
| 975 |
-
|
|
|
|
| 976 |
|
| 977 |
# Create sampling rng
|
| 978 |
rng, input_rng = jax.random.split(rng)
|
|
@@ -994,19 +995,20 @@ def main():
|
|
| 994 |
total=steps_per_epoch,
|
| 995 |
):
|
| 996 |
state, train_metric = p_train_step(state, batch)
|
|
|
|
| 997 |
|
| 998 |
-
if
|
| 999 |
# log metrics
|
| 1000 |
-
wandb_log(unreplicate(train_metric), step=
|
| 1001 |
|
| 1002 |
-
if training_args.eval_steps and
|
| 1003 |
run_evaluation()
|
| 1004 |
|
| 1005 |
-
if
|
| 1006 |
-
run_save_model(state,
|
| 1007 |
|
| 1008 |
# log final train metrics
|
| 1009 |
-
wandb_log(unreplicate(train_metric), step=
|
| 1010 |
|
| 1011 |
train_metric = unreplicate(train_metric)
|
| 1012 |
epochs.write(
|
|
|
|
| 416 |
f"{prefix}/{k}" if prefix is not None else k: v for k, v in metrics.items()
|
| 417 |
}
|
| 418 |
if step is not None:
|
| 419 |
+
log_metrics["train/step"] = step
|
| 420 |
wandb.log(log_metrics)
|
| 421 |
|
| 422 |
|
|
|
|
| 846 |
f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
|
| 847 |
)
|
| 848 |
logger.info(
|
| 849 |
+
f" Total train batch size (w. parallel, distributed & gradient accumulation) = {batch_size_per_update}"
|
| 850 |
)
|
| 851 |
logger.info(f" Total global steps = {total_steps}")
|
| 852 |
logger.info(f" Total optimization steps = {total_optimization_steps}")
|
|
|
|
| 854 |
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
| 855 |
|
| 856 |
# set default x-axis as 'train/step'
|
| 857 |
+
wandb_log({}, step=unreplicate(state.step))
|
| 858 |
wandb.define_metric("*", step_metric="train/step")
|
| 859 |
|
| 860 |
# add interesting config parameters
|
|
|
|
| 893 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
| 894 |
|
| 895 |
# log metrics
|
| 896 |
+
wandb_log(eval_metrics, step=unreplicate(state.step), prefix="eval")
|
| 897 |
|
| 898 |
# Print metrics and update progress bar
|
| 899 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
|
|
|
|
| 956 |
)
|
| 957 |
# save some space
|
| 958 |
c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
|
| 959 |
+
c.cleanup("5GB")
|
| 960 |
|
| 961 |
wandb.run.log_artifact(artifact)
|
| 962 |
|
|
|
|
| 972 |
|
| 973 |
for epoch in epochs:
|
| 974 |
# ======================== Training ================================
|
| 975 |
+
step = unreplicate(state.step)
|
| 976 |
+
wandb_log({"train/epoch": epoch}, step=step)
|
| 977 |
|
| 978 |
# Create sampling rng
|
| 979 |
rng, input_rng = jax.random.split(rng)
|
|
|
|
| 995 |
total=steps_per_epoch,
|
| 996 |
):
|
| 997 |
state, train_metric = p_train_step(state, batch)
|
| 998 |
+
step = unreplicate(state.step)
|
| 999 |
|
| 1000 |
+
if step % data_args.log_interval == 0 and jax.process_index() == 0:
|
| 1001 |
# log metrics
|
| 1002 |
+
wandb_log(unreplicate(train_metric), step=step, prefix="train")
|
| 1003 |
|
| 1004 |
+
if training_args.eval_steps and step % training_args.eval_steps == 0:
|
| 1005 |
run_evaluation()
|
| 1006 |
|
| 1007 |
+
if step % data_args.save_model_steps == 0:
|
| 1008 |
+
run_save_model(state, step, epoch)
|
| 1009 |
|
| 1010 |
# log final train metrics
|
| 1011 |
+
wandb_log(unreplicate(train_metric), step=step, prefix="train")
|
| 1012 |
|
| 1013 |
train_metric = unreplicate(train_metric)
|
| 1014 |
epochs.write(
|