feat: enable trl's autounwrap (#1060)
Browse files* feat: test trl's autounwrap
* fix: add check for adapter
* feat: add config to disable autounwrap
* chore: fix lint
- .vscode/launch.json +1 -1
- devtools/README.md +1 -1
- docs/debugging.md +4 -4
- docs/rlhf.md +9 -0
- src/axolotl/train.py +9 -4
.vscode/launch.json
CHANGED
|
@@ -11,7 +11,7 @@
|
|
| 11 |
"request": "launch",
|
| 12 |
"args": [
|
| 13 |
"-m", "axolotl.cli.train", "dev_sharegpt.yml",
|
| 14 |
-
// The flags below simplify debugging by overriding the axolotl config
|
| 15 |
// with the debugging tips above. Modify as needed.
|
| 16 |
"--dataset_processes=1", // limits data preprocessing to one process
|
| 17 |
"--max_steps=1", // limits training to just one step
|
|
|
|
| 11 |
"request": "launch",
|
| 12 |
"args": [
|
| 13 |
"-m", "axolotl.cli.train", "dev_sharegpt.yml",
|
| 14 |
+
// The flags below simplify debugging by overriding the axolotl config
|
| 15 |
// with the debugging tips above. Modify as needed.
|
| 16 |
"--dataset_processes=1", // limits data preprocessing to one process
|
| 17 |
"--max_steps=1", // limits training to just one step
|
devtools/README.md
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
This directory contains example config files that might be useful for debugging. Please see [docs/debugging.md](../docs/debugging.md) for more information.
|
|
|
|
| 1 |
+
This directory contains example config files that might be useful for debugging. Please see [docs/debugging.md](../docs/debugging.md) for more information.
|
docs/debugging.md
CHANGED
|
@@ -30,13 +30,13 @@ While debugging it's helpful to simplify your test scenario as much as possible.
|
|
| 30 |
3. **Use a small model**: A good example of a small model is [TinyLlama/TinyLlama-1.1B-Chat-v1.0](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0).
|
| 31 |
4. **Minimize iteration time**: Make sure the training loop finishes as fast as possible, with these settings.
|
| 32 |
- `micro_batch_size: 1`
|
| 33 |
-
- `max_steps: 1`
|
| 34 |
- `val_set_size: 0`
|
| 35 |
5. **Clear Caches:** Axolotl caches certain steps and so does the underlying HuggingFace trainer. You may want to clear some of these caches when debugging.
|
| 36 |
- Data preprocessing: When debugging data preprocessing, which includes prompt template formation, you may want to delete the directory set in `dataset_prepared_path:` in your axolotl config. If you didn't set this value, the default is `last_run_prepared`.
|
| 37 |
- HF Hub: If you are debugging data preprocessing, you should clear the relevant HF cache [HuggingFace cache](https://huggingface.co/docs/datasets/cache), by deleting the appropriate `~/.cache/huggingface/datasets/...` folder(s).
|
| 38 |
- **The recommended approach is to redirect all outputs and caches to a temporary folder and delete selected subfolders before each run. This is demonstrated in the example configuration below.**
|
| 39 |
-
|
| 40 |
|
| 41 |
## Debugging with VSCode
|
| 42 |
|
|
@@ -74,7 +74,7 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 acceler
|
|
| 74 |
"request": "launch",
|
| 75 |
"args": [
|
| 76 |
"-m", "axolotl.cli.train", "dev_sharegpt.yml",
|
| 77 |
-
// The flags below simplify debugging by overriding the axolotl config
|
| 78 |
// with the debugging tips above. Modify as needed.
|
| 79 |
"--dataset_processes=1", // limits data preprocessing to one process
|
| 80 |
"--max_steps=1", // limits training to just one step
|
|
@@ -101,7 +101,7 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 acceler
|
|
| 101 |
|
| 102 |
- The argument `justMyCode` is set to `true` such that you step through only the axolotl code. If you want to step into dependencies, set this to `false`.
|
| 103 |
- The `preLaunchTask`: `cleanup-for-dataprep` is defined in [.vscode/tasks.json](../.vscode/tasks.json) and is used to delete the following folders before debugging, which is essential to ensure that the data pre-processing code is run from scratch:
|
| 104 |
-
- `./devtools/temp_debug/axolotl_outputs`
|
| 105 |
- `./devtools/temp_debug/.hf-cache/datasets`
|
| 106 |
|
| 107 |
>[!Tip]
|
|
|
|
| 30 |
3. **Use a small model**: A good example of a small model is [TinyLlama/TinyLlama-1.1B-Chat-v1.0](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0).
|
| 31 |
4. **Minimize iteration time**: Make sure the training loop finishes as fast as possible, with these settings.
|
| 32 |
- `micro_batch_size: 1`
|
| 33 |
+
- `max_steps: 1`
|
| 34 |
- `val_set_size: 0`
|
| 35 |
5. **Clear Caches:** Axolotl caches certain steps and so does the underlying HuggingFace trainer. You may want to clear some of these caches when debugging.
|
| 36 |
- Data preprocessing: When debugging data preprocessing, which includes prompt template formation, you may want to delete the directory set in `dataset_prepared_path:` in your axolotl config. If you didn't set this value, the default is `last_run_prepared`.
|
| 37 |
- HF Hub: If you are debugging data preprocessing, you should clear the relevant HF cache [HuggingFace cache](https://huggingface.co/docs/datasets/cache), by deleting the appropriate `~/.cache/huggingface/datasets/...` folder(s).
|
| 38 |
- **The recommended approach is to redirect all outputs and caches to a temporary folder and delete selected subfolders before each run. This is demonstrated in the example configuration below.**
|
| 39 |
+
|
| 40 |
|
| 41 |
## Debugging with VSCode
|
| 42 |
|
|
|
|
| 74 |
"request": "launch",
|
| 75 |
"args": [
|
| 76 |
"-m", "axolotl.cli.train", "dev_sharegpt.yml",
|
| 77 |
+
// The flags below simplify debugging by overriding the axolotl config
|
| 78 |
// with the debugging tips above. Modify as needed.
|
| 79 |
"--dataset_processes=1", // limits data preprocessing to one process
|
| 80 |
"--max_steps=1", // limits training to just one step
|
|
|
|
| 101 |
|
| 102 |
- The argument `justMyCode` is set to `true` such that you step through only the axolotl code. If you want to step into dependencies, set this to `false`.
|
| 103 |
- The `preLaunchTask`: `cleanup-for-dataprep` is defined in [.vscode/tasks.json](../.vscode/tasks.json) and is used to delete the following folders before debugging, which is essential to ensure that the data pre-processing code is run from scratch:
|
| 104 |
+
- `./devtools/temp_debug/axolotl_outputs`
|
| 105 |
- `./devtools/temp_debug/.hf-cache/datasets`
|
| 106 |
|
| 107 |
>[!Tip]
|
docs/rlhf.md
CHANGED
|
@@ -33,3 +33,12 @@ datasets:
|
|
| 33 |
```yaml
|
| 34 |
rl: ipo
|
| 35 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
```yaml
|
| 34 |
rl: ipo
|
| 35 |
```
|
| 36 |
+
|
| 37 |
+
#### Trl autounwrap for peft
|
| 38 |
+
|
| 39 |
+
Trl supports autounwrapping peft models, so that a ref model does not need to be additionally loaded, leading to less VRAM needed. This is on by default. To turn it off, pass the following config.
|
| 40 |
+
|
| 41 |
+
```yaml
|
| 42 |
+
# load ref model when adapter training.
|
| 43 |
+
rl_adapter_ref_model: true
|
| 44 |
+
```
|
src/axolotl/train.py
CHANGED
|
@@ -63,10 +63,15 @@ def train(
|
|
| 63 |
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
|
| 64 |
model_ref = None
|
| 65 |
if cfg.rl:
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
safe_serialization = cfg.save_safetensors is True
|
| 72 |
|
|
|
|
| 63 |
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
|
| 64 |
model_ref = None
|
| 65 |
if cfg.rl:
|
| 66 |
+
if cfg.adapter and not cfg.rl_adapter_ref_model:
|
| 67 |
+
# use built-in trl autounwrap
|
| 68 |
+
LOG.debug("Passing model_ref: None to RL trainer")
|
| 69 |
+
model_ref = None # explicit setting to None
|
| 70 |
+
else:
|
| 71 |
+
# load the model again for model_ref/baseline
|
| 72 |
+
model_ref, _ = load_model(
|
| 73 |
+
cfg, tokenizer, inference=cli_args.inference, reference_model=True
|
| 74 |
+
)
|
| 75 |
|
| 76 |
safe_serialization = cfg.save_safetensors is True
|
| 77 |
|