Spaces:
Running
Running
feat: no decay option
Browse files- seq2seq/run_seq2seq_flax.py +7 -1
- seq2seq/sweep.yaml +1 -0
seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -162,6 +162,9 @@ class DataTrainingArguments:
|
|
| 162 |
"than this will be truncated, sequences shorter will be padded."
|
| 163 |
},
|
| 164 |
)
|
|
|
|
|
|
|
|
|
|
| 165 |
max_target_length: Optional[int] = field(
|
| 166 |
default=OUTPUT_LENGTH,
|
| 167 |
metadata={
|
|
@@ -332,12 +335,14 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
|
| 332 |
|
| 333 |
|
| 334 |
def create_learning_rate_fn(
|
| 335 |
-
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
| 336 |
) -> Callable[[int], jnp.array]:
|
| 337 |
"""Returns a linear warmup, linear_decay learning rate function."""
|
| 338 |
steps_per_epoch = train_ds_size // train_batch_size
|
| 339 |
num_train_steps = steps_per_epoch * num_train_epochs
|
| 340 |
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
|
|
|
|
|
|
|
| 341 |
decay_fn = optax.linear_schedule(
|
| 342 |
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
|
| 343 |
)
|
|
@@ -610,6 +615,7 @@ def main():
|
|
| 610 |
training_args.num_train_epochs,
|
| 611 |
training_args.warmup_steps,
|
| 612 |
training_args.learning_rate,
|
|
|
|
| 613 |
)
|
| 614 |
|
| 615 |
# We use Optax's "masking" functionality to not apply weight decay
|
|
|
|
| 162 |
"than this will be truncated, sequences shorter will be padded."
|
| 163 |
},
|
| 164 |
)
|
| 165 |
+
no_decay: bool = field(
|
| 166 |
+
default=False, metadata={"help": "Whether to use decay in the learning rate scheduler."}
|
| 167 |
+
)
|
| 168 |
max_target_length: Optional[int] = field(
|
| 169 |
default=OUTPUT_LENGTH,
|
| 170 |
metadata={
|
|
|
|
| 335 |
|
| 336 |
|
| 337 |
def create_learning_rate_fn(
|
| 338 |
+
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float, no_decay: bool
|
| 339 |
) -> Callable[[int], jnp.array]:
|
| 340 |
"""Returns a linear warmup, linear_decay learning rate function."""
|
| 341 |
steps_per_epoch = train_ds_size // train_batch_size
|
| 342 |
num_train_steps = steps_per_epoch * num_train_epochs
|
| 343 |
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
|
| 344 |
+
if no_decay:
|
| 345 |
+
return warmup_fn
|
| 346 |
decay_fn = optax.linear_schedule(
|
| 347 |
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
|
| 348 |
)
|
|
|
|
| 615 |
training_args.num_train_epochs,
|
| 616 |
training_args.warmup_steps,
|
| 617 |
training_args.learning_rate,
|
| 618 |
+
data_args.no_decay
|
| 619 |
)
|
| 620 |
|
| 621 |
# We use Optax's "masking" functionality to not apply weight decay
|
seq2seq/sweep.yaml
CHANGED
|
@@ -37,6 +37,7 @@ command:
|
|
| 37 |
- 56
|
| 38 |
- "--preprocessing_num_workers"
|
| 39 |
- 80
|
|
|
|
| 40 |
- "--do_train"
|
| 41 |
- "--do_eval"
|
| 42 |
- ${args}
|
|
|
|
| 37 |
- 56
|
| 38 |
- "--preprocessing_num_workers"
|
| 39 |
- 80
|
| 40 |
+
- "--no_decay"
|
| 41 |
- "--do_train"
|
| 42 |
- "--do_eval"
|
| 43 |
- ${args}
|