let hf trainer handle torch compile (#516)
Browse files* let hf trainer handle torch compile
* remove torch compile checks, include option for backend
* suppress torch errors to get further
* require min torch version of 2.1.0 for torch compile to work
---------
Co-authored-by: Aman Karmani <[email protected]>
- README.md +4 -0
- src/axolotl/train.py +0 -4
- src/axolotl/utils/trainer.py +16 -0
README.md
CHANGED
|
@@ -519,6 +519,10 @@ wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_step
|
|
| 519 |
# where to save the finished model to
|
| 520 |
output_dir: ./completed-model
|
| 521 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 522 |
# training hyperparameters
|
| 523 |
gradient_accumulation_steps: 1
|
| 524 |
micro_batch_size: 2
|
|
|
|
| 519 |
# where to save the finished model to
|
| 520 |
output_dir: ./completed-model
|
| 521 |
|
| 522 |
+
# whether to use torch.compile and which backend to use
|
| 523 |
+
torch_compile: # bool
|
| 524 |
+
torch_compile_backend: # Optional[str]
|
| 525 |
+
|
| 526 |
# training hyperparameters
|
| 527 |
gradient_accumulation_steps: 1
|
| 528 |
micro_batch_size: 2
|
src/axolotl/train.py
CHANGED
|
@@ -80,10 +80,6 @@ def train(
|
|
| 80 |
|
| 81 |
model.config.use_cache = False
|
| 82 |
|
| 83 |
-
if torch.__version__ >= "2" and sys.platform != "win32":
|
| 84 |
-
LOG.info("Compiling torch model")
|
| 85 |
-
model = torch.compile(model)
|
| 86 |
-
|
| 87 |
# go ahead and presave, so we have the adapter config available to inspect
|
| 88 |
if peft_config:
|
| 89 |
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
|
|
|
|
| 80 |
|
| 81 |
model.config.use_cache = False
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
# go ahead and presave, so we have the adapter config available to inspect
|
| 84 |
if peft_config:
|
| 85 |
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
|
src/axolotl/utils/trainer.py
CHANGED
|
@@ -11,6 +11,7 @@ from pathlib import Path
|
|
| 11 |
from typing import Optional, Union
|
| 12 |
|
| 13 |
import numpy as np
|
|
|
|
| 14 |
import torch.cuda
|
| 15 |
import transformers
|
| 16 |
from datasets import Dataset, set_caching_enabled
|
|
@@ -604,6 +605,21 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|
| 604 |
if cfg.greater_is_better:
|
| 605 |
training_arguments_kwargs["greater_is_better"] = cfg.greater_is_better
|
| 606 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 607 |
# DDP Config
|
| 608 |
if cfg.ddp_timeout:
|
| 609 |
training_arguments_kwargs["ddp_timeout"] = cfg.ddp_timeout
|
|
|
|
| 11 |
from typing import Optional, Union
|
| 12 |
|
| 13 |
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
import torch.cuda
|
| 16 |
import transformers
|
| 17 |
from datasets import Dataset, set_caching_enabled
|
|
|
|
| 605 |
if cfg.greater_is_better:
|
| 606 |
training_arguments_kwargs["greater_is_better"] = cfg.greater_is_better
|
| 607 |
|
| 608 |
+
if cfg.torch_compile:
|
| 609 |
+
if torch.__version__ < "2.1.0": # pylint: disable=protected-access
|
| 610 |
+
LOG.warning("torch>=2.1.0 required for torch_compile to work properly")
|
| 611 |
+
else:
|
| 612 |
+
import torch._dynamo # pylint: disable=redefined-outer-name
|
| 613 |
+
|
| 614 |
+
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
|
| 615 |
+
True
|
| 616 |
+
)
|
| 617 |
+
training_arguments_kwargs["torch_compile"] = cfg.torch_compile
|
| 618 |
+
if cfg.torch_compile_backend:
|
| 619 |
+
training_arguments_kwargs[
|
| 620 |
+
"torch_compile_backend"
|
| 621 |
+
] = cfg.torch_compile_backend
|
| 622 |
+
|
| 623 |
# DDP Config
|
| 624 |
if cfg.ddp_timeout:
|
| 625 |
training_arguments_kwargs["ddp_timeout"] = cfg.ddp_timeout
|