Spaces:
Running
Running
feat: add adafactor
Browse files- seq2seq/run_seq2seq_flax.py +16 -9
seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -623,17 +623,24 @@ def main():
|
|
| 623 |
return traverse_util.unflatten_dict(flat_mask)
|
| 624 |
|
| 625 |
# create adam optimizer
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 634 |
|
| 635 |
# Setup train state
|
| 636 |
-
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=
|
| 637 |
|
| 638 |
# label smoothed cross entropy
|
| 639 |
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
|
|
|
|
| 623 |
return traverse_util.unflatten_dict(flat_mask)
|
| 624 |
|
| 625 |
# create adam optimizer
|
| 626 |
+
if training_args.adafactor:
|
| 627 |
+
# We use the default parameters here to initialize adafactor,
|
| 628 |
+
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
|
| 629 |
+
optimizer = optax.adafactor(
|
| 630 |
+
learning_rate=linear_decay_lr_schedule_fn,
|
| 631 |
+
)
|
| 632 |
+
else:
|
| 633 |
+
optimizer = optax.adamw(
|
| 634 |
+
learning_rate=linear_decay_lr_schedule_fn,
|
| 635 |
+
b1=training_args.adam_beta1,
|
| 636 |
+
b2=training_args.adam_beta2,
|
| 637 |
+
eps=training_args.adam_epsilon,
|
| 638 |
+
weight_decay=training_args.weight_decay,
|
| 639 |
+
mask=decay_mask_fn,
|
| 640 |
+
)
|
| 641 |
|
| 642 |
# Setup train state
|
| 643 |
+
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
|
| 644 |
|
| 645 |
# label smoothed cross entropy
|
| 646 |
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
|