Spaces:
Running
Running
feat: add more config of distributed_shampoo
Browse files- tools/train/train.py +13 -7
tools/train/train.py
CHANGED
|
@@ -220,15 +220,15 @@ class TrainingArguments:
|
|
| 220 |
},
|
| 221 |
)
|
| 222 |
weight_decay: float = field(
|
| 223 |
-
default=None, metadata={"help": "Weight decay
|
| 224 |
)
|
| 225 |
beta1: float = field(
|
| 226 |
default=0.9,
|
| 227 |
-
metadata={"help": "Beta1 for
|
| 228 |
)
|
| 229 |
beta2: float = field(
|
| 230 |
default=0.999,
|
| 231 |
-
metadata={"help": "Beta2 for
|
| 232 |
)
|
| 233 |
adam_epsilon: float = field(
|
| 234 |
default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}
|
|
@@ -236,13 +236,19 @@ class TrainingArguments:
|
|
| 236 |
max_grad_norm: float = field(
|
| 237 |
default=1.0, metadata={"help": "Max gradient norm for Adafactor."}
|
| 238 |
)
|
|
|
|
|
|
|
|
|
|
| 239 |
preconditioning_compute_steps: int = field(
|
| 240 |
default=10, metadata={"help": "Number of steps to update preconditioner."}
|
| 241 |
)
|
|
|
|
|
|
|
|
|
|
| 242 |
optim_quantized: bool = field(
|
| 243 |
default=False,
|
| 244 |
metadata={
|
| 245 |
-
"help": "Whether to quantize optimizer (only supported with
|
| 246 |
},
|
| 247 |
)
|
| 248 |
|
|
@@ -594,7 +600,7 @@ def main():
|
|
| 594 |
# - mask for weight decay is not implemented
|
| 595 |
optimizer = distributed_shampoo(
|
| 596 |
learning_rate_fn,
|
| 597 |
-
block_size=
|
| 598 |
beta1=training_args.beta1,
|
| 599 |
beta2=training_args.beta2,
|
| 600 |
diagonal_epsilon=1e-10,
|
|
@@ -602,7 +608,7 @@ def main():
|
|
| 602 |
weight_decay=training_args.weight_decay
|
| 603 |
if training_args.weight_decay is not None
|
| 604 |
else 0.0,
|
| 605 |
-
start_preconditioning_step=
|
| 606 |
preconditioning_compute_steps=training_args.preconditioning_compute_steps,
|
| 607 |
statistics_compute_steps=1,
|
| 608 |
best_effort_shape_interpretation=True,
|
|
@@ -612,7 +618,7 @@ def main():
|
|
| 612 |
batch_axis_name="batch",
|
| 613 |
inverse_failure_threshold=0.1,
|
| 614 |
moving_average_for_momentum=True,
|
| 615 |
-
skip_preconditioning_dim_size_gt=
|
| 616 |
clip_by_scaled_gradient_norm=None,
|
| 617 |
precision=jax.lax.Precision.HIGHEST,
|
| 618 |
best_effort_memory_usage_reduction=training_args.optim_quantized,
|
|
|
|
| 220 |
},
|
| 221 |
)
|
| 222 |
weight_decay: float = field(
|
| 223 |
+
default=None, metadata={"help": "Weight decay."}
|
| 224 |
)
|
| 225 |
beta1: float = field(
|
| 226 |
default=0.9,
|
| 227 |
+
metadata={"help": "Beta1 for Adam & Distributed Shampoo."},
|
| 228 |
)
|
| 229 |
beta2: float = field(
|
| 230 |
default=0.999,
|
| 231 |
+
metadata={"help": "Beta2 for for Adam & Distributed Shampoo."},
|
| 232 |
)
|
| 233 |
adam_epsilon: float = field(
|
| 234 |
default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}
|
|
|
|
| 236 |
max_grad_norm: float = field(
|
| 237 |
default=1.0, metadata={"help": "Max gradient norm for Adafactor."}
|
| 238 |
)
|
| 239 |
+
block_size: int = field(
|
| 240 |
+
default=1024, metadata={"help": "Chunked size for large layers with Distributed Shampoo."}
|
| 241 |
+
)
|
| 242 |
preconditioning_compute_steps: int = field(
|
| 243 |
default=10, metadata={"help": "Number of steps to update preconditioner."}
|
| 244 |
)
|
| 245 |
+
skip_preconditioning_dim_size_gt: int = field(
|
| 246 |
+
default=4096, metadata={"help": "Max size for preconditioning with Distributed Shampoo."}
|
| 247 |
+
)
|
| 248 |
optim_quantized: bool = field(
|
| 249 |
default=False,
|
| 250 |
metadata={
|
| 251 |
+
"help": "Whether to quantize optimizer (only supported with Distributed Shampoo)."
|
| 252 |
},
|
| 253 |
)
|
| 254 |
|
|
|
|
| 600 |
# - mask for weight decay is not implemented
|
| 601 |
optimizer = distributed_shampoo(
|
| 602 |
learning_rate_fn,
|
| 603 |
+
block_size=training_args.block_size,
|
| 604 |
beta1=training_args.beta1,
|
| 605 |
beta2=training_args.beta2,
|
| 606 |
diagonal_epsilon=1e-10,
|
|
|
|
| 608 |
weight_decay=training_args.weight_decay
|
| 609 |
if training_args.weight_decay is not None
|
| 610 |
else 0.0,
|
| 611 |
+
start_preconditioning_step=training_args.warmup_steps,
|
| 612 |
preconditioning_compute_steps=training_args.preconditioning_compute_steps,
|
| 613 |
statistics_compute_steps=1,
|
| 614 |
best_effort_shape_interpretation=True,
|
|
|
|
| 618 |
batch_axis_name="batch",
|
| 619 |
inverse_failure_threshold=0.1,
|
| 620 |
moving_average_for_momentum=True,
|
| 621 |
+
skip_preconditioning_dim_size_gt=training_args.skip_preconditioning_dim_size_gt,
|
| 622 |
clip_by_scaled_gradient_norm=None,
|
| 623 |
precision=jax.lax.Precision.HIGHEST,
|
| 624 |
best_effort_memory_usage_reduction=training_args.optim_quantized,
|