Spaces:
Running
Running
feat: log everything through wandb
Browse files- seq2seq/run_seq2seq_flax.py +17 -53
seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -57,7 +57,6 @@ from transformers import (
|
|
| 57 |
FlaxBartForConditionalGeneration,
|
| 58 |
HfArgumentParser,
|
| 59 |
TrainingArguments,
|
| 60 |
-
is_tensorboard_available,
|
| 61 |
)
|
| 62 |
from transformers.models.bart.modeling_flax_bart import *
|
| 63 |
from transformers.file_utils import is_offline_mode
|
|
@@ -226,10 +225,10 @@ class DataTrainingArguments:
|
|
| 226 |
"value if set."
|
| 227 |
},
|
| 228 |
)
|
| 229 |
-
|
| 230 |
default=400,
|
| 231 |
metadata={
|
| 232 |
-
"help": "Evaluation will be performed every
|
| 233 |
},
|
| 234 |
)
|
| 235 |
log_model: bool = field(
|
|
@@ -324,19 +323,6 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
|
|
| 324 |
yield batch
|
| 325 |
|
| 326 |
|
| 327 |
-
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
| 328 |
-
summary_writer.scalar("train_time", train_time, step)
|
| 329 |
-
|
| 330 |
-
train_metrics = get_metrics(train_metrics)
|
| 331 |
-
for key, vals in train_metrics.items():
|
| 332 |
-
tag = f"train_epoch/{key}"
|
| 333 |
-
for i, val in enumerate(vals):
|
| 334 |
-
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
| 335 |
-
|
| 336 |
-
for metric_name, value in eval_metrics.items():
|
| 337 |
-
summary_writer.scalar(f"eval/{metric_name}", value, step)
|
| 338 |
-
|
| 339 |
-
|
| 340 |
def create_learning_rate_fn(
|
| 341 |
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
| 342 |
) -> Callable[[int], jnp.array]:
|
|
@@ -351,6 +337,14 @@ def create_learning_rate_fn(
|
|
| 351 |
return schedule_fn
|
| 352 |
|
| 353 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
def main():
|
| 355 |
# See all possible arguments in src/transformers/training_args.py
|
| 356 |
# or by passing the --help flag to this script.
|
|
@@ -377,7 +371,6 @@ def main():
|
|
| 377 |
|
| 378 |
# Set up wandb run
|
| 379 |
wandb.init(
|
| 380 |
-
sync_tensorboard=True,
|
| 381 |
entity='wandb',
|
| 382 |
project='hf-flax-dalle-mini',
|
| 383 |
job_type='Seq2SeqVQGAN',
|
|
@@ -578,24 +571,6 @@ def main():
|
|
| 578 |
result = {k: round(v, 4) for k, v in result.items()}
|
| 579 |
return result
|
| 580 |
|
| 581 |
-
# Enable tensorboard only on the master node
|
| 582 |
-
has_tensorboard = is_tensorboard_available()
|
| 583 |
-
if has_tensorboard and jax.process_index() == 0:
|
| 584 |
-
try:
|
| 585 |
-
from flax.metrics.tensorboard import SummaryWriter
|
| 586 |
-
|
| 587 |
-
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
|
| 588 |
-
except ImportError as ie:
|
| 589 |
-
has_tensorboard = False
|
| 590 |
-
logger.warning(
|
| 591 |
-
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
|
| 592 |
-
)
|
| 593 |
-
else:
|
| 594 |
-
logger.warning(
|
| 595 |
-
"Unable to display metrics through TensorBoard because the package is not installed: "
|
| 596 |
-
"Please run pip install tensorboard to enable."
|
| 597 |
-
)
|
| 598 |
-
|
| 599 |
# Initialize our training
|
| 600 |
rng = jax.random.PRNGKey(training_args.seed)
|
| 601 |
rng, dropout_rng = jax.random.split(rng)
|
|
@@ -774,10 +749,8 @@ def main():
|
|
| 774 |
eval_metrics = get_metrics(eval_metrics)
|
| 775 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
| 776 |
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
wandb.log({"eval/step": global_step})
|
| 780 |
-
wandb.log({f"eval/{k}": jax.device_get(v)})
|
| 781 |
|
| 782 |
# compute ROUGE metrics
|
| 783 |
rouge_desc = ""
|
|
@@ -790,6 +763,7 @@ def main():
|
|
| 790 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
|
| 791 |
epochs.write(desc)
|
| 792 |
epochs.desc = desc
|
|
|
|
| 793 |
return eval_metrics
|
| 794 |
|
| 795 |
for epoch in epochs:
|
|
@@ -798,7 +772,6 @@ def main():
|
|
| 798 |
|
| 799 |
# Create sampling rng
|
| 800 |
rng, input_rng = jax.random.split(rng)
|
| 801 |
-
train_metrics = []
|
| 802 |
|
| 803 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
| 804 |
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
|
|
@@ -808,32 +781,23 @@ def main():
|
|
| 808 |
global_step +=1
|
| 809 |
batch = next(train_loader)
|
| 810 |
state, train_metric = p_train_step(state, batch)
|
| 811 |
-
train_metrics.append(train_metric)
|
| 812 |
|
| 813 |
if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
|
| 814 |
-
|
| 815 |
-
|
| 816 |
-
wandb.log({"train/step": global_step})
|
| 817 |
-
wandb.log({f"train/{k}": jax.device_get(v)})
|
| 818 |
|
| 819 |
-
if global_step % data_args.
|
| 820 |
run_evaluation()
|
| 821 |
|
| 822 |
train_time += time.time() - train_start
|
| 823 |
-
|
| 824 |
train_metric = unreplicate(train_metric)
|
| 825 |
-
|
| 826 |
epochs.write(
|
| 827 |
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
| 828 |
)
|
| 829 |
|
|
|
|
| 830 |
eval_metrics = run_evaluation()
|
| 831 |
|
| 832 |
-
# Save metrics
|
| 833 |
-
if has_tensorboard and jax.process_index() == 0:
|
| 834 |
-
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
| 835 |
-
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
|
| 836 |
-
|
| 837 |
# save checkpoint after each epoch and push checkpoint to the hub
|
| 838 |
if jax.process_index() == 0:
|
| 839 |
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
|
|
|
| 57 |
FlaxBartForConditionalGeneration,
|
| 58 |
HfArgumentParser,
|
| 59 |
TrainingArguments,
|
|
|
|
| 60 |
)
|
| 61 |
from transformers.models.bart.modeling_flax_bart import *
|
| 62 |
from transformers.file_utils import is_offline_mode
|
|
|
|
| 225 |
"value if set."
|
| 226 |
},
|
| 227 |
)
|
| 228 |
+
eval_steps: Optional[int] = field(
|
| 229 |
default=400,
|
| 230 |
metadata={
|
| 231 |
+
"help": "Evaluation will be performed every eval_steps"
|
| 232 |
},
|
| 233 |
)
|
| 234 |
log_model: bool = field(
|
|
|
|
| 323 |
yield batch
|
| 324 |
|
| 325 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
def create_learning_rate_fn(
|
| 327 |
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
| 328 |
) -> Callable[[int], jnp.array]:
|
|
|
|
| 337 |
return schedule_fn
|
| 338 |
|
| 339 |
|
| 340 |
+
def wandb_log(metrics, step=None, prefix=None):
|
| 341 |
+
if jax.process_index() == 0:
|
| 342 |
+
log_metrics = {f'{prefix}/k' if prefix is not None else k: jax.device_get(v) for k,v in metrics.items()}
|
| 343 |
+
if step is not None:
|
| 344 |
+
log_metrics = {**metrics, 'train/step': step}
|
| 345 |
+
wandb.log(log_metrics)
|
| 346 |
+
|
| 347 |
+
|
| 348 |
def main():
|
| 349 |
# See all possible arguments in src/transformers/training_args.py
|
| 350 |
# or by passing the --help flag to this script.
|
|
|
|
| 371 |
|
| 372 |
# Set up wandb run
|
| 373 |
wandb.init(
|
|
|
|
| 374 |
entity='wandb',
|
| 375 |
project='hf-flax-dalle-mini',
|
| 376 |
job_type='Seq2SeqVQGAN',
|
|
|
|
| 571 |
result = {k: round(v, 4) for k, v in result.items()}
|
| 572 |
return result
|
| 573 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 574 |
# Initialize our training
|
| 575 |
rng = jax.random.PRNGKey(training_args.seed)
|
| 576 |
rng, dropout_rng = jax.random.split(rng)
|
|
|
|
| 749 |
eval_metrics = get_metrics(eval_metrics)
|
| 750 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
| 751 |
|
| 752 |
+
# log metrics
|
| 753 |
+
wandb_log(eval_metrics, step=global_step, prefix='eval')
|
|
|
|
|
|
|
| 754 |
|
| 755 |
# compute ROUGE metrics
|
| 756 |
rouge_desc = ""
|
|
|
|
| 763 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
|
| 764 |
epochs.write(desc)
|
| 765 |
epochs.desc = desc
|
| 766 |
+
|
| 767 |
return eval_metrics
|
| 768 |
|
| 769 |
for epoch in epochs:
|
|
|
|
| 772 |
|
| 773 |
# Create sampling rng
|
| 774 |
rng, input_rng = jax.random.split(rng)
|
|
|
|
| 775 |
|
| 776 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
| 777 |
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
|
|
|
|
| 781 |
global_step +=1
|
| 782 |
batch = next(train_loader)
|
| 783 |
state, train_metric = p_train_step(state, batch)
|
|
|
|
| 784 |
|
| 785 |
if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
|
| 786 |
+
# log metrics
|
| 787 |
+
wandb_log(unreplicate(train_metric), step=global_step, prefix='tran')
|
|
|
|
|
|
|
| 788 |
|
| 789 |
+
if global_step % data_args.eval_steps == 0:
|
| 790 |
run_evaluation()
|
| 791 |
|
| 792 |
train_time += time.time() - train_start
|
|
|
|
| 793 |
train_metric = unreplicate(train_metric)
|
|
|
|
| 794 |
epochs.write(
|
| 795 |
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
| 796 |
)
|
| 797 |
|
| 798 |
+
# Final evaluation
|
| 799 |
eval_metrics = run_evaluation()
|
| 800 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 801 |
# save checkpoint after each epoch and push checkpoint to the hub
|
| 802 |
if jax.process_index() == 0:
|
| 803 |
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|