Spaces:
Running
Running
feat(train): handle multi-hosts
Browse files- tools/train/train.py +95 -71
tools/train/train.py
CHANGED
|
@@ -389,15 +389,19 @@ def main():
|
|
| 389 |
)
|
| 390 |
|
| 391 |
# Set up wandb run
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
|
|
|
| 398 |
|
| 399 |
if training_args.resume_from_checkpoint is not None:
|
| 400 |
-
|
|
|
|
|
|
|
|
|
|
| 401 |
artifact_dir = artifact.download()
|
| 402 |
|
| 403 |
# load model
|
|
@@ -462,14 +466,23 @@ def main():
|
|
| 462 |
|
| 463 |
# Store some constant
|
| 464 |
num_epochs = int(training_args.num_train_epochs)
|
|
|
|
| 465 |
train_batch_size = (
|
| 466 |
-
int(training_args.per_device_train_batch_size) * jax.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 467 |
)
|
| 468 |
-
batch_size_per_update = train_batch_size * training_args.gradient_accumulation_steps
|
| 469 |
-
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
| 470 |
len_train_dataset, len_eval_dataset = dataset.length
|
| 471 |
steps_per_epoch = (
|
| 472 |
-
len_train_dataset // train_batch_size
|
|
|
|
|
|
|
| 473 |
)
|
| 474 |
num_train_steps = (
|
| 475 |
steps_per_epoch * num_epochs if steps_per_epoch is not None else None
|
|
@@ -568,7 +581,7 @@ def main():
|
|
| 568 |
grads=grads,
|
| 569 |
dropout_rng=new_dropout_rng,
|
| 570 |
train_time=state.train_time + delta_time,
|
| 571 |
-
train_samples=state.train_samples + train_batch_size,
|
| 572 |
)
|
| 573 |
|
| 574 |
metrics = {
|
|
@@ -600,6 +613,7 @@ def main():
|
|
| 600 |
logger.info(
|
| 601 |
f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
|
| 602 |
)
|
|
|
|
| 603 |
logger.info(
|
| 604 |
f" Total train batch size (w. parallel, distributed & gradient accumulation) = {batch_size_per_update}"
|
| 605 |
)
|
|
@@ -608,19 +622,20 @@ def main():
|
|
| 608 |
range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
|
| 609 |
)
|
| 610 |
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
|
|
|
| 624 |
|
| 625 |
# replicate state on each device
|
| 626 |
state = state.replicate()
|
|
@@ -688,52 +703,61 @@ def main():
|
|
| 688 |
f,
|
| 689 |
)
|
| 690 |
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
|
| 723 |
-
|
| 724 |
-
|
| 725 |
-
|
| 726 |
-
|
| 727 |
-
|
| 728 |
-
|
| 729 |
-
|
| 730 |
-
|
| 731 |
-
|
| 732 |
-
|
| 733 |
-
|
| 734 |
-
|
| 735 |
-
|
| 736 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 737 |
|
| 738 |
# init variables
|
| 739 |
last_time = time.perf_counter()
|
|
|
|
| 389 |
)
|
| 390 |
|
| 391 |
# Set up wandb run
|
| 392 |
+
if jax.process_index() == 0:
|
| 393 |
+
wandb.init(
|
| 394 |
+
entity="dalle-mini",
|
| 395 |
+
project="dalle-mini",
|
| 396 |
+
job_type="Seq2Seq",
|
| 397 |
+
config=parser.parse_args(),
|
| 398 |
+
)
|
| 399 |
|
| 400 |
if training_args.resume_from_checkpoint is not None:
|
| 401 |
+
if jax.process_index() == 0:
|
| 402 |
+
artifact = wandb.run.use_artifact(training_args.resume_from_checkpoint)
|
| 403 |
+
else:
|
| 404 |
+
artifact = wandb.Api().artifact(training_args.resume_from_checkpoint)
|
| 405 |
artifact_dir = artifact.download()
|
| 406 |
|
| 407 |
# load model
|
|
|
|
| 466 |
|
| 467 |
# Store some constant
|
| 468 |
num_epochs = int(training_args.num_train_epochs)
|
| 469 |
+
# batch size per node
|
| 470 |
train_batch_size = (
|
| 471 |
+
int(training_args.per_device_train_batch_size) * jax.local_device_count()
|
| 472 |
+
)
|
| 473 |
+
batch_size_per_update = (
|
| 474 |
+
train_batch_size
|
| 475 |
+
* training_args.gradient_accumulation_steps
|
| 476 |
+
* jax.process_count()
|
| 477 |
+
)
|
| 478 |
+
eval_batch_size = (
|
| 479 |
+
int(training_args.per_device_eval_batch_size) * jax.local_device_count()
|
| 480 |
)
|
|
|
|
|
|
|
| 481 |
len_train_dataset, len_eval_dataset = dataset.length
|
| 482 |
steps_per_epoch = (
|
| 483 |
+
len_train_dataset // (train_batch_size * jax.process_count())
|
| 484 |
+
if len_train_dataset is not None
|
| 485 |
+
else None
|
| 486 |
)
|
| 487 |
num_train_steps = (
|
| 488 |
steps_per_epoch * num_epochs if steps_per_epoch is not None else None
|
|
|
|
| 581 |
grads=grads,
|
| 582 |
dropout_rng=new_dropout_rng,
|
| 583 |
train_time=state.train_time + delta_time,
|
| 584 |
+
train_samples=state.train_samples + train_batch_size * jax.process_count(),
|
| 585 |
)
|
| 586 |
|
| 587 |
metrics = {
|
|
|
|
| 613 |
logger.info(
|
| 614 |
f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
|
| 615 |
)
|
| 616 |
+
logger.info(f" Number of devices = {jax.device_count()}")
|
| 617 |
logger.info(
|
| 618 |
f" Total train batch size (w. parallel, distributed & gradient accumulation) = {batch_size_per_update}"
|
| 619 |
)
|
|
|
|
| 622 |
range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
|
| 623 |
)
|
| 624 |
|
| 625 |
+
if jax.process_index() == 0:
|
| 626 |
+
# set default x-axis as 'train/step'
|
| 627 |
+
wandb_log({}, step=state.step)
|
| 628 |
+
wandb.define_metric("*", step_metric="train/step")
|
| 629 |
+
|
| 630 |
+
# add interesting config parameters
|
| 631 |
+
wandb.config.update(
|
| 632 |
+
{
|
| 633 |
+
"len_train_dataset": len_train_dataset,
|
| 634 |
+
"len_eval_dataset": len_eval_dataset,
|
| 635 |
+
"batch_size_per_update": batch_size_per_update,
|
| 636 |
+
"num_params": num_params,
|
| 637 |
+
}
|
| 638 |
+
)
|
| 639 |
|
| 640 |
# replicate state on each device
|
| 641 |
state = state.replicate()
|
|
|
|
| 703 |
f,
|
| 704 |
)
|
| 705 |
|
| 706 |
+
if jax.process_index() == 0:
|
| 707 |
+
# save to W&B
|
| 708 |
+
if training_args.log_model:
|
| 709 |
+
# save some space
|
| 710 |
+
c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
|
| 711 |
+
c.cleanup(wandb.util.from_human_size("10GB"))
|
| 712 |
+
|
| 713 |
+
metadata = dict(state_dict)
|
| 714 |
+
metadata["num_params"] = num_params
|
| 715 |
+
if eval_metrics is not None:
|
| 716 |
+
metadata["eval"] = eval_metrics
|
| 717 |
+
artifact = wandb.Artifact(
|
| 718 |
+
name=f"model-{wandb.run.id}",
|
| 719 |
+
type="bart_model",
|
| 720 |
+
metadata=metadata,
|
| 721 |
+
)
|
| 722 |
+
artifact.add_file(
|
| 723 |
+
str(Path(training_args.output_dir) / "flax_model.msgpack")
|
| 724 |
+
)
|
| 725 |
+
artifact.add_file(
|
| 726 |
+
str(Path(training_args.output_dir) / "config.json")
|
| 727 |
+
)
|
| 728 |
+
artifact.add_file(
|
| 729 |
+
str(Path(training_args.output_dir) / "tokenizer.json")
|
| 730 |
+
)
|
| 731 |
+
artifact.add_file(
|
| 732 |
+
str(Path(training_args.output_dir) / "tokenizer_config.json")
|
| 733 |
+
)
|
| 734 |
+
artifact.add_file(
|
| 735 |
+
str(Path(training_args.output_dir) / "vocab.json")
|
| 736 |
+
)
|
| 737 |
+
artifact.add_file(
|
| 738 |
+
str(Path(training_args.output_dir) / "merges.txt")
|
| 739 |
+
)
|
| 740 |
+
artifact.add_file(
|
| 741 |
+
str(Path(training_args.output_dir) / "special_tokens_map.json")
|
| 742 |
+
)
|
| 743 |
+
artifact.add_file(
|
| 744 |
+
str(Path(training_args.output_dir) / "opt_state.msgpack")
|
| 745 |
+
)
|
| 746 |
+
artifact.add_file(
|
| 747 |
+
str(Path(training_args.output_dir) / "training_state.json")
|
| 748 |
+
)
|
| 749 |
+
|
| 750 |
+
wandb.run.log_artifact(artifact)
|
| 751 |
+
|
| 752 |
+
# save to the hub
|
| 753 |
+
if training_args.push_to_hub:
|
| 754 |
+
model.save_pretrained(
|
| 755 |
+
training_args.output_dir,
|
| 756 |
+
params=params,
|
| 757 |
+
push_to_hub=training_args.push_to_hub,
|
| 758 |
+
commit_message=f"Saving weights and logs at step {unreplicate(state.step)+1}",
|
| 759 |
+
temp_dir=True, # avoid issues with being in a repository
|
| 760 |
+
)
|
| 761 |
|
| 762 |
# init variables
|
| 763 |
last_time = time.perf_counter()
|