Spaces:
Running
Running
feat(train): refactor learning rate params
Browse files- tools/train/train.py +53 -35
tools/train/train.py
CHANGED
|
@@ -246,9 +246,29 @@ class TrainingArguments:
|
|
| 246 |
},
|
| 247 |
)
|
| 248 |
|
| 249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
default=False,
|
| 251 |
-
metadata={
|
|
|
|
|
|
|
| 252 |
)
|
| 253 |
|
| 254 |
num_train_epochs: float = field(
|
|
@@ -321,33 +341,6 @@ class TrainState(train_state.TrainState):
|
|
| 321 |
)
|
| 322 |
|
| 323 |
|
| 324 |
-
def create_learning_rate_fn(
|
| 325 |
-
num_warmup_steps: int,
|
| 326 |
-
learning_rate: float,
|
| 327 |
-
use_decay: bool,
|
| 328 |
-
num_train_steps: int = None, # used only with `use_decay`, typically train_size // batch_size * num_epochs
|
| 329 |
-
) -> Callable[[int], jnp.array]:
|
| 330 |
-
"""Returns a linear warmup, linear_decay learning rate function."""
|
| 331 |
-
if use_decay:
|
| 332 |
-
assert (
|
| 333 |
-
num_train_steps is not None
|
| 334 |
-
), "Learning rate with decay requires number of training steps"
|
| 335 |
-
warmup_fn = optax.linear_schedule(
|
| 336 |
-
init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps
|
| 337 |
-
)
|
| 338 |
-
if not use_decay:
|
| 339 |
-
return warmup_fn
|
| 340 |
-
decay_fn = optax.linear_schedule(
|
| 341 |
-
init_value=learning_rate,
|
| 342 |
-
end_value=0,
|
| 343 |
-
transition_steps=num_train_steps - num_warmup_steps,
|
| 344 |
-
)
|
| 345 |
-
schedule_fn = optax.join_schedules(
|
| 346 |
-
schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]
|
| 347 |
-
)
|
| 348 |
-
return schedule_fn
|
| 349 |
-
|
| 350 |
-
|
| 351 |
class MetricsLogger:
|
| 352 |
def __init__(self, state):
|
| 353 |
self.step = state.step
|
|
@@ -541,12 +534,37 @@ def main():
|
|
| 541 |
num_params = model.num_params
|
| 542 |
|
| 543 |
# Create learning rate schedule
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 550 |
|
| 551 |
# We use Optax's "masking" functionality to not apply weight decay
|
| 552 |
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
|
|
|
| 246 |
},
|
| 247 |
)
|
| 248 |
|
| 249 |
+
lr_decay: str = field(
|
| 250 |
+
default=None,
|
| 251 |
+
metadata={
|
| 252 |
+
"help": "Decay to be used in the learning rate scheduler. Can be None (default), linear or exponential."
|
| 253 |
+
},
|
| 254 |
+
)
|
| 255 |
+
lr_transition_steps: int = field(
|
| 256 |
+
default=None,
|
| 257 |
+
metadata={
|
| 258 |
+
"help": "Number of transition steps associated with learning rate decay when using exponential decay."
|
| 259 |
+
},
|
| 260 |
+
)
|
| 261 |
+
lr_decay_rate: float = field(
|
| 262 |
+
default=None,
|
| 263 |
+
metadata={
|
| 264 |
+
"help": "Decay rate associated with learning rate when using exponential decay."
|
| 265 |
+
},
|
| 266 |
+
)
|
| 267 |
+
lr_staircase: bool = field(
|
| 268 |
default=False,
|
| 269 |
+
metadata={
|
| 270 |
+
"help": "Whether to use staircase or continuous learning rate when using exponential decay."
|
| 271 |
+
},
|
| 272 |
)
|
| 273 |
|
| 274 |
num_train_epochs: float = field(
|
|
|
|
| 341 |
)
|
| 342 |
|
| 343 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
class MetricsLogger:
|
| 345 |
def __init__(self, state):
|
| 346 |
self.step = state.step
|
|
|
|
| 534 |
num_params = model.num_params
|
| 535 |
|
| 536 |
# Create learning rate schedule
|
| 537 |
+
def create_learning_rate_fn() -> Callable[[int], jnp.array]:
|
| 538 |
+
"""Create the learning rate function."""
|
| 539 |
+
warmup_fn = optax.linear_schedule(
|
| 540 |
+
init_value=0.0,
|
| 541 |
+
end_value=training_args.learning_rate,
|
| 542 |
+
transition_steps=training_args.warmup_steps,
|
| 543 |
+
)
|
| 544 |
+
if training_args.lr_decay is None:
|
| 545 |
+
return warmup_fn
|
| 546 |
+
elif training_args.lr_decay == "linear":
|
| 547 |
+
assert (
|
| 548 |
+
num_train_steps is not None
|
| 549 |
+
), "linear decay requires knowing the dataset length"
|
| 550 |
+
decay_fn = optax.linear_schedule(
|
| 551 |
+
init_value=training_args.learning_rate,
|
| 552 |
+
end_value=0,
|
| 553 |
+
transition_steps=num_train_steps - training_args.warmup_steps,
|
| 554 |
+
)
|
| 555 |
+
elif training_args.lr_decay == "exponential":
|
| 556 |
+
decay_fn = optax.exponential_decay(
|
| 557 |
+
init_value=training_args.learning_rate,
|
| 558 |
+
transition_steps=training_args.lr_transition_steps,
|
| 559 |
+
decay_rate=training_args.lr_decay_rate,
|
| 560 |
+
staircase=training_args.lr_staircase,
|
| 561 |
+
)
|
| 562 |
+
schedule_fn = optax.join_schedules(
|
| 563 |
+
schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
|
| 564 |
+
)
|
| 565 |
+
return schedule_fn
|
| 566 |
+
|
| 567 |
+
learning_rate_fn = create_learning_rate_fn()
|
| 568 |
|
| 569 |
# We use Optax's "masking" functionality to not apply weight decay
|
| 570 |
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|