Spaces:
Running
Running
feat: log more metrics
Browse files- tools/train/train.py +41 -20
tools/train/train.py
CHANGED
|
@@ -331,14 +331,37 @@ def create_learning_rate_fn(
|
|
| 331 |
return schedule_fn
|
| 332 |
|
| 333 |
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
}
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
|
| 343 |
|
| 344 |
def main():
|
|
@@ -628,9 +651,10 @@ def main():
|
|
| 628 |
range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
|
| 629 |
)
|
| 630 |
|
|
|
|
| 631 |
if jax.process_index() == 0:
|
| 632 |
# set default x-axis as 'train/step'
|
| 633 |
-
|
| 634 |
wandb.define_metric("*", step_metric="train/step")
|
| 635 |
|
| 636 |
# add interesting config parameters
|
|
@@ -672,7 +696,9 @@ def main():
|
|
| 672 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
| 673 |
|
| 674 |
# log metrics
|
| 675 |
-
|
|
|
|
|
|
|
| 676 |
|
| 677 |
# Print metrics and update progress bar
|
| 678 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
|
|
@@ -772,7 +798,7 @@ def main():
|
|
| 772 |
for epoch in epochs:
|
| 773 |
state.replace(epoch=jax_utils.replicate(epoch))
|
| 774 |
# ======================== Training ================================
|
| 775 |
-
|
| 776 |
|
| 777 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
| 778 |
train_loader = dataset.dataloader("train", train_batch_size)
|
|
@@ -797,17 +823,12 @@ def main():
|
|
| 797 |
step = unreplicate(state.step)
|
| 798 |
|
| 799 |
if step % training_args.logging_steps == 0 and jax.process_index() == 0:
|
| 800 |
-
|
| 801 |
-
|
| 802 |
-
# log state parameters
|
| 803 |
-
state_dict = {
|
| 804 |
-
k.split("_")[-1]: unreplicate(getattr(state, k))
|
| 805 |
-
for k in ["epoch", "train_time", "train_samples"]
|
| 806 |
-
}
|
| 807 |
-
wandb_log({**metrics, **state_dict}, step=step, prefix="train")
|
| 808 |
|
| 809 |
eval_metrics = None
|
| 810 |
if training_args.eval_steps and step % training_args.eval_steps == 0:
|
|
|
|
| 811 |
eval_metrics = run_evaluation()
|
| 812 |
|
| 813 |
if step % training_args.save_steps == 0:
|
|
@@ -815,8 +836,8 @@ def main():
|
|
| 815 |
|
| 816 |
# log final train metrics
|
| 817 |
if train_metrics is not None:
|
| 818 |
-
|
| 819 |
-
|
| 820 |
|
| 821 |
epochs.write(
|
| 822 |
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})"
|
|
|
|
| 331 |
return schedule_fn
|
| 332 |
|
| 333 |
|
| 334 |
+
class MetricsLogger:
|
| 335 |
+
def __init__(self, state):
|
| 336 |
+
self.step = state.step
|
| 337 |
+
self.time = time.perf_counter()
|
| 338 |
+
|
| 339 |
+
def get_all_train_metrics(self, train_metrics, state):
|
| 340 |
+
"""Make a dict of training metrics to be logged"""
|
| 341 |
+
metrics = unreplicate(train_metrics)
|
| 342 |
+
# get state parameters
|
| 343 |
+
state_dict = {
|
| 344 |
+
k.split("_")[-1]: unreplicate(getattr(state, k))
|
| 345 |
+
for k in ["epoch", "train_time", "train_samples"]
|
| 346 |
}
|
| 347 |
+
# timing metrics
|
| 348 |
+
new_step = int(unreplicate(state.step))
|
| 349 |
+
new_time = time.perf_counter()
|
| 350 |
+
time_per_step = (new_time - self.time) / (new_step - self.step)
|
| 351 |
+
self.step = new_step
|
| 352 |
+
self.time = new_time
|
| 353 |
+
return {**metrics, **state_dict, "time_per_step": time_per_step}
|
| 354 |
+
|
| 355 |
+
@staticmethod
|
| 356 |
+
def log(metrics, step=None, prefix=None):
|
| 357 |
+
if jax.process_index() == 0:
|
| 358 |
+
log_metrics = {
|
| 359 |
+
f"{prefix}/{k}" if prefix is not None else k: v
|
| 360 |
+
for k, v in metrics.items()
|
| 361 |
+
}
|
| 362 |
+
if step is not None:
|
| 363 |
+
log_metrics["train/step"] = step
|
| 364 |
+
wandb.log(log_metrics)
|
| 365 |
|
| 366 |
|
| 367 |
def main():
|
|
|
|
| 651 |
range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
|
| 652 |
)
|
| 653 |
|
| 654 |
+
metrics_logger = MetricsLogger(state)
|
| 655 |
if jax.process_index() == 0:
|
| 656 |
# set default x-axis as 'train/step'
|
| 657 |
+
metrics_logger.log({}, step=state.step)
|
| 658 |
wandb.define_metric("*", step_metric="train/step")
|
| 659 |
|
| 660 |
# add interesting config parameters
|
|
|
|
| 696 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
| 697 |
|
| 698 |
# log metrics
|
| 699 |
+
metrics_logger.log(
|
| 700 |
+
eval_metrics, step=unreplicate(state.step), prefix="eval"
|
| 701 |
+
)
|
| 702 |
|
| 703 |
# Print metrics and update progress bar
|
| 704 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
|
|
|
|
| 798 |
for epoch in epochs:
|
| 799 |
state.replace(epoch=jax_utils.replicate(epoch))
|
| 800 |
# ======================== Training ================================
|
| 801 |
+
metrics_logger.log({"train/epoch": epoch}, step=unreplicate(state.step))
|
| 802 |
|
| 803 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
| 804 |
train_loader = dataset.dataloader("train", train_batch_size)
|
|
|
|
| 823 |
step = unreplicate(state.step)
|
| 824 |
|
| 825 |
if step % training_args.logging_steps == 0 and jax.process_index() == 0:
|
| 826 |
+
all_metrics = metrics_logger.get_all_train_metrics(train_metrics, state)
|
| 827 |
+
metrics_logger.log(all_metrics, step=step, prefix="train")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 828 |
|
| 829 |
eval_metrics = None
|
| 830 |
if training_args.eval_steps and step % training_args.eval_steps == 0:
|
| 831 |
+
return
|
| 832 |
eval_metrics = run_evaluation()
|
| 833 |
|
| 834 |
if step % training_args.save_steps == 0:
|
|
|
|
| 836 |
|
| 837 |
# log final train metrics
|
| 838 |
if train_metrics is not None:
|
| 839 |
+
all_metrics = metrics_logger.get_all_train_metrics(train_metrics, state)
|
| 840 |
+
metrics_logger.log(all_metrics, step=step, prefix="train")
|
| 841 |
|
| 842 |
epochs.write(
|
| 843 |
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})"
|