Spaces:
Running
Running
feat(train): load model on CPU
Browse files- tools/train/train.py +24 -23
tools/train/train.py
CHANGED
|
@@ -679,39 +679,40 @@ def main():
|
|
| 679 |
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
|
| 680 |
mesh = maps.Mesh(devices, ("batch", "mp"))
|
| 681 |
|
| 682 |
-
#
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
tx=optimizer,
|
| 687 |
-
params=params,
|
| 688 |
-
opt_state=opt_state,
|
| 689 |
-
dropout_rng=dropout_rng,
|
| 690 |
-
step=0,
|
| 691 |
-
)
|
| 692 |
-
|
| 693 |
-
state_spec = init_state(param_spec, opt_state_spec)
|
| 694 |
-
state_spec = state_spec.replace(
|
| 695 |
dropout_rng=None,
|
| 696 |
step=None,
|
| 697 |
epoch=None,
|
| 698 |
train_time=None,
|
| 699 |
train_samples=None,
|
|
|
|
|
|
|
| 700 |
)
|
| 701 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 702 |
with maps.mesh(mesh.devices, mesh.axis_names):
|
| 703 |
-
# move params & init opt_state over specified devices
|
| 704 |
-
params, opt_state = pjit(
|
| 705 |
-
lambda x: (x, optimizer.init(x)),
|
| 706 |
-
in_axis_resources=None,
|
| 707 |
-
out_axis_resources=(param_spec, opt_state_spec),
|
| 708 |
-
)(freeze(model.params))
|
| 709 |
-
# create training state
|
| 710 |
state = pjit(
|
| 711 |
init_state,
|
| 712 |
-
in_axis_resources=
|
| 713 |
out_axis_resources=state_spec,
|
| 714 |
-
|
|
|
|
| 715 |
|
| 716 |
if training_args.resume_from_checkpoint is not None:
|
| 717 |
# restore optimizer state and other parameters
|
|
@@ -793,7 +794,7 @@ def main():
|
|
| 793 |
# Create parallel version of the train and eval step
|
| 794 |
p_train_step = pjit(
|
| 795 |
train_step,
|
| 796 |
-
in_axis_resources=(state_spec, None, None),
|
| 797 |
out_axis_resources=(state_spec, None),
|
| 798 |
donate_argnums=(0,),
|
| 799 |
)
|
|
|
|
| 679 |
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
|
| 680 |
mesh = maps.Mesh(devices, ("batch", "mp"))
|
| 681 |
|
| 682 |
+
# Create state spec
|
| 683 |
+
state_spec = TrainState(
|
| 684 |
+
params=param_spec,
|
| 685 |
+
opt_state=opt_state_spec,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 686 |
dropout_rng=None,
|
| 687 |
step=None,
|
| 688 |
epoch=None,
|
| 689 |
train_time=None,
|
| 690 |
train_samples=None,
|
| 691 |
+
apply_fn=model.__call__,
|
| 692 |
+
tx=optimizer,
|
| 693 |
)
|
| 694 |
|
| 695 |
+
# create training state
|
| 696 |
+
def init_state(params):
|
| 697 |
+
state = TrainState.create(
|
| 698 |
+
apply_fn=model.__call__,
|
| 699 |
+
tx=optimizer,
|
| 700 |
+
params=freeze(params),
|
| 701 |
+
dropout_rng=dropout_rng,
|
| 702 |
+
)
|
| 703 |
+
return state
|
| 704 |
+
|
| 705 |
+
# hack: move the inital params to CPU to free up device memory
|
| 706 |
+
# TODO: allow loading weights on CPU in pre-trained model
|
| 707 |
+
model.params = jax.tree_map(lambda x: np.asarray(x), model.params)
|
| 708 |
+
|
| 709 |
with maps.mesh(mesh.devices, mesh.axis_names):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 710 |
state = pjit(
|
| 711 |
init_state,
|
| 712 |
+
in_axis_resources=None,
|
| 713 |
out_axis_resources=state_spec,
|
| 714 |
+
donate_argnums=(0,),
|
| 715 |
+
)(freeze(model.params))
|
| 716 |
|
| 717 |
if training_args.resume_from_checkpoint is not None:
|
| 718 |
# restore optimizer state and other parameters
|
|
|
|
| 794 |
# Create parallel version of the train and eval step
|
| 795 |
p_train_step = pjit(
|
| 796 |
train_step,
|
| 797 |
+
in_axis_resources=(state_spec, PartitionSpec("batch", None), None),
|
| 798 |
out_axis_resources=(state_spec, None),
|
| 799 |
donate_argnums=(0,),
|
| 800 |
)
|