Merge branch 'main' into quadratic-warmup
Browse files- .github/workflows/base.yml +2 -1
- .github/workflows/main.yml +3 -2
- .github/workflows/tests.yml +1 -0
- .pre-commit-config.yaml +1 -1
- README.md +52 -10
- data/README.md +4 -4
- docker/Dockerfile-base +2 -2
- examples/openllama-3b/config.yml +6 -5
- examples/pythia-12b/README.md +9 -0
- examples/pythia-12b/config.yml +49 -0
- examples/redpajama/config-3b.yml +1 -1
- requirements.txt +1 -0
- scripts/finetune.py +40 -17
- src/axolotl/datasets.py +1 -0
- src/axolotl/prompt_strategies/alpaca_chat.py +42 -6
- src/axolotl/prompt_strategies/alpaca_instruct.py +10 -1
- src/axolotl/prompt_strategies/alpaca_w_system.py +120 -0
- src/axolotl/prompt_tokenizers.py +21 -17
- src/axolotl/prompters.py +29 -37
- src/axolotl/utils/callbacks.py +38 -1
- src/axolotl/utils/data.py +144 -11
- src/axolotl/utils/models.py +33 -7
- src/axolotl/utils/tokenization.py +2 -0
- src/axolotl/utils/trainer.py +20 -1
- src/axolotl/utils/validation.py +43 -1
- tests/test_prompt_tokenizers.py +83 -3
- tests/test_prompters.py +68 -1
- tests/test_tokenizers.py +31 -0
- tests/test_validation.py +101 -0
.github/workflows/base.yml
CHANGED
|
@@ -12,6 +12,7 @@ jobs:
|
|
| 12 |
# this job needs to be run on self-hosted GPU runners...
|
| 13 |
runs-on: self-hosted
|
| 14 |
strategy:
|
|
|
|
| 15 |
matrix:
|
| 16 |
include:
|
| 17 |
- cuda: "118"
|
|
@@ -25,7 +26,7 @@ jobs:
|
|
| 25 |
pytorch: 2.0.0
|
| 26 |
axolotl_extras:
|
| 27 |
- cuda: "117"
|
| 28 |
-
cuda_version: 11.7.
|
| 29 |
python_version: "3.9"
|
| 30 |
pytorch: 1.13.1
|
| 31 |
axolotl_extras:
|
|
|
|
| 12 |
# this job needs to be run on self-hosted GPU runners...
|
| 13 |
runs-on: self-hosted
|
| 14 |
strategy:
|
| 15 |
+
fail-fast: false
|
| 16 |
matrix:
|
| 17 |
include:
|
| 18 |
- cuda: "118"
|
|
|
|
| 26 |
pytorch: 2.0.0
|
| 27 |
axolotl_extras:
|
| 28 |
- cuda: "117"
|
| 29 |
+
cuda_version: 11.7.1
|
| 30 |
python_version: "3.9"
|
| 31 |
pytorch: 1.13.1
|
| 32 |
axolotl_extras:
|
.github/workflows/main.yml
CHANGED
|
@@ -11,6 +11,7 @@ jobs:
|
|
| 11 |
if: github.repository_owner == 'OpenAccess-AI-Collective'
|
| 12 |
# this job needs to be run on self-hosted GPU runners...
|
| 13 |
strategy:
|
|
|
|
| 14 |
matrix:
|
| 15 |
include:
|
| 16 |
- cuda: cu118
|
|
@@ -29,7 +30,7 @@ jobs:
|
|
| 29 |
pytorch: 2.0.0
|
| 30 |
axolotl_extras: gptq
|
| 31 |
- cuda: cu117
|
| 32 |
-
cuda_version: 11.7.
|
| 33 |
python_version: "3.9"
|
| 34 |
pytorch: 1.13.1
|
| 35 |
axolotl_extras:
|
|
@@ -84,7 +85,7 @@ jobs:
|
|
| 84 |
pytorch: 2.0.0
|
| 85 |
axolotl_extras: gptq
|
| 86 |
- cuda: cu117
|
| 87 |
-
cuda_version: 11.7.
|
| 88 |
python_version: "3.9"
|
| 89 |
pytorch: 1.13.1
|
| 90 |
axolotl_extras:
|
|
|
|
| 11 |
if: github.repository_owner == 'OpenAccess-AI-Collective'
|
| 12 |
# this job needs to be run on self-hosted GPU runners...
|
| 13 |
strategy:
|
| 14 |
+
fail-fast: false
|
| 15 |
matrix:
|
| 16 |
include:
|
| 17 |
- cuda: cu118
|
|
|
|
| 30 |
pytorch: 2.0.0
|
| 31 |
axolotl_extras: gptq
|
| 32 |
- cuda: cu117
|
| 33 |
+
cuda_version: 11.7.1
|
| 34 |
python_version: "3.9"
|
| 35 |
pytorch: 1.13.1
|
| 36 |
axolotl_extras:
|
|
|
|
| 85 |
pytorch: 2.0.0
|
| 86 |
axolotl_extras: gptq
|
| 87 |
- cuda: cu117
|
| 88 |
+
cuda_version: 11.7.1
|
| 89 |
python_version: "3.9"
|
| 90 |
pytorch: 1.13.1
|
| 91 |
axolotl_extras:
|
.github/workflows/tests.yml
CHANGED
|
@@ -7,6 +7,7 @@ jobs:
|
|
| 7 |
test:
|
| 8 |
runs-on: ubuntu-latest
|
| 9 |
strategy:
|
|
|
|
| 10 |
matrix:
|
| 11 |
python_version: ["3.9", "3.10"]
|
| 12 |
timeout-minutes: 10
|
|
|
|
| 7 |
test:
|
| 8 |
runs-on: ubuntu-latest
|
| 9 |
strategy:
|
| 10 |
+
fail-fast: false
|
| 11 |
matrix:
|
| 12 |
python_version: ["3.9", "3.10"]
|
| 13 |
timeout-minutes: 10
|
.pre-commit-config.yaml
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
default_language_version:
|
| 2 |
-
python: python3
|
| 3 |
|
| 4 |
repos:
|
| 5 |
- repo: https://github.com/pre-commit/pre-commit-hooks
|
|
|
|
| 1 |
default_language_version:
|
| 2 |
+
python: python3
|
| 3 |
|
| 4 |
repos:
|
| 5 |
- repo: https://github.com/pre-commit/pre-commit-hooks
|
README.md
CHANGED
|
@@ -138,7 +138,7 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
|
| 138 |
```json
|
| 139 |
{"instruction": "...", "input": "...", "output": "..."}
|
| 140 |
```
|
| 141 |
-
- `sharegpt`: conversations
|
| 142 |
```json
|
| 143 |
{"conversations": [{"from": "...", "value": "..."}]}
|
| 144 |
```
|
|
@@ -195,6 +195,10 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
|
| 195 |
```json
|
| 196 |
{"message_1": "...", "message_2": "..."}
|
| 197 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
- `context_qa`: in context question answering from an article
|
| 199 |
```json
|
| 200 |
{"article": "...", "question": "...", "answer": "..."}
|
|
@@ -233,7 +237,7 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
|
| 233 |
#### How to add custom prompts
|
| 234 |
|
| 235 |
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
|
| 236 |
-
2. Use your custom file name as the dataset type
|
| 237 |
|
| 238 |
Optionally, download some datasets, see [data/README.md](data/README.md)
|
| 239 |
|
|
@@ -251,10 +255,18 @@ See sample configs in [configs](configs) folder or [examples](examples) for quic
|
|
| 251 |
|
| 252 |
- dataset
|
| 253 |
```yaml
|
|
|
|
|
|
|
|
|
|
| 254 |
datasets:
|
| 255 |
-
- path: vicgalle/alpaca-gpt4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
type: alpaca # format from earlier
|
| 257 |
-
sequence_len: 2048 # max token length / prompt
|
| 258 |
```
|
| 259 |
|
| 260 |
- loading
|
|
@@ -264,6 +276,8 @@ See sample configs in [configs](configs) folder or [examples](examples) for quic
|
|
| 264 |
bf16: true # require >=ampere
|
| 265 |
fp16: true
|
| 266 |
tf32: true # require >=ampere
|
|
|
|
|
|
|
| 267 |
```
|
| 268 |
Note: Repo does not do 4-bit quantization.
|
| 269 |
|
|
@@ -300,6 +314,8 @@ model_type: AutoModelForCausalLM
|
|
| 300 |
tokenizer_type: AutoTokenizer
|
| 301 |
# Trust remote code for untrusted source
|
| 302 |
trust_remote_code:
|
|
|
|
|
|
|
| 303 |
|
| 304 |
# whether you are training a 4-bit GPTQ quantized model
|
| 305 |
gptq: true
|
|
@@ -320,10 +336,10 @@ tf32: true # require >=ampere
|
|
| 320 |
|
| 321 |
# a list of one or more datasets to finetune the model with
|
| 322 |
datasets:
|
| 323 |
-
#
|
| 324 |
- path: vicgalle/alpaca-gpt4
|
| 325 |
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
| 326 |
-
type: alpaca # format
|
| 327 |
data_files: # path to source data files
|
| 328 |
shards: # number of shards to split data into
|
| 329 |
|
|
@@ -332,6 +348,8 @@ datasets:
|
|
| 332 |
dataset_prepared_path: data/last_run_prepared
|
| 333 |
# push prepared dataset to hub
|
| 334 |
push_dataset_to_hub: # repo path
|
|
|
|
|
|
|
| 335 |
# whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
|
| 336 |
# required to be true when used in combination with `push_dataset_to_hub`
|
| 337 |
hf_use_auth_token: # boolean
|
|
@@ -420,7 +438,15 @@ log_sweep_max_lr:
|
|
| 420 |
optimizer:
|
| 421 |
# specify weight decay
|
| 422 |
weight_decay:
|
| 423 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
# whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
| 425 |
xformers_attention:
|
| 426 |
# whether to use flash attention patch https://github.com/HazyResearch/flash-attention:
|
|
@@ -500,16 +526,16 @@ Pass the appropriate flag to the train command:
|
|
| 500 |
|
| 501 |
- Pretrained LORA:
|
| 502 |
```bash
|
| 503 |
-
--inference --lora_model_dir
|
| 504 |
```
|
| 505 |
- Full weights finetune:
|
| 506 |
```bash
|
| 507 |
-
--inference --base_model
|
| 508 |
```
|
| 509 |
- Full weights finetune w/ a prompt from a text file:
|
| 510 |
```bash
|
| 511 |
cat /tmp/prompt.txt | python scripts/finetune.py configs/your_config.yml \
|
| 512 |
-
--base_model
|
| 513 |
```
|
| 514 |
|
| 515 |
### Merge LORA to base
|
|
@@ -520,6 +546,12 @@ Add below flag to train command above
|
|
| 520 |
--merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
|
| 521 |
```
|
| 522 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 523 |
## Common Errors 🧰
|
| 524 |
|
| 525 |
> Cuda out of memory
|
|
@@ -552,6 +584,16 @@ Building something cool with Axolotl? Consider adding a badge to your model card
|
|
| 552 |
|
| 553 |
[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)
|
| 554 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 555 |
## Contributing 🤝
|
| 556 |
|
| 557 |
Bugs? Please check for open issue else create a new [Issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues/new).
|
|
|
|
| 138 |
```json
|
| 139 |
{"instruction": "...", "input": "...", "output": "..."}
|
| 140 |
```
|
| 141 |
+
- `sharegpt:chat`: conversations
|
| 142 |
```json
|
| 143 |
{"conversations": [{"from": "...", "value": "..."}]}
|
| 144 |
```
|
|
|
|
| 195 |
```json
|
| 196 |
{"message_1": "...", "message_2": "..."}
|
| 197 |
```
|
| 198 |
+
- `alpaca_w_system.load_open_orca`: support for open orca datasets with included system prompts, instruct
|
| 199 |
+
```json
|
| 200 |
+
{"system_prompt": "...", "question": "...", "response": "..."}
|
| 201 |
+
```
|
| 202 |
- `context_qa`: in context question answering from an article
|
| 203 |
```json
|
| 204 |
{"article": "...", "question": "...", "answer": "..."}
|
|
|
|
| 237 |
#### How to add custom prompts
|
| 238 |
|
| 239 |
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
|
| 240 |
+
2. Use your custom file name as the dataset type `<prompt_strategies_file>.load_<load_fn>`.
|
| 241 |
|
| 242 |
Optionally, download some datasets, see [data/README.md](data/README.md)
|
| 243 |
|
|
|
|
| 255 |
|
| 256 |
- dataset
|
| 257 |
```yaml
|
| 258 |
+
sequence_len: 2048 # max token length for prompt
|
| 259 |
+
|
| 260 |
+
# huggingface repo
|
| 261 |
datasets:
|
| 262 |
+
- path: vicgalle/alpaca-gpt4
|
| 263 |
+
type: alpaca # format from earlier
|
| 264 |
+
|
| 265 |
+
# local
|
| 266 |
+
datasets:
|
| 267 |
+
- path: json
|
| 268 |
+
data_files: data.jsonl # or json
|
| 269 |
type: alpaca # format from earlier
|
|
|
|
| 270 |
```
|
| 271 |
|
| 272 |
- loading
|
|
|
|
| 276 |
bf16: true # require >=ampere
|
| 277 |
fp16: true
|
| 278 |
tf32: true # require >=ampere
|
| 279 |
+
bfloat16: true # require >=ampere, use instead of bf16 when you don't want AMP (automatic mixed precision)
|
| 280 |
+
float16: true # use instead of fp16 when you don't want AMP
|
| 281 |
```
|
| 282 |
Note: Repo does not do 4-bit quantization.
|
| 283 |
|
|
|
|
| 314 |
tokenizer_type: AutoTokenizer
|
| 315 |
# Trust remote code for untrusted source
|
| 316 |
trust_remote_code:
|
| 317 |
+
# use_fast option for tokenizer loading from_pretrained, default to True
|
| 318 |
+
tokenizer_use_fast:
|
| 319 |
|
| 320 |
# whether you are training a 4-bit GPTQ quantized model
|
| 321 |
gptq: true
|
|
|
|
| 336 |
|
| 337 |
# a list of one or more datasets to finetune the model with
|
| 338 |
datasets:
|
| 339 |
+
# hf dataset repo | "json" for local dataset, make sure to fill data_files
|
| 340 |
- path: vicgalle/alpaca-gpt4
|
| 341 |
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
| 342 |
+
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
| 343 |
data_files: # path to source data files
|
| 344 |
shards: # number of shards to split data into
|
| 345 |
|
|
|
|
| 348 |
dataset_prepared_path: data/last_run_prepared
|
| 349 |
# push prepared dataset to hub
|
| 350 |
push_dataset_to_hub: # repo path
|
| 351 |
+
# push checkpoints to hub
|
| 352 |
+
hub_model_id: # repo path
|
| 353 |
# whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
|
| 354 |
# required to be true when used in combination with `push_dataset_to_hub`
|
| 355 |
hf_use_auth_token: # boolean
|
|
|
|
| 438 |
optimizer:
|
| 439 |
# specify weight decay
|
| 440 |
weight_decay:
|
| 441 |
+
# adamw hyperparams
|
| 442 |
+
adam_beta1:
|
| 443 |
+
adam_beta2:
|
| 444 |
+
adam_epsilon:
|
| 445 |
+
# Gradient clipping max norm
|
| 446 |
+
max_grad_norm:
|
| 447 |
+
|
| 448 |
+
# whether to bettertransformers
|
| 449 |
+
flash_optimum:
|
| 450 |
# whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
| 451 |
xformers_attention:
|
| 452 |
# whether to use flash attention patch https://github.com/HazyResearch/flash-attention:
|
|
|
|
| 526 |
|
| 527 |
- Pretrained LORA:
|
| 528 |
```bash
|
| 529 |
+
--inference --lora_model_dir="./lora-output-dir"
|
| 530 |
```
|
| 531 |
- Full weights finetune:
|
| 532 |
```bash
|
| 533 |
+
--inference --base_model="./completed-model"
|
| 534 |
```
|
| 535 |
- Full weights finetune w/ a prompt from a text file:
|
| 536 |
```bash
|
| 537 |
cat /tmp/prompt.txt | python scripts/finetune.py configs/your_config.yml \
|
| 538 |
+
--base_model="./completed-model" --inference --prompter=None --load_in_8bit=True
|
| 539 |
```
|
| 540 |
|
| 541 |
### Merge LORA to base
|
|
|
|
| 546 |
--merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
|
| 547 |
```
|
| 548 |
|
| 549 |
+
If you run out of CUDA memory, you can try to merge in system RAM with
|
| 550 |
+
|
| 551 |
+
```bash
|
| 552 |
+
CUDA_VISIBLE_DEVICES="" python3 scripts/finetune.py ...
|
| 553 |
+
```
|
| 554 |
+
|
| 555 |
## Common Errors 🧰
|
| 556 |
|
| 557 |
> Cuda out of memory
|
|
|
|
| 584 |
|
| 585 |
[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)
|
| 586 |
|
| 587 |
+
## Community Showcase
|
| 588 |
+
|
| 589 |
+
Open Access AI Collective
|
| 590 |
+
- [Minotaur 13b](https://huggingface.co/openaccess-ai-collective/minotaur-13b)
|
| 591 |
+
- [Manticore 13b](https://huggingface.co/openaccess-ai-collective/manticore-13b)
|
| 592 |
+
- [Hippogriff 30b](https://huggingface.co/openaccess-ai-collective/hippogriff-30b-chat)
|
| 593 |
+
|
| 594 |
+
PocketDoc Labs
|
| 595 |
+
- [Dan's PersonalityEngine 13b LoRA](https://huggingface.co/PocketDoc/Dans-PersonalityEngine-13b-LoRA)
|
| 596 |
+
|
| 597 |
## Contributing 🤝
|
| 598 |
|
| 599 |
Bugs? Please check for open issue else create a new [Issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues/new).
|
data/README.md
CHANGED
|
@@ -10,10 +10,10 @@ curl https://github.com/teknium1/GPTeacher/blob/main/Roleplay/roleplay-similarit
|
|
| 10 |
## Convert the JSON data files to JSONL.
|
| 11 |
|
| 12 |
```shell
|
| 13 |
-
python3 ./scripts/alpaca_json_to_jsonl.py --
|
| 14 |
-
python3 ./scripts/alpaca_json_to_jsonl.py --
|
| 15 |
-
python3 ./scripts/alpaca_json_to_jsonl.py --
|
| 16 |
-
python3 ./scripts/alpaca_json_to_jsonl.py --
|
| 17 |
```
|
| 18 |
---
|
| 19 |
|
|
|
|
| 10 |
## Convert the JSON data files to JSONL.
|
| 11 |
|
| 12 |
```shell
|
| 13 |
+
python3 ./scripts/alpaca_json_to_jsonl.py --file data/alpaca_data_gpt4.json --output data/alpaca_data_gpt4.jsonl
|
| 14 |
+
python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/vicuna_cleaned.json --output data/vicuna_cleaned.jsonl
|
| 15 |
+
python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/roleplay-similarity_0.6-instruct-dataset.json --output data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
| 16 |
+
python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/gpt4-instruct-similarity-0.6-dataset.json --output data/gpt4-instruct-similarity-0.6-dataset.jsonl
|
| 17 |
```
|
| 18 |
---
|
| 19 |
|
docker/Dockerfile-base
CHANGED
|
@@ -77,7 +77,7 @@ FROM base-builder
|
|
| 77 |
RUN python3 -m pip uninstall -y apex
|
| 78 |
RUN git clone https://github.com/NVIDIA/apex
|
| 79 |
# `MAX_JOBS=1` disables parallel building to avoid cpu memory OOM when building image on GitHub Action (standard) runners
|
| 80 |
-
RUN cd apex && MAX_JOBS=1 python3 -m pip install
|
| 81 |
|
| 82 |
RUN mkdir -p /workspace/builds
|
| 83 |
COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes
|
|
@@ -97,4 +97,4 @@ RUN cd /workspace/builds/bitsandbytes && python3 setup.py install
|
|
| 97 |
RUN git lfs install --skip-repo
|
| 98 |
RUN pip3 install awscli && \
|
| 99 |
# The base image ships with `pydantic==1.8.2` which is not working
|
| 100 |
-
pip3 install -U --no-cache-dir pydantic
|
|
|
|
| 77 |
RUN python3 -m pip uninstall -y apex
|
| 78 |
RUN git clone https://github.com/NVIDIA/apex
|
| 79 |
# `MAX_JOBS=1` disables parallel building to avoid cpu memory OOM when building image on GitHub Action (standard) runners
|
| 80 |
+
RUN cd apex && MAX_JOBS=1 python3 -m pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
|
| 81 |
|
| 82 |
RUN mkdir -p /workspace/builds
|
| 83 |
COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes
|
|
|
|
| 97 |
RUN git lfs install --skip-repo
|
| 98 |
RUN pip3 install awscli && \
|
| 99 |
# The base image ships with `pydantic==1.8.2` which is not working
|
| 100 |
+
pip3 install -U --no-cache-dir pydantic==1.10.10
|
examples/openllama-3b/config.yml
CHANGED
|
@@ -26,17 +26,18 @@ wandb_watch:
|
|
| 26 |
wandb_run_id:
|
| 27 |
wandb_log_model:
|
| 28 |
output_dir: ./openllama-out
|
| 29 |
-
|
| 30 |
-
micro_batch_size:
|
| 31 |
num_epochs: 3
|
| 32 |
optimizer: adamw_bnb_8bit
|
| 33 |
torchdistx_path:
|
| 34 |
lr_scheduler: cosine
|
| 35 |
-
learning_rate: 0.
|
| 36 |
train_on_inputs: false
|
| 37 |
group_by_length: false
|
|
|
|
| 38 |
bf16: false
|
| 39 |
-
fp16:
|
| 40 |
tf32: false
|
| 41 |
gradient_checkpointing: true
|
| 42 |
early_stopping_patience:
|
|
@@ -52,7 +53,7 @@ eval_steps: 50
|
|
| 52 |
save_steps:
|
| 53 |
debug:
|
| 54 |
deepspeed:
|
| 55 |
-
weight_decay: 0.
|
| 56 |
fsdp:
|
| 57 |
fsdp_config:
|
| 58 |
special_tokens:
|
|
|
|
| 26 |
wandb_run_id:
|
| 27 |
wandb_log_model:
|
| 28 |
output_dir: ./openllama-out
|
| 29 |
+
gradient_accumulation_steps: 1
|
| 30 |
+
micro_batch_size: 1
|
| 31 |
num_epochs: 3
|
| 32 |
optimizer: adamw_bnb_8bit
|
| 33 |
torchdistx_path:
|
| 34 |
lr_scheduler: cosine
|
| 35 |
+
learning_rate: 0.00001
|
| 36 |
train_on_inputs: false
|
| 37 |
group_by_length: false
|
| 38 |
+
float16: true
|
| 39 |
bf16: false
|
| 40 |
+
fp16: false
|
| 41 |
tf32: false
|
| 42 |
gradient_checkpointing: true
|
| 43 |
early_stopping_patience:
|
|
|
|
| 53 |
save_steps:
|
| 54 |
debug:
|
| 55 |
deepspeed:
|
| 56 |
+
weight_decay: 0.1
|
| 57 |
fsdp:
|
| 58 |
fsdp_config:
|
| 59 |
special_tokens:
|
examples/pythia-12b/README.md
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Pythia 12B
|
| 2 |
+
|
| 3 |
+
- Single-GPU A100 only (?)
|
| 4 |
+
|
| 5 |
+
```shell
|
| 6 |
+
python scripts/finetune.py examples/pythia-12b/config.yml
|
| 7 |
+
```
|
| 8 |
+
|
| 9 |
+
⚠️ Multiple-GPU A100 - Doesn't seem to work with multi-gpu without causing OOM! ⚠️
|
examples/pythia-12b/config.yml
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
base_model: EleutherAI/pythia-12b-deduped
|
| 2 |
+
base_model_config: EleutherAI/pythia-12b-deduped
|
| 3 |
+
base_model_ignore_patterns: pytorch* # prefer safetensors
|
| 4 |
+
model_type: GPTNeoXForCausalLM
|
| 5 |
+
tokenizer_type: AutoTokenizer
|
| 6 |
+
load_in_8bit: false
|
| 7 |
+
load_in_4bit: false
|
| 8 |
+
gptq: false
|
| 9 |
+
device_map: auto
|
| 10 |
+
datasets:
|
| 11 |
+
- path: vicgalle/alpaca-gpt4
|
| 12 |
+
type: alpaca
|
| 13 |
+
dataset_prepared_path: last_run_prepared
|
| 14 |
+
val_set_size: 0.05
|
| 15 |
+
adapter:
|
| 16 |
+
lora_model_dir:
|
| 17 |
+
sequence_len: 2048
|
| 18 |
+
max_packed_sequence_len: 2048
|
| 19 |
+
lora_r: 64
|
| 20 |
+
lora_alpha: 32
|
| 21 |
+
lora_dropout: 0.0
|
| 22 |
+
lora_target_modules:
|
| 23 |
+
lora_target_linear: true
|
| 24 |
+
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
| 25 |
+
wandb_project:
|
| 26 |
+
wandb_watch:
|
| 27 |
+
wandb_run_id:
|
| 28 |
+
wandb_log_model:
|
| 29 |
+
output_dir: ./pythia-12b
|
| 30 |
+
gradient_accumulation_steps: 1
|
| 31 |
+
micro_batch_size: 1
|
| 32 |
+
num_epochs: 5
|
| 33 |
+
learning_rate: 0.00003
|
| 34 |
+
optimizer: adamw_bnb_8bit
|
| 35 |
+
lr_scheduler: cosine
|
| 36 |
+
train_on_inputs: false
|
| 37 |
+
group_by_length: false
|
| 38 |
+
bf16: false
|
| 39 |
+
fp16: false
|
| 40 |
+
float16: true
|
| 41 |
+
tf32: true
|
| 42 |
+
flash_optimum: true
|
| 43 |
+
early_stopping_patience:
|
| 44 |
+
resume_from_checkpoint:
|
| 45 |
+
local_rank:
|
| 46 |
+
gradient_checkpointing: true
|
| 47 |
+
fsdp:
|
| 48 |
+
fsdp_config:
|
| 49 |
+
collator_pad_to_longest: true
|
examples/redpajama/config-3b.yml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
base_model: togethercomputer/RedPajama-INCITE-Chat-3B-v1
|
| 2 |
base_model_config: togethercomputer/RedPajama-INCITE-Chat-3B-v1
|
| 3 |
model_type: GPTNeoXForCausalLM
|
| 4 |
-
tokenizer_type:
|
| 5 |
trust_remote_code:
|
| 6 |
load_in_8bit: false
|
| 7 |
datasets:
|
|
|
|
| 1 |
base_model: togethercomputer/RedPajama-INCITE-Chat-3B-v1
|
| 2 |
base_model_config: togethercomputer/RedPajama-INCITE-Chat-3B-v1
|
| 3 |
model_type: GPTNeoXForCausalLM
|
| 4 |
+
tokenizer_type: AutoTokenizer
|
| 5 |
trust_remote_code:
|
| 6 |
load_in_8bit: false
|
| 7 |
datasets:
|
requirements.txt
CHANGED
|
@@ -11,6 +11,7 @@ sentencepiece
|
|
| 11 |
wandb
|
| 12 |
einops
|
| 13 |
xformers
|
|
|
|
| 14 |
# qlora things
|
| 15 |
bert-score==0.3.13
|
| 16 |
evaluate==0.4.0
|
|
|
|
| 11 |
wandb
|
| 12 |
einops
|
| 13 |
xformers
|
| 14 |
+
optimum
|
| 15 |
# qlora things
|
| 16 |
bert-score==0.3.13
|
| 17 |
evaluate==0.4.0
|
scripts/finetune.py
CHANGED
|
@@ -12,13 +12,14 @@ from typing import Any, Dict, List, Optional, Union
|
|
| 12 |
import fire
|
| 13 |
import torch
|
| 14 |
import yaml
|
|
|
|
|
|
|
|
|
|
| 15 |
from transformers import GenerationConfig, TextStreamer
|
| 16 |
|
| 17 |
-
from axolotl.utils.data import load_prepare_datasets
|
| 18 |
from axolotl.utils.dict import DictDefault
|
| 19 |
from axolotl.utils.models import load_model, load_tokenizer
|
| 20 |
-
|
| 21 |
-
# add src to the pythonpath so we don't need to pip install this
|
| 22 |
from axolotl.utils.tokenization import check_dataset_labels
|
| 23 |
from axolotl.utils.trainer import setup_trainer
|
| 24 |
from axolotl.utils.validation import validate_config
|
|
@@ -63,7 +64,7 @@ def get_multi_line_input() -> Optional[str]:
|
|
| 63 |
return instruction
|
| 64 |
|
| 65 |
|
| 66 |
-
def do_inference(cfg, model, tokenizer, prompter
|
| 67 |
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
| 68 |
|
| 69 |
for token, symbol in default_tokens.items():
|
|
@@ -217,9 +218,20 @@ def train(
|
|
| 217 |
if (
|
| 218 |
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
|
| 219 |
): # don't need to load dataset for these
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
|
| 224 |
if cfg.debug or "debug" in kwargs:
|
| 225 |
logging.info("check_dataset_labels...")
|
|
@@ -257,13 +269,13 @@ def train(
|
|
| 257 |
|
| 258 |
if cfg.inference:
|
| 259 |
logging.info("calling do_inference function")
|
| 260 |
-
|
| 261 |
if "prompter" in kwargs:
|
| 262 |
if kwargs["prompter"] == "None":
|
| 263 |
-
|
| 264 |
else:
|
| 265 |
-
|
| 266 |
-
do_inference(cfg, model, tokenizer,
|
| 267 |
return
|
| 268 |
|
| 269 |
if "shard" in kwargs:
|
|
@@ -285,12 +297,15 @@ def train(
|
|
| 285 |
|
| 286 |
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
| 287 |
if cfg.local_rank == 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
signal.signal(
|
| 289 |
-
signal.SIGINT,
|
| 290 |
-
lambda signal, frame: (
|
| 291 |
-
model.save_pretrained(cfg.output_dir),
|
| 292 |
-
sys.exit(0),
|
| 293 |
-
),
|
| 294 |
)
|
| 295 |
|
| 296 |
logging.info("Starting trainer...")
|
|
@@ -313,13 +328,21 @@ def train(
|
|
| 313 |
|
| 314 |
if not Path(cfg.output_dir).is_dir():
|
| 315 |
os.makedirs(cfg.output_dir, exist_ok=True)
|
| 316 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
|
| 318 |
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
| 319 |
|
| 320 |
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
| 321 |
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
| 322 |
if cfg.local_rank == 0:
|
|
|
|
|
|
|
| 323 |
model.save_pretrained(cfg.output_dir)
|
| 324 |
|
| 325 |
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
|
|
|
|
| 12 |
import fire
|
| 13 |
import torch
|
| 14 |
import yaml
|
| 15 |
+
|
| 16 |
+
# add src to the pythonpath so we don't need to pip install this
|
| 17 |
+
from optimum.bettertransformer import BetterTransformer
|
| 18 |
from transformers import GenerationConfig, TextStreamer
|
| 19 |
|
| 20 |
+
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
| 21 |
from axolotl.utils.dict import DictDefault
|
| 22 |
from axolotl.utils.models import load_model, load_tokenizer
|
|
|
|
|
|
|
| 23 |
from axolotl.utils.tokenization import check_dataset_labels
|
| 24 |
from axolotl.utils.trainer import setup_trainer
|
| 25 |
from axolotl.utils.validation import validate_config
|
|
|
|
| 64 |
return instruction
|
| 65 |
|
| 66 |
|
| 67 |
+
def do_inference(cfg, model, tokenizer, prompter: Optional[str]):
|
| 68 |
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
| 69 |
|
| 70 |
for token, symbol in default_tokens.items():
|
|
|
|
| 218 |
if (
|
| 219 |
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
|
| 220 |
): # don't need to load dataset for these
|
| 221 |
+
if not cfg.pretraining_dataset:
|
| 222 |
+
train_dataset, eval_dataset = load_prepare_datasets(
|
| 223 |
+
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
| 224 |
+
)
|
| 225 |
+
else:
|
| 226 |
+
train_dataset = load_pretraining_dataset(
|
| 227 |
+
cfg.pretraining_dataset,
|
| 228 |
+
tokenizer,
|
| 229 |
+
max_tokens=cfg.sequence_len,
|
| 230 |
+
seed=cfg.seed,
|
| 231 |
+
)
|
| 232 |
+
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
| 233 |
+
train_dataset = train_dataset.with_format("torch")
|
| 234 |
+
eval_dataset = None
|
| 235 |
|
| 236 |
if cfg.debug or "debug" in kwargs:
|
| 237 |
logging.info("check_dataset_labels...")
|
|
|
|
| 269 |
|
| 270 |
if cfg.inference:
|
| 271 |
logging.info("calling do_inference function")
|
| 272 |
+
prompter: Optional[str] = "AlpacaPrompter"
|
| 273 |
if "prompter" in kwargs:
|
| 274 |
if kwargs["prompter"] == "None":
|
| 275 |
+
prompter = None
|
| 276 |
else:
|
| 277 |
+
prompter = kwargs["prompter"]
|
| 278 |
+
do_inference(cfg, model, tokenizer, prompter=prompter)
|
| 279 |
return
|
| 280 |
|
| 281 |
if "shard" in kwargs:
|
|
|
|
| 297 |
|
| 298 |
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
| 299 |
if cfg.local_rank == 0:
|
| 300 |
+
|
| 301 |
+
def terminate_handler(_, __, model):
|
| 302 |
+
if cfg.flash_optimum:
|
| 303 |
+
model = BetterTransformer.reverse(model)
|
| 304 |
+
model.save_pretrained(cfg.output_dir)
|
| 305 |
+
sys.exit(0)
|
| 306 |
+
|
| 307 |
signal.signal(
|
| 308 |
+
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
)
|
| 310 |
|
| 311 |
logging.info("Starting trainer...")
|
|
|
|
| 328 |
|
| 329 |
if not Path(cfg.output_dir).is_dir():
|
| 330 |
os.makedirs(cfg.output_dir, exist_ok=True)
|
| 331 |
+
if cfg.flash_optimum:
|
| 332 |
+
with torch.backends.cuda.sdp_kernel(
|
| 333 |
+
enable_flash=True, enable_math=True, enable_mem_efficient=True
|
| 334 |
+
):
|
| 335 |
+
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
| 336 |
+
else:
|
| 337 |
+
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
| 338 |
|
| 339 |
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
| 340 |
|
| 341 |
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
| 342 |
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
| 343 |
if cfg.local_rank == 0:
|
| 344 |
+
if cfg.flash_optimum:
|
| 345 |
+
model = BetterTransformer.reverse(model)
|
| 346 |
model.save_pretrained(cfg.output_dir)
|
| 347 |
|
| 348 |
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
|
src/axolotl/datasets.py
CHANGED
|
@@ -126,6 +126,7 @@ class ConstantLengthDataset(IterableDataset):
|
|
| 126 |
buffer_len = 0
|
| 127 |
|
| 128 |
if example:
|
|
|
|
| 129 |
# just going to drop data points that are too long
|
| 130 |
if len(example["input_ids"]) <= self.seq_length:
|
| 131 |
input_ids = example["input_ids"]
|
|
|
|
| 126 |
buffer_len = 0
|
| 127 |
|
| 128 |
if example:
|
| 129 |
+
# FIXME
|
| 130 |
# just going to drop data points that are too long
|
| 131 |
if len(example["input_ids"]) <= self.seq_length:
|
| 132 |
input_ids = example["input_ids"]
|
src/axolotl/prompt_strategies/alpaca_chat.py
CHANGED
|
@@ -6,7 +6,7 @@ from axolotl.prompt_tokenizers import (
|
|
| 6 |
AlpacaPromptTokenizingStrategy,
|
| 7 |
InstructionPromptTokenizingStrategy,
|
| 8 |
)
|
| 9 |
-
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
| 10 |
|
| 11 |
|
| 12 |
def load(tokenizer, cfg):
|
|
@@ -20,11 +20,38 @@ def load(tokenizer, cfg):
|
|
| 20 |
|
| 21 |
class AlpacaConcisePrompter(AlpacaPrompter):
|
| 22 |
"""
|
| 23 |
-
Alpaca Prompter extending the system prompt to ask for concise answers
|
| 24 |
"""
|
| 25 |
|
| 26 |
-
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context.
|
| 27 |
-
system_no_input_prompt = "Below is an instruction that describes a task.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
|
@@ -64,7 +91,7 @@ def load_concise(tokenizer, cfg):
|
|
| 64 |
|
| 65 |
def load_qa(tokenizer, cfg):
|
| 66 |
return AlpacaQAPromptTokenizingStrategy(
|
| 67 |
-
|
| 68 |
tokenizer,
|
| 69 |
cfg.train_on_inputs,
|
| 70 |
cfg.sequence_len,
|
|
@@ -73,7 +100,16 @@ def load_qa(tokenizer, cfg):
|
|
| 73 |
|
| 74 |
def load_camel_ai(tokenizer, cfg):
|
| 75 |
return CamelAIPromptTokenizingStrategy(
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
tokenizer,
|
| 78 |
cfg.train_on_inputs,
|
| 79 |
cfg.sequence_len,
|
|
|
|
| 6 |
AlpacaPromptTokenizingStrategy,
|
| 7 |
InstructionPromptTokenizingStrategy,
|
| 8 |
)
|
| 9 |
+
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
|
| 10 |
|
| 11 |
|
| 12 |
def load(tokenizer, cfg):
|
|
|
|
| 20 |
|
| 21 |
class AlpacaConcisePrompter(AlpacaPrompter):
|
| 22 |
"""
|
| 23 |
+
Alpaca Prompter extending the system prompt to ask for concise chat-instruct answers
|
| 24 |
"""
|
| 25 |
|
| 26 |
+
system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n"
|
| 27 |
+
system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class AlpacaChatPrompter(AlpacaPrompter):
|
| 31 |
+
"""
|
| 32 |
+
Alpaca Chat Prompter extending the system prompt to for chat-instruct answers
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n"
|
| 36 |
+
system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n"
|
| 37 |
+
|
| 38 |
+
def __init__(self): # pylint: disable=super-init-not-called
|
| 39 |
+
self.prompt_style = PromptStyle.CHAT.value
|
| 40 |
+
self.match_prompt_style()
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class NoSystemPrompter(AlpacaPrompter):
|
| 44 |
+
"""
|
| 45 |
+
Null Prompter with no system prompts
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
system_prompt = ""
|
| 49 |
+
system_no_input_prompt = ""
|
| 50 |
+
turn_format = "{instruction} {input} "
|
| 51 |
+
turn_no_input_format = "{instruction} "
|
| 52 |
+
|
| 53 |
+
def __init__(self): # pylint: disable=super-init-not-called
|
| 54 |
+
pass
|
| 55 |
|
| 56 |
|
| 57 |
class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
|
|
|
| 91 |
|
| 92 |
def load_qa(tokenizer, cfg):
|
| 93 |
return AlpacaQAPromptTokenizingStrategy(
|
| 94 |
+
AlpacaChatPrompter(),
|
| 95 |
tokenizer,
|
| 96 |
cfg.train_on_inputs,
|
| 97 |
cfg.sequence_len,
|
|
|
|
| 100 |
|
| 101 |
def load_camel_ai(tokenizer, cfg):
|
| 102 |
return CamelAIPromptTokenizingStrategy(
|
| 103 |
+
AlpacaChatPrompter(),
|
| 104 |
+
tokenizer,
|
| 105 |
+
cfg.train_on_inputs,
|
| 106 |
+
cfg.sequence_len,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def load_no_prompt(tokenizer, cfg):
|
| 111 |
+
return AlpacaPromptTokenizingStrategy(
|
| 112 |
+
UnpromptedPrompter(PromptStyle.CHAT.value),
|
| 113 |
tokenizer,
|
| 114 |
cfg.train_on_inputs,
|
| 115 |
cfg.sequence_len,
|
src/axolotl/prompt_strategies/alpaca_instruct.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
"""Module loading the AlpacaInstructPromptTokenizingStrategy class"""
|
| 2 |
|
| 3 |
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
|
| 4 |
-
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
| 5 |
|
| 6 |
|
| 7 |
def load(tokenizer, cfg):
|
|
@@ -11,3 +11,12 @@ def load(tokenizer, cfg):
|
|
| 11 |
cfg.train_on_inputs,
|
| 12 |
cfg.sequence_len,
|
| 13 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""Module loading the AlpacaInstructPromptTokenizingStrategy class"""
|
| 2 |
|
| 3 |
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
|
| 4 |
+
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
|
| 5 |
|
| 6 |
|
| 7 |
def load(tokenizer, cfg):
|
|
|
|
| 11 |
cfg.train_on_inputs,
|
| 12 |
cfg.sequence_len,
|
| 13 |
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def load_no_prompt(tokenizer, cfg):
|
| 17 |
+
return AlpacaPromptTokenizingStrategy(
|
| 18 |
+
UnpromptedPrompter(PromptStyle.INSTRUCT.value),
|
| 19 |
+
tokenizer,
|
| 20 |
+
cfg.train_on_inputs,
|
| 21 |
+
cfg.sequence_len,
|
| 22 |
+
)
|
src/axolotl/prompt_strategies/alpaca_w_system.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Prompt strategies loader for alpaca instruction datasets with system prompts
|
| 3 |
+
"""
|
| 4 |
+
from typing import Generator, Tuple, Union
|
| 5 |
+
|
| 6 |
+
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
| 7 |
+
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class InstructionWSystemPromptTokenizingStrategy(PromptTokenizingStrategy):
|
| 11 |
+
"""
|
| 12 |
+
Tokenizing strategy for instruction-based prompts.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]:
|
| 16 |
+
return (
|
| 17 |
+
prompt["instruction"],
|
| 18 |
+
prompt["input"] if "input" in prompt else "",
|
| 19 |
+
prompt["output"],
|
| 20 |
+
prompt["system"],
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
def tokenize_prompt(self, prompt):
|
| 24 |
+
# pylint: disable=duplicate-code
|
| 25 |
+
(
|
| 26 |
+
instruction,
|
| 27 |
+
input, # pylint: disable=redefined-builtin
|
| 28 |
+
response,
|
| 29 |
+
system,
|
| 30 |
+
) = self.parse_instruction_fields(prompt)
|
| 31 |
+
user_prompt = next(
|
| 32 |
+
iter(
|
| 33 |
+
self.prompter.build_prompt_w_system(
|
| 34 |
+
system,
|
| 35 |
+
instruction,
|
| 36 |
+
input,
|
| 37 |
+
)
|
| 38 |
+
)
|
| 39 |
+
)
|
| 40 |
+
tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
| 41 |
+
if not self.train_on_inputs:
|
| 42 |
+
user_prompt_len = len(tokenized_prompt["input_ids"])
|
| 43 |
+
# TODO this could be sped up using numpy array slicing
|
| 44 |
+
tokenized_prompt["labels"] = [-100] * user_prompt_len
|
| 45 |
+
tokenized_res_prompt = self._tokenize(
|
| 46 |
+
response, strip_bos_token=True, add_eos_token=True
|
| 47 |
+
)
|
| 48 |
+
tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"]
|
| 49 |
+
tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
|
| 50 |
+
tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]
|
| 51 |
+
|
| 52 |
+
return tokenized_prompt
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class SystemDataPrompter(AlpacaPrompter):
|
| 56 |
+
"""
|
| 57 |
+
Alpaca Style Prompter that uses system prompts from the dataset
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def build_prompt_w_system(
|
| 61 |
+
self,
|
| 62 |
+
system: str,
|
| 63 |
+
instruction: str,
|
| 64 |
+
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
| 65 |
+
output: Union[None, str] = None,
|
| 66 |
+
) -> Generator[str, None, None]:
|
| 67 |
+
# returns the full prompt from instruction and optional input
|
| 68 |
+
# if a label (=response, =output) is provided, it's also appended.
|
| 69 |
+
if input:
|
| 70 |
+
res = system + self.turn_format.format(instruction=instruction, input=input)
|
| 71 |
+
else:
|
| 72 |
+
res = system + self.turn_no_input_format.format(instruction=instruction)
|
| 73 |
+
if output:
|
| 74 |
+
res = f"{res}{output}"
|
| 75 |
+
yield res
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy):
|
| 79 |
+
"""
|
| 80 |
+
Tokenizing strategy for OpenOrca datasets
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]:
|
| 84 |
+
return (
|
| 85 |
+
prompt["question"],
|
| 86 |
+
"",
|
| 87 |
+
prompt["response"],
|
| 88 |
+
prompt["system_prompt"],
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def load(tokenizer, cfg):
|
| 93 |
+
return load_chat(tokenizer, cfg)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def load_instruct(tokenizer, cfg):
|
| 97 |
+
return InstructionWSystemPromptTokenizingStrategy(
|
| 98 |
+
SystemDataPrompter(PromptStyle.INSTRUCT.value),
|
| 99 |
+
tokenizer,
|
| 100 |
+
cfg.train_on_inputs,
|
| 101 |
+
cfg.sequence_len,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def load_chat(tokenizer, cfg):
|
| 106 |
+
return InstructionWSystemPromptTokenizingStrategy(
|
| 107 |
+
SystemDataPrompter(PromptStyle.CHAT.value),
|
| 108 |
+
tokenizer,
|
| 109 |
+
cfg.train_on_inputs,
|
| 110 |
+
cfg.sequence_len,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def load_open_orca(tokenizer, cfg):
|
| 115 |
+
return OpenOrcaPromptTokenizingStrategy(
|
| 116 |
+
SystemDataPrompter(PromptStyle.INSTRUCT.value),
|
| 117 |
+
tokenizer,
|
| 118 |
+
cfg.train_on_inputs,
|
| 119 |
+
cfg.sequence_len,
|
| 120 |
+
)
|
src/axolotl/prompt_tokenizers.py
CHANGED
|
@@ -87,7 +87,9 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
| 87 |
Tokenizing strategy for instruction-based prompts.
|
| 88 |
"""
|
| 89 |
|
| 90 |
-
def parse_instruction_fields(
|
|
|
|
|
|
|
| 91 |
raise NotImplementedError
|
| 92 |
|
| 93 |
def tokenize_prompt(self, prompt):
|
|
@@ -96,25 +98,27 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
| 96 |
input, # pylint: disable=redefined-builtin
|
| 97 |
response,
|
| 98 |
) = self.parse_instruction_fields(prompt)
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
self.prompter.build_prompt(
|
| 105 |
-
instruction,
|
| 106 |
-
input,
|
| 107 |
-
)
|
| 108 |
)
|
| 109 |
)
|
| 110 |
-
|
| 111 |
-
|
|
|
|
|
|
|
| 112 |
# TODO this could be sped up using numpy array slicing
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
-
return
|
| 118 |
|
| 119 |
def _build_full_prompt(
|
| 120 |
self, instruction, input, response # pylint: disable=redefined-builtin
|
|
@@ -436,7 +440,7 @@ def parse_tokenized_to_result(
|
|
| 436 |
result: Dict[str, List[int]],
|
| 437 |
current_len: int,
|
| 438 |
res: Dict[str, List[int]],
|
| 439 |
-
labels:
|
| 440 |
pad_token_id: Union[int, None] = None,
|
| 441 |
) -> Tuple[Dict[str, List[int]], int]:
|
| 442 |
"""
|
|
|
|
| 87 |
Tokenizing strategy for instruction-based prompts.
|
| 88 |
"""
|
| 89 |
|
| 90 |
+
def parse_instruction_fields(
|
| 91 |
+
self, prompt
|
| 92 |
+
) -> Union[Tuple[str, str, str], Tuple[str, str, str, str]]:
|
| 93 |
raise NotImplementedError
|
| 94 |
|
| 95 |
def tokenize_prompt(self, prompt):
|
|
|
|
| 98 |
input, # pylint: disable=redefined-builtin
|
| 99 |
response,
|
| 100 |
) = self.parse_instruction_fields(prompt)
|
| 101 |
+
user_prompt = next(
|
| 102 |
+
iter(
|
| 103 |
+
self.prompter.build_prompt(
|
| 104 |
+
instruction,
|
| 105 |
+
input,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
)
|
| 107 |
)
|
| 108 |
+
)
|
| 109 |
+
tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
| 110 |
+
if not self.train_on_inputs:
|
| 111 |
+
user_prompt_len = len(tokenized_prompt["input_ids"])
|
| 112 |
# TODO this could be sped up using numpy array slicing
|
| 113 |
+
tokenized_prompt["labels"] = [-100] * user_prompt_len
|
| 114 |
+
tokenized_res_prompt = self._tokenize(
|
| 115 |
+
response, strip_bos_token=True, add_eos_token=True
|
| 116 |
+
)
|
| 117 |
+
tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"]
|
| 118 |
+
tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
|
| 119 |
+
tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]
|
| 120 |
|
| 121 |
+
return tokenized_prompt
|
| 122 |
|
| 123 |
def _build_full_prompt(
|
| 124 |
self, instruction, input, response # pylint: disable=redefined-builtin
|
|
|
|
| 440 |
result: Dict[str, List[int]],
|
| 441 |
current_len: int,
|
| 442 |
res: Dict[str, List[int]],
|
| 443 |
+
labels: List[int],
|
| 444 |
pad_token_id: Union[int, None] = None,
|
| 445 |
) -> Tuple[Dict[str, List[int]], int]:
|
| 446 |
"""
|
src/axolotl/prompters.py
CHANGED
|
@@ -24,6 +24,8 @@ class AlpacaPrompter:
|
|
| 24 |
|
| 25 |
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
|
| 26 |
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
|
|
|
|
|
|
|
| 27 |
prompt_style: Optional[PromptStyle] = None
|
| 28 |
|
| 29 |
def __init__(self, prompt_style=PromptStyle.INSTRUCT.value):
|
|
@@ -32,23 +34,13 @@ class AlpacaPrompter:
|
|
| 32 |
|
| 33 |
def match_prompt_style(self):
|
| 34 |
if self.prompt_style == PromptStyle.INSTRUCT.value:
|
| 35 |
-
self.
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
)
|
| 39 |
-
self.prompt_no_input = (
|
| 40 |
-
self.system_no_input_prompt
|
| 41 |
-
+ "### Instruction:\n{instruction}\n\n### Response:\n"
|
| 42 |
)
|
| 43 |
-
self.response_split = "### Response:"
|
| 44 |
if self.prompt_style == PromptStyle.CHAT.value:
|
| 45 |
-
self.
|
| 46 |
-
|
| 47 |
-
)
|
| 48 |
-
self.prompt_no_input = (
|
| 49 |
-
self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
|
| 50 |
-
)
|
| 51 |
-
self.response_split = "ASSISTANT:"
|
| 52 |
|
| 53 |
def build_prompt(
|
| 54 |
self,
|
|
@@ -59,16 +51,17 @@ class AlpacaPrompter:
|
|
| 59 |
# returns the full prompt from instruction and optional input
|
| 60 |
# if a label (=response, =output) is provided, it's also appended.
|
| 61 |
if input:
|
| 62 |
-
res = self.
|
|
|
|
|
|
|
| 63 |
else:
|
| 64 |
-
res = self.
|
|
|
|
|
|
|
| 65 |
if output:
|
| 66 |
res = f"{res}{output}"
|
| 67 |
yield res
|
| 68 |
|
| 69 |
-
def get_response(self, output: str) -> str:
|
| 70 |
-
return output.split(self.response_split)[1].strip()
|
| 71 |
-
|
| 72 |
|
| 73 |
class UnpromptedPrompter(AlpacaPrompter):
|
| 74 |
"""
|
|
@@ -93,7 +86,10 @@ class MultipleChoiceExplainPrompter(AlpacaPrompter):
|
|
| 93 |
"""
|
| 94 |
|
| 95 |
system_prompt = (
|
| 96 |
-
"Choose the answer that best answers the question. Explain your reasoning
|
|
|
|
|
|
|
|
|
|
| 97 |
)
|
| 98 |
|
| 99 |
|
|
@@ -102,7 +98,12 @@ class MultipleChoiceConcisePrompter(AlpacaPrompter):
|
|
| 102 |
Prompter for multiple choice concise
|
| 103 |
"""
|
| 104 |
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
|
| 108 |
class SummarizeTLDRPrompter(AlpacaPrompter):
|
|
@@ -110,9 +111,12 @@ class SummarizeTLDRPrompter(AlpacaPrompter):
|
|
| 110 |
Prompter for summarize TLDR
|
| 111 |
"""
|
| 112 |
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
|
| 118 |
class CompletionPrompter:
|
|
@@ -128,9 +132,6 @@ class CompletionPrompter:
|
|
| 128 |
) -> Generator[str, None, None]:
|
| 129 |
yield instruction
|
| 130 |
|
| 131 |
-
def get_response(self, output: str) -> str:
|
| 132 |
-
return output.strip()
|
| 133 |
-
|
| 134 |
|
| 135 |
class GPTeacherPrompter(AlpacaPrompter):
|
| 136 |
"""
|
|
@@ -210,9 +211,6 @@ class ReflectAlpacaPrompter:
|
|
| 210 |
res = f"{res}{label}"
|
| 211 |
yield res
|
| 212 |
|
| 213 |
-
def get_response(self, output: str) -> str:
|
| 214 |
-
return output.split(self.response_split)[1].strip()
|
| 215 |
-
|
| 216 |
|
| 217 |
class SeparatorStyle(Enum):
|
| 218 |
"""Different separator style."""
|
|
@@ -289,12 +287,6 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
|
| 289 |
sep2=" ",
|
| 290 |
)
|
| 291 |
|
| 292 |
-
# def match_prompt_style(self):
|
| 293 |
-
# if self.prompt_style == PromptStyle.chat.value:
|
| 294 |
-
# self.prompt_input = self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
|
| 295 |
-
# self.prompt_no_input = self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
|
| 296 |
-
# self.response_split = "ASSISTANT:"
|
| 297 |
-
|
| 298 |
def build_prompt(self, source) -> Generator[str, None, None]:
|
| 299 |
# ignore the system prompt if provided
|
| 300 |
if source[0]["from"] == "system":
|
|
|
|
| 24 |
|
| 25 |
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
|
| 26 |
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
|
| 27 |
+
turn_format: str
|
| 28 |
+
turn_no_input_format: str
|
| 29 |
prompt_style: Optional[PromptStyle] = None
|
| 30 |
|
| 31 |
def __init__(self, prompt_style=PromptStyle.INSTRUCT.value):
|
|
|
|
| 34 |
|
| 35 |
def match_prompt_style(self):
|
| 36 |
if self.prompt_style == PromptStyle.INSTRUCT.value:
|
| 37 |
+
self.turn_format = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
| 38 |
+
self.turn_no_input_format = (
|
| 39 |
+
"### Instruction:\n{instruction}\n\n### Response:\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
)
|
|
|
|
| 41 |
if self.prompt_style == PromptStyle.CHAT.value:
|
| 42 |
+
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
|
| 43 |
+
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
def build_prompt(
|
| 46 |
self,
|
|
|
|
| 51 |
# returns the full prompt from instruction and optional input
|
| 52 |
# if a label (=response, =output) is provided, it's also appended.
|
| 53 |
if input:
|
| 54 |
+
res = self.system_prompt + self.turn_format.format(
|
| 55 |
+
instruction=instruction, input=input
|
| 56 |
+
)
|
| 57 |
else:
|
| 58 |
+
res = self.system_no_input_prompt + self.turn_no_input_format.format(
|
| 59 |
+
instruction=instruction
|
| 60 |
+
)
|
| 61 |
if output:
|
| 62 |
res = f"{res}{output}"
|
| 63 |
yield res
|
| 64 |
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
class UnpromptedPrompter(AlpacaPrompter):
|
| 67 |
"""
|
|
|
|
| 86 |
"""
|
| 87 |
|
| 88 |
system_prompt = (
|
| 89 |
+
"Choose the answer that best answers the question. Explain your reasoning.\n"
|
| 90 |
+
)
|
| 91 |
+
system_no_input_prompt = (
|
| 92 |
+
"Choose the answer that best answers the question. Explain your reasoning.\n"
|
| 93 |
)
|
| 94 |
|
| 95 |
|
|
|
|
| 98 |
Prompter for multiple choice concise
|
| 99 |
"""
|
| 100 |
|
| 101 |
+
system_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n"
|
| 102 |
+
system_no_input_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n"
|
| 103 |
+
|
| 104 |
+
def match_prompt_style(self):
|
| 105 |
+
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
|
| 106 |
+
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
|
| 107 |
|
| 108 |
|
| 109 |
class SummarizeTLDRPrompter(AlpacaPrompter):
|
|
|
|
| 111 |
Prompter for summarize TLDR
|
| 112 |
"""
|
| 113 |
|
| 114 |
+
system_prompt = ""
|
| 115 |
+
system_no_input_prompt = ""
|
| 116 |
+
|
| 117 |
+
def match_prompt_style(self):
|
| 118 |
+
self.turn_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\n{input}\nASSISTANT:"
|
| 119 |
+
self.turn_no_input_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
|
| 120 |
|
| 121 |
|
| 122 |
class CompletionPrompter:
|
|
|
|
| 132 |
) -> Generator[str, None, None]:
|
| 133 |
yield instruction
|
| 134 |
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
class GPTeacherPrompter(AlpacaPrompter):
|
| 137 |
"""
|
|
|
|
| 211 |
res = f"{res}{label}"
|
| 212 |
yield res
|
| 213 |
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
class SeparatorStyle(Enum):
|
| 216 |
"""Different separator style."""
|
|
|
|
| 287 |
sep2=" ",
|
| 288 |
)
|
| 289 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
def build_prompt(self, source) -> Generator[str, None, None]:
|
| 291 |
# ignore the system prompt if provided
|
| 292 |
if source[0]["from"] == "system":
|
src/axolotl/utils/callbacks.py
CHANGED
|
@@ -2,13 +2,14 @@
|
|
| 2 |
|
| 3 |
import os
|
| 4 |
|
|
|
|
| 5 |
from transformers import (
|
| 6 |
TrainerCallback,
|
| 7 |
TrainerControl,
|
| 8 |
TrainerState,
|
| 9 |
TrainingArguments,
|
| 10 |
)
|
| 11 |
-
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
| 12 |
|
| 13 |
|
| 14 |
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
|
|
@@ -30,3 +31,39 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
|
|
| 30 |
kwargs["model"].save_pretrained(peft_model_path)
|
| 31 |
|
| 32 |
return control
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
import os
|
| 4 |
|
| 5 |
+
from optimum.bettertransformer import BetterTransformer
|
| 6 |
from transformers import (
|
| 7 |
TrainerCallback,
|
| 8 |
TrainerControl,
|
| 9 |
TrainerState,
|
| 10 |
TrainingArguments,
|
| 11 |
)
|
| 12 |
+
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
| 13 |
|
| 14 |
|
| 15 |
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
|
|
|
|
| 31 |
kwargs["model"].save_pretrained(peft_model_path)
|
| 32 |
|
| 33 |
return control
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class SaveBetterTransformerModelCallback(
|
| 37 |
+
TrainerCallback
|
| 38 |
+
): # pylint: disable=too-few-public-methods
|
| 39 |
+
"""Callback to save the BetterTransformer wrapped model"""
|
| 40 |
+
|
| 41 |
+
def on_step_end(
|
| 42 |
+
self,
|
| 43 |
+
args: TrainingArguments,
|
| 44 |
+
state: TrainerState,
|
| 45 |
+
control: TrainerControl,
|
| 46 |
+
**kwargs,
|
| 47 |
+
):
|
| 48 |
+
# Save
|
| 49 |
+
if (
|
| 50 |
+
args.save_strategy == IntervalStrategy.STEPS
|
| 51 |
+
and args.save_steps > 0
|
| 52 |
+
and state.global_step % args.save_steps == 0
|
| 53 |
+
):
|
| 54 |
+
control.should_save = True
|
| 55 |
+
|
| 56 |
+
if control.should_save:
|
| 57 |
+
checkpoint_folder = os.path.join(
|
| 58 |
+
args.output_dir,
|
| 59 |
+
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
model = BetterTransformer.reverse(kwargs["model"])
|
| 63 |
+
model.save_pretrained(checkpoint_folder)
|
| 64 |
+
# FIXME - need to cleanup old checkpoints
|
| 65 |
+
|
| 66 |
+
# since we're saving here, we don't need the trainer loop to attempt to save too b/c
|
| 67 |
+
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model
|
| 68 |
+
control.should_save = False
|
| 69 |
+
return control
|
src/axolotl/utils/data.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
| 1 |
"""Module containing data utilities"""
|
| 2 |
-
|
| 3 |
import logging
|
| 4 |
from hashlib import md5
|
| 5 |
from pathlib import Path
|
| 6 |
from typing import List, Tuple, Union
|
| 7 |
|
|
|
|
| 8 |
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
|
| 9 |
from huggingface_hub import hf_hub_download
|
| 10 |
from transformers import PreTrainedTokenizerBase
|
|
@@ -101,13 +102,26 @@ def load_tokenized_prepared_datasets(
|
|
| 101 |
pass
|
| 102 |
|
| 103 |
# prefer local dataset, even if hub exists
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
elif ds_from_hub:
|
| 112 |
if d.data_files:
|
| 113 |
ds = load_dataset(
|
|
@@ -394,8 +408,127 @@ def load_prepare_datasets(
|
|
| 394 |
index=cfg.dataset_shard_idx,
|
| 395 |
)
|
| 396 |
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
|
| 401 |
return train_dataset, eval_dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""Module containing data utilities"""
|
| 2 |
+
import functools
|
| 3 |
import logging
|
| 4 |
from hashlib import md5
|
| 5 |
from pathlib import Path
|
| 6 |
from typing import List, Tuple, Union
|
| 7 |
|
| 8 |
+
import torch
|
| 9 |
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
|
| 10 |
from huggingface_hub import hf_hub_download
|
| 11 |
from transformers import PreTrainedTokenizerBase
|
|
|
|
| 102 |
pass
|
| 103 |
|
| 104 |
# prefer local dataset, even if hub exists
|
| 105 |
+
local_path = Path(d.path)
|
| 106 |
+
if local_path.exists():
|
| 107 |
+
if local_path.is_dir():
|
| 108 |
+
ds = load_dataset(
|
| 109 |
+
d.path,
|
| 110 |
+
data_files=d.data_files,
|
| 111 |
+
streaming=False,
|
| 112 |
+
split=None,
|
| 113 |
+
)
|
| 114 |
+
elif local_path.is_file():
|
| 115 |
+
ds = load_dataset(
|
| 116 |
+
"json",
|
| 117 |
+
data_files=d.path,
|
| 118 |
+
streaming=False,
|
| 119 |
+
split=None,
|
| 120 |
+
)
|
| 121 |
+
else:
|
| 122 |
+
raise ValueError(
|
| 123 |
+
"unhandled dataset load: local path exists, but is neither a directory or a file"
|
| 124 |
+
)
|
| 125 |
elif ds_from_hub:
|
| 126 |
if d.data_files:
|
| 127 |
ds = load_dataset(
|
|
|
|
| 408 |
index=cfg.dataset_shard_idx,
|
| 409 |
)
|
| 410 |
|
| 411 |
+
if cfg.val_set_size:
|
| 412 |
+
dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
|
| 413 |
+
train_dataset = dataset["train"]
|
| 414 |
+
eval_dataset = dataset["test"]
|
| 415 |
+
else:
|
| 416 |
+
train_dataset = dataset
|
| 417 |
+
eval_dataset = None
|
| 418 |
|
| 419 |
return train_dataset, eval_dataset
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
def encode_pretraining(tokenizer, max_tokens, examples):
|
| 423 |
+
res = tokenizer(
|
| 424 |
+
examples["text"],
|
| 425 |
+
truncation=True,
|
| 426 |
+
max_length=max_tokens - 2,
|
| 427 |
+
add_special_tokens=True,
|
| 428 |
+
)
|
| 429 |
+
# Convert to PyTorch tensors
|
| 430 |
+
input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
|
| 431 |
+
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
|
| 432 |
+
new_input_ids = []
|
| 433 |
+
new_attention_mask = []
|
| 434 |
+
# Append EOS and PAD tokens to input_ids, and correct attention_mask
|
| 435 |
+
for i, _ in enumerate(input_ids):
|
| 436 |
+
input_ids[i] = torch.cat(
|
| 437 |
+
(
|
| 438 |
+
input_ids[i],
|
| 439 |
+
torch.tensor([tokenizer.eos_token_id, tokenizer.pad_token_id]),
|
| 440 |
+
),
|
| 441 |
+
dim=0,
|
| 442 |
+
)
|
| 443 |
+
attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0)
|
| 444 |
+
|
| 445 |
+
# Concatenate tokens so that their lengths are less than max_tokens
|
| 446 |
+
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
| 447 |
+
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
| 448 |
+
|
| 449 |
+
for ids, mask in zip(input_ids, attention_mask):
|
| 450 |
+
if buffer_input_ids.numel() == max_tokens:
|
| 451 |
+
new_input_ids.append(buffer_input_ids)
|
| 452 |
+
new_attention_mask.append(buffer_attention_mask)
|
| 453 |
+
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
| 454 |
+
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
| 455 |
+
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
| 456 |
+
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
| 457 |
+
elif buffer_input_ids.numel() + ids.numel() <= max_tokens:
|
| 458 |
+
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
| 459 |
+
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
| 460 |
+
else:
|
| 461 |
+
buffer_input_ids = torch.cat(
|
| 462 |
+
(
|
| 463 |
+
buffer_input_ids,
|
| 464 |
+
torch.full(
|
| 465 |
+
(max_tokens - buffer_input_ids.numel(),),
|
| 466 |
+
tokenizer.pad_token_id,
|
| 467 |
+
dtype=torch.long,
|
| 468 |
+
),
|
| 469 |
+
),
|
| 470 |
+
dim=0,
|
| 471 |
+
)
|
| 472 |
+
buffer_attention_mask = torch.cat(
|
| 473 |
+
(
|
| 474 |
+
buffer_attention_mask,
|
| 475 |
+
torch.full(
|
| 476 |
+
(max_tokens - buffer_attention_mask.numel(),),
|
| 477 |
+
0,
|
| 478 |
+
dtype=torch.long,
|
| 479 |
+
),
|
| 480 |
+
),
|
| 481 |
+
dim=0,
|
| 482 |
+
)
|
| 483 |
+
new_input_ids.append(buffer_input_ids)
|
| 484 |
+
new_attention_mask.append(buffer_attention_mask)
|
| 485 |
+
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
| 486 |
+
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
| 487 |
+
|
| 488 |
+
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
| 489 |
+
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
| 490 |
+
|
| 491 |
+
if buffer_input_ids.numel() > 0: # for any leftover tokens
|
| 492 |
+
while buffer_input_ids.numel() < max_tokens: # make all sequences equal in size
|
| 493 |
+
buffer_input_ids = torch.cat(
|
| 494 |
+
(
|
| 495 |
+
buffer_input_ids,
|
| 496 |
+
torch.full(
|
| 497 |
+
(max_tokens - buffer_input_ids.numel(),),
|
| 498 |
+
tokenizer.pad_token_id,
|
| 499 |
+
dtype=torch.long,
|
| 500 |
+
),
|
| 501 |
+
),
|
| 502 |
+
dim=0,
|
| 503 |
+
)
|
| 504 |
+
buffer_attention_mask = torch.cat(
|
| 505 |
+
(
|
| 506 |
+
buffer_attention_mask,
|
| 507 |
+
torch.full(
|
| 508 |
+
(max_tokens - buffer_attention_mask.numel(),),
|
| 509 |
+
0,
|
| 510 |
+
dtype=torch.long,
|
| 511 |
+
),
|
| 512 |
+
),
|
| 513 |
+
dim=0,
|
| 514 |
+
)
|
| 515 |
+
new_input_ids.append(buffer_input_ids)
|
| 516 |
+
new_attention_mask.append(buffer_attention_mask)
|
| 517 |
+
|
| 518 |
+
ret = {
|
| 519 |
+
"input_ids": [seq.tolist() for seq in new_input_ids],
|
| 520 |
+
"labels": [seq.tolist() for seq in new_input_ids],
|
| 521 |
+
"attention_mask": [seq.tolist() for seq in new_attention_mask],
|
| 522 |
+
}
|
| 523 |
+
|
| 524 |
+
logging.debug(len(ret["input_ids"]))
|
| 525 |
+
return ret
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
|
| 529 |
+
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
| 530 |
+
dataset = load_dataset(path, streaming=True, split="train")
|
| 531 |
+
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
|
| 532 |
+
# TODO dynamically figure out which columns/features to remove
|
| 533 |
+
dataset = dataset.map(encode, batched=True, remove_columns=["text", "meta"])
|
| 534 |
+
return dataset
|
src/axolotl/utils/models.py
CHANGED
|
@@ -10,13 +10,15 @@ from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
|
|
| 10 |
import bitsandbytes as bnb
|
| 11 |
import torch
|
| 12 |
import transformers
|
| 13 |
-
from
|
| 14 |
from transformers import ( # noqa: F401
|
| 15 |
AutoConfig,
|
| 16 |
AutoModelForCausalLM,
|
| 17 |
AutoTokenizer,
|
| 18 |
BitsAndBytesConfig,
|
| 19 |
LlamaConfig,
|
|
|
|
|
|
|
| 20 |
)
|
| 21 |
|
| 22 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
|
@@ -32,15 +34,20 @@ def load_tokenizer(
|
|
| 32 |
tokenizer_type,
|
| 33 |
cfg,
|
| 34 |
):
|
|
|
|
|
|
|
|
|
|
| 35 |
if tokenizer_type:
|
| 36 |
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
|
| 37 |
tokenizer_config,
|
| 38 |
trust_remote_code=cfg.trust_remote_code or False,
|
|
|
|
| 39 |
)
|
| 40 |
else:
|
| 41 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 42 |
tokenizer_config,
|
| 43 |
trust_remote_code=cfg.trust_remote_code or False,
|
|
|
|
| 44 |
)
|
| 45 |
|
| 46 |
logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
|
@@ -70,7 +77,7 @@ def load_tokenizer(
|
|
| 70 |
def load_model(
|
| 71 |
base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora"
|
| 72 |
):
|
| 73 |
-
# type: (str, str, str,
|
| 74 |
"""
|
| 75 |
Load a model from a base model and a model type.
|
| 76 |
"""
|
|
@@ -121,9 +128,9 @@ def load_model(
|
|
| 121 |
logging.info("patching with xpos rope")
|
| 122 |
replace_llama_rope_with_xpos_rope()
|
| 123 |
|
| 124 |
-
if cfg.bf16:
|
| 125 |
torch_dtype = torch.bfloat16
|
| 126 |
-
elif cfg.load_in_8bit or cfg.fp16:
|
| 127 |
torch_dtype = torch.float16
|
| 128 |
else:
|
| 129 |
torch_dtype = torch.float32
|
|
@@ -195,7 +202,7 @@ def load_model(
|
|
| 195 |
else True,
|
| 196 |
)
|
| 197 |
load_in_8bit = False
|
| 198 |
-
elif cfg.is_llama_derived_model:
|
| 199 |
from transformers import LlamaForCausalLM
|
| 200 |
|
| 201 |
config = LlamaConfig.from_pretrained(base_model_config)
|
|
@@ -234,7 +241,7 @@ def load_model(
|
|
| 234 |
# device=cfg.device,
|
| 235 |
# )
|
| 236 |
# model.train() # sets to train instead of eval mode
|
| 237 |
-
elif model_type:
|
| 238 |
model = getattr(transformers, model_type).from_pretrained(
|
| 239 |
base_model,
|
| 240 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
|
@@ -251,11 +258,16 @@ def load_model(
|
|
| 251 |
)
|
| 252 |
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
|
| 253 |
# when training starts
|
| 254 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
config.max_seq_len = cfg.sequence_len
|
| 256 |
logging.warning(f"increasing context length to {cfg.sequence_len}")
|
| 257 |
elif (
|
| 258 |
hasattr(config, "max_sequence_length")
|
|
|
|
| 259 |
and cfg.sequence_len > config.max_sequence_length
|
| 260 |
):
|
| 261 |
config.max_sequence_length = cfg.sequence_len
|
|
@@ -278,6 +290,7 @@ def load_model(
|
|
| 278 |
model = AutoModelForCausalLM.from_pretrained(
|
| 279 |
base_model,
|
| 280 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
|
|
|
| 281 |
torch_dtype=torch_dtype,
|
| 282 |
device_map=cfg.device_map,
|
| 283 |
trust_remote_code=cfg.trust_remote_code or False,
|
|
@@ -287,6 +300,16 @@ def load_model(
|
|
| 287 |
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
|
| 288 |
model.resize_token_embeddings(embeddings_len)
|
| 289 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
if not cfg.gptq and (
|
| 291 |
(cfg.adapter == "lora" and load_in_8bit)
|
| 292 |
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
|
@@ -332,6 +355,9 @@ def load_model(
|
|
| 332 |
logging.warning("there are no parameters that require gradient updates")
|
| 333 |
model.config.use_cache = False
|
| 334 |
|
|
|
|
|
|
|
|
|
|
| 335 |
# TODO resume_from_checkpoint handling
|
| 336 |
return model, lora_config
|
| 337 |
|
|
|
|
| 10 |
import bitsandbytes as bnb
|
| 11 |
import torch
|
| 12 |
import transformers
|
| 13 |
+
from optimum.bettertransformer import BetterTransformer
|
| 14 |
from transformers import ( # noqa: F401
|
| 15 |
AutoConfig,
|
| 16 |
AutoModelForCausalLM,
|
| 17 |
AutoTokenizer,
|
| 18 |
BitsAndBytesConfig,
|
| 19 |
LlamaConfig,
|
| 20 |
+
PreTrainedModel,
|
| 21 |
+
PreTrainedTokenizerBase,
|
| 22 |
)
|
| 23 |
|
| 24 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
|
|
|
| 34 |
tokenizer_type,
|
| 35 |
cfg,
|
| 36 |
):
|
| 37 |
+
use_fast = True # this is the default
|
| 38 |
+
if cfg.tokenizer_use_fast is not None:
|
| 39 |
+
use_fast = cfg.tokenizer_use_fast
|
| 40 |
if tokenizer_type:
|
| 41 |
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
|
| 42 |
tokenizer_config,
|
| 43 |
trust_remote_code=cfg.trust_remote_code or False,
|
| 44 |
+
use_fast=use_fast,
|
| 45 |
)
|
| 46 |
else:
|
| 47 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 48 |
tokenizer_config,
|
| 49 |
trust_remote_code=cfg.trust_remote_code or False,
|
| 50 |
+
use_fast=use_fast,
|
| 51 |
)
|
| 52 |
|
| 53 |
logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
|
|
|
| 77 |
def load_model(
|
| 78 |
base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora"
|
| 79 |
):
|
| 80 |
+
# type: (str, str, str, PreTrainedTokenizerBase, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
| 81 |
"""
|
| 82 |
Load a model from a base model and a model type.
|
| 83 |
"""
|
|
|
|
| 128 |
logging.info("patching with xpos rope")
|
| 129 |
replace_llama_rope_with_xpos_rope()
|
| 130 |
|
| 131 |
+
if cfg.bf16 or cfg.bfloat16:
|
| 132 |
torch_dtype = torch.bfloat16
|
| 133 |
+
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
|
| 134 |
torch_dtype = torch.float16
|
| 135 |
else:
|
| 136 |
torch_dtype = torch.float32
|
|
|
|
| 202 |
else True,
|
| 203 |
)
|
| 204 |
load_in_8bit = False
|
| 205 |
+
elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
|
| 206 |
from transformers import LlamaForCausalLM
|
| 207 |
|
| 208 |
config = LlamaConfig.from_pretrained(base_model_config)
|
|
|
|
| 241 |
# device=cfg.device,
|
| 242 |
# )
|
| 243 |
# model.train() # sets to train instead of eval mode
|
| 244 |
+
elif model_type and not cfg.trust_remote_code:
|
| 245 |
model = getattr(transformers, model_type).from_pretrained(
|
| 246 |
base_model,
|
| 247 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
|
|
|
| 258 |
)
|
| 259 |
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
|
| 260 |
# when training starts
|
| 261 |
+
if (
|
| 262 |
+
hasattr(config, "max_seq_len")
|
| 263 |
+
and config.max_seq_len
|
| 264 |
+
and cfg.sequence_len > config.max_seq_len
|
| 265 |
+
):
|
| 266 |
config.max_seq_len = cfg.sequence_len
|
| 267 |
logging.warning(f"increasing context length to {cfg.sequence_len}")
|
| 268 |
elif (
|
| 269 |
hasattr(config, "max_sequence_length")
|
| 270 |
+
and config.max_sequence_length
|
| 271 |
and cfg.sequence_len > config.max_sequence_length
|
| 272 |
):
|
| 273 |
config.max_sequence_length = cfg.sequence_len
|
|
|
|
| 290 |
model = AutoModelForCausalLM.from_pretrained(
|
| 291 |
base_model,
|
| 292 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 293 |
+
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
| 294 |
torch_dtype=torch_dtype,
|
| 295 |
device_map=cfg.device_map,
|
| 296 |
trust_remote_code=cfg.trust_remote_code or False,
|
|
|
|
| 300 |
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
|
| 301 |
model.resize_token_embeddings(embeddings_len)
|
| 302 |
|
| 303 |
+
if (
|
| 304 |
+
hasattr(model.config, "max_position_embeddings")
|
| 305 |
+
and model.config.max_position_embeddings
|
| 306 |
+
and cfg.sequence_len >= model.config.max_position_embeddings
|
| 307 |
+
):
|
| 308 |
+
logging.warning(
|
| 309 |
+
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
|
| 310 |
+
)
|
| 311 |
+
model.config.max_position_embeddings = cfg.sequence_len
|
| 312 |
+
|
| 313 |
if not cfg.gptq and (
|
| 314 |
(cfg.adapter == "lora" and load_in_8bit)
|
| 315 |
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
|
|
|
| 355 |
logging.warning("there are no parameters that require gradient updates")
|
| 356 |
model.config.use_cache = False
|
| 357 |
|
| 358 |
+
if cfg.flash_optimum:
|
| 359 |
+
model = BetterTransformer.transform(model)
|
| 360 |
+
|
| 361 |
# TODO resume_from_checkpoint handling
|
| 362 |
return model, lora_config
|
| 363 |
|
src/axolotl/utils/tokenization.py
CHANGED
|
@@ -34,3 +34,5 @@ def check_example_labels(example, tokenizer):
|
|
| 34 |
|
| 35 |
logging.info(" ".join(colored_tokens))
|
| 36 |
logging.info("\n\n\n")
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
logging.info(" ".join(colored_tokens))
|
| 36 |
logging.info("\n\n\n")
|
| 37 |
+
|
| 38 |
+
return " ".join(colored_tokens)
|
src/axolotl/utils/trainer.py
CHANGED
|
@@ -17,7 +17,10 @@ from torch.optim.lr_scheduler import OneCycleLR
|
|
| 17 |
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
| 18 |
from transformers.trainer_pt_utils import get_parameter_names
|
| 19 |
|
| 20 |
-
from axolotl.utils.callbacks import
|
|
|
|
|
|
|
|
|
|
| 21 |
from axolotl.utils.schedulers import (
|
| 22 |
InterpolatingLogScheduler,
|
| 23 |
get_cosine_schedule_with_quadratic_warmup,
|
|
@@ -166,6 +169,19 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
| 166 |
# TODO search Path("./") for one
|
| 167 |
training_arguments_kwargs["deepspeed"] = "./ds_config.json"
|
| 168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
training_args = AxolotlTrainingArguments(
|
| 170 |
per_device_train_batch_size=cfg.micro_batch_size,
|
| 171 |
per_device_eval_batch_size=cfg.eval_batch_size
|
|
@@ -282,6 +298,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
| 282 |
]: # only save in rank 0
|
| 283 |
callbacks.append(SavePeftModelCallback)
|
| 284 |
|
|
|
|
|
|
|
|
|
|
| 285 |
data_collator_kwargs = {
|
| 286 |
"padding": True,
|
| 287 |
}
|
|
|
|
| 17 |
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
| 18 |
from transformers.trainer_pt_utils import get_parameter_names
|
| 19 |
|
| 20 |
+
from axolotl.utils.callbacks import (
|
| 21 |
+
SaveBetterTransformerModelCallback,
|
| 22 |
+
SavePeftModelCallback,
|
| 23 |
+
)
|
| 24 |
from axolotl.utils.schedulers import (
|
| 25 |
InterpolatingLogScheduler,
|
| 26 |
get_cosine_schedule_with_quadratic_warmup,
|
|
|
|
| 169 |
# TODO search Path("./") for one
|
| 170 |
training_arguments_kwargs["deepspeed"] = "./ds_config.json"
|
| 171 |
|
| 172 |
+
if cfg.adam_beta1:
|
| 173 |
+
training_arguments_kwargs["adam_beta1"] = cfg.adam_beta1
|
| 174 |
+
if cfg.adam_beta2:
|
| 175 |
+
training_arguments_kwargs["adam_beta2"] = cfg.adam_beta2
|
| 176 |
+
if cfg.adam_epsilon:
|
| 177 |
+
training_arguments_kwargs["adam_epsilon"] = cfg.adam_epsilon
|
| 178 |
+
if cfg.max_grad_norm:
|
| 179 |
+
training_arguments_kwargs["max_grad_norm"] = cfg.max_grad_norm
|
| 180 |
+
|
| 181 |
+
if cfg.hub_model_id:
|
| 182 |
+
training_arguments_kwargs["hub_model_id"] = cfg.hub_model_id
|
| 183 |
+
training_arguments_kwargs["push_to_hub"] = True
|
| 184 |
+
|
| 185 |
training_args = AxolotlTrainingArguments(
|
| 186 |
per_device_train_batch_size=cfg.micro_batch_size,
|
| 187 |
per_device_eval_batch_size=cfg.eval_batch_size
|
|
|
|
| 298 |
]: # only save in rank 0
|
| 299 |
callbacks.append(SavePeftModelCallback)
|
| 300 |
|
| 301 |
+
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
|
| 302 |
+
callbacks.append(SaveBetterTransformerModelCallback)
|
| 303 |
+
|
| 304 |
data_collator_kwargs = {
|
| 305 |
"padding": True,
|
| 306 |
}
|
src/axolotl/utils/validation.py
CHANGED
|
@@ -2,6 +2,8 @@
|
|
| 2 |
|
| 3 |
import logging
|
| 4 |
|
|
|
|
|
|
|
| 5 |
|
| 6 |
def validate_config(cfg):
|
| 7 |
if cfg.gradient_accumulation_steps and cfg.batch_size:
|
|
@@ -62,7 +64,47 @@ def validate_config(cfg):
|
|
| 62 |
) and cfg.gradient_checkpointing:
|
| 63 |
raise ValueError("gradient_checkpointing is not supported for MPT models")
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
# TODO
|
| 66 |
# MPT 7b
|
| 67 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
| 68 |
-
# no 8bit
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
import logging
|
| 4 |
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
|
| 8 |
def validate_config(cfg):
|
| 9 |
if cfg.gradient_accumulation_steps and cfg.batch_size:
|
|
|
|
| 64 |
) and cfg.gradient_checkpointing:
|
| 65 |
raise ValueError("gradient_checkpointing is not supported for MPT models")
|
| 66 |
|
| 67 |
+
if cfg.flash_optimum is True:
|
| 68 |
+
if cfg.adapter:
|
| 69 |
+
logging.warning(
|
| 70 |
+
"BetterTransformers probably doesn't work with PEFT adapters"
|
| 71 |
+
)
|
| 72 |
+
if cfg.fp16 or cfg.bf16:
|
| 73 |
+
raise ValueError("AMP is not supported with BetterTransformer")
|
| 74 |
+
if cfg.float16 is not True and cfg.bloat16 is not True:
|
| 75 |
+
logging.warning(
|
| 76 |
+
"You should probably set bfloat16 or float16 to true to "
|
| 77 |
+
"load the model in float16 for BetterTransformers"
|
| 78 |
+
)
|
| 79 |
+
if int(torch.__version__.split(".")[0]) < 2:
|
| 80 |
+
logging.warning("torch>=2.0.0 required")
|
| 81 |
+
raise ValueError(
|
| 82 |
+
f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
if cfg.pretraining_dataset and cfg.group_by_length:
|
| 86 |
+
logging.warning(
|
| 87 |
+
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
|
| 91 |
+
not cfg.optimizer or "adamw" not in cfg.optimizer
|
| 92 |
+
):
|
| 93 |
+
logging.warning("adamw hyperparameters found, but no adamw optimizer set")
|
| 94 |
+
|
| 95 |
+
if cfg.push_to_hub_model_id:
|
| 96 |
+
raise ValueError(
|
| 97 |
+
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
# TODO
|
| 101 |
# MPT 7b
|
| 102 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
| 103 |
+
# no 8bit adaAmw w bf16
|
| 104 |
+
|
| 105 |
+
# GPT-NeoX
|
| 106 |
+
# evals broken when extending context len
|
| 107 |
+
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 162, in forward attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
| 108 |
+
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/optimum/bettertransformer/models/attention.py", line 74, in gpt2_wrapped_scaled_dot_product
|
| 109 |
+
# attention_mask = causal_mask + attention_mask
|
| 110 |
+
# RuntimeError: The size of tensor a (2048) must match the size of tensor b (8132) at non-singleton dimension 3
|
tests/test_prompt_tokenizers.py
CHANGED
|
@@ -6,8 +6,16 @@ from pathlib import Path
|
|
| 6 |
|
| 7 |
from transformers import AutoTokenizer
|
| 8 |
|
| 9 |
-
from axolotl.
|
| 10 |
-
from axolotl.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
logging.basicConfig(level="INFO")
|
| 13 |
|
|
@@ -29,7 +37,6 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
|
| 29 |
)
|
| 30 |
|
| 31 |
def test_sharegpt_integration(self):
|
| 32 |
-
print(Path(__file__).parent)
|
| 33 |
with open(
|
| 34 |
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
|
| 35 |
) as fin:
|
|
@@ -53,6 +60,79 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
|
| 53 |
self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
|
| 54 |
self.assertEqual(example[fields], tokenized_conversation[fields])
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
if __name__ == "__main__":
|
| 58 |
unittest.main()
|
|
|
|
| 6 |
|
| 7 |
from transformers import AutoTokenizer
|
| 8 |
|
| 9 |
+
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
|
| 10 |
+
from axolotl.prompt_strategies.alpaca_w_system import (
|
| 11 |
+
InstructionWSystemPromptTokenizingStrategy,
|
| 12 |
+
SystemDataPrompter,
|
| 13 |
+
)
|
| 14 |
+
from axolotl.prompt_tokenizers import (
|
| 15 |
+
AlpacaPromptTokenizingStrategy,
|
| 16 |
+
ShareGPTPromptTokenizingStrategy,
|
| 17 |
+
)
|
| 18 |
+
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompter
|
| 19 |
|
| 20 |
logging.basicConfig(level="INFO")
|
| 21 |
|
|
|
|
| 37 |
)
|
| 38 |
|
| 39 |
def test_sharegpt_integration(self):
|
|
|
|
| 40 |
with open(
|
| 41 |
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
|
| 42 |
) as fin:
|
|
|
|
| 60 |
self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
|
| 61 |
self.assertEqual(example[fields], tokenized_conversation[fields])
|
| 62 |
|
| 63 |
+
def test_no_sys_prompt(self):
|
| 64 |
+
"""
|
| 65 |
+
tests the interface between the user and assistant parts
|
| 66 |
+
"""
|
| 67 |
+
prompter = NoSystemPrompter()
|
| 68 |
+
# pylint: disable=duplicate-code
|
| 69 |
+
strat = AlpacaPromptTokenizingStrategy(
|
| 70 |
+
prompter,
|
| 71 |
+
self.tokenizer,
|
| 72 |
+
False,
|
| 73 |
+
2048,
|
| 74 |
+
)
|
| 75 |
+
sample = {
|
| 76 |
+
"instruction": "hello cruel. lorem ipsum dolor sit amet.",
|
| 77 |
+
"output": "world!",
|
| 78 |
+
}
|
| 79 |
+
example = strat.tokenize_prompt(sample)
|
| 80 |
+
world_idx = example["input_ids"].index(3186)
|
| 81 |
+
assert example["labels"][world_idx] == 3186
|
| 82 |
+
assert example["labels"][world_idx - 1] == -100
|
| 83 |
+
|
| 84 |
+
def test_alpaca(self):
|
| 85 |
+
"""
|
| 86 |
+
tests the interface between the user and assistant parts
|
| 87 |
+
"""
|
| 88 |
+
# pylint: disable=duplicate-code
|
| 89 |
+
prompter = AlpacaPrompter()
|
| 90 |
+
strat = AlpacaPromptTokenizingStrategy(
|
| 91 |
+
prompter,
|
| 92 |
+
self.tokenizer,
|
| 93 |
+
False,
|
| 94 |
+
2048,
|
| 95 |
+
)
|
| 96 |
+
sample = {"instruction": "hello!", "output": "Hi! How can I help?"}
|
| 97 |
+
example = strat.tokenize_prompt(sample)
|
| 98 |
+
world_idx = example["input_ids"].index(6324)
|
| 99 |
+
assert example["labels"][world_idx] == 6324
|
| 100 |
+
assert example["labels"][world_idx - 1] == -100
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
|
| 104 |
+
"""
|
| 105 |
+
Test class for prompt tokenization strategies with sys prompt from the dataset
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
def setUp(self) -> None:
|
| 109 |
+
# pylint: disable=duplicate-code
|
| 110 |
+
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
| 111 |
+
self.tokenizer.add_special_tokens(
|
| 112 |
+
{
|
| 113 |
+
"bos_token": "<s>",
|
| 114 |
+
"eos_token": "</s>",
|
| 115 |
+
"unk_token": "<unk>",
|
| 116 |
+
}
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
def test_system_alpaca(self):
|
| 120 |
+
prompter = SystemDataPrompter(PromptStyle.CHAT.value)
|
| 121 |
+
strat = InstructionWSystemPromptTokenizingStrategy(
|
| 122 |
+
prompter,
|
| 123 |
+
self.tokenizer,
|
| 124 |
+
False,
|
| 125 |
+
2048,
|
| 126 |
+
)
|
| 127 |
+
sample = {
|
| 128 |
+
"system": "use cot",
|
| 129 |
+
"instruction": "hello!",
|
| 130 |
+
"output": "Hi! How can I help?",
|
| 131 |
+
}
|
| 132 |
+
example = strat.tokenize_prompt(sample)
|
| 133 |
+
assert example["input_ids"][0:3] == [1, 671, 20118] # <s>use cot
|
| 134 |
+
assert example["input_ids"][3] == 11889 # USER
|
| 135 |
+
|
| 136 |
|
| 137 |
if __name__ == "__main__":
|
| 138 |
unittest.main()
|
tests/test_prompters.py
CHANGED
|
@@ -2,7 +2,13 @@
|
|
| 2 |
|
| 3 |
import unittest
|
| 4 |
|
| 5 |
-
from axolotl.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
class AlpacaPrompterTest(unittest.TestCase):
|
|
@@ -55,3 +61,64 @@ class AlpacaPrompterTest(unittest.TestCase):
|
|
| 55 |
assert "### Response:" not in res
|
| 56 |
assert "USER:" in res
|
| 57 |
assert "ASSISTANT:" in res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
import unittest
|
| 4 |
|
| 5 |
+
from axolotl.prompt_strategies.alpaca_w_system import SystemDataPrompter
|
| 6 |
+
from axolotl.prompters import (
|
| 7 |
+
AlpacaPrompter,
|
| 8 |
+
MultipleChoiceExplainPrompter,
|
| 9 |
+
PromptStyle,
|
| 10 |
+
UnpromptedPrompter,
|
| 11 |
+
)
|
| 12 |
|
| 13 |
|
| 14 |
class AlpacaPrompterTest(unittest.TestCase):
|
|
|
|
| 61 |
assert "### Response:" not in res
|
| 62 |
assert "USER:" in res
|
| 63 |
assert "ASSISTANT:" in res
|
| 64 |
+
|
| 65 |
+
def test_system_prompt(self):
|
| 66 |
+
prompter = SystemDataPrompter(prompt_style=PromptStyle.CHAT.value)
|
| 67 |
+
res = next(
|
| 68 |
+
prompter.build_prompt_w_system(
|
| 69 |
+
"use cot", "tell me a joke about the following", "alpacas"
|
| 70 |
+
)
|
| 71 |
+
)
|
| 72 |
+
assert "use cot" in res
|
| 73 |
+
assert res.startswith("use cot")
|
| 74 |
+
assert "### Instruction:" not in res
|
| 75 |
+
assert "### Input:" not in res
|
| 76 |
+
assert "alpacas" in res
|
| 77 |
+
assert "### Response:" not in res
|
| 78 |
+
assert "USER:" in res
|
| 79 |
+
assert "ASSISTANT:" in res
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class UnpromptedPrompterTest(unittest.TestCase):
|
| 83 |
+
"""
|
| 84 |
+
Test class for UnpromptedPrompter with no system prompts
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def test_prompt_style_w_none(self):
|
| 88 |
+
prompter = UnpromptedPrompter(prompt_style=None)
|
| 89 |
+
res = next(prompter.build_prompt("tell me a joke"))
|
| 90 |
+
assert "### Instruction:" in res
|
| 91 |
+
assert "tell me a joke" in res
|
| 92 |
+
assert res.startswith("###")
|
| 93 |
+
|
| 94 |
+
def test_prompt_style_w_instruct(self):
|
| 95 |
+
prompter = UnpromptedPrompter(prompt_style=PromptStyle.INSTRUCT.value)
|
| 96 |
+
res = next(
|
| 97 |
+
prompter.build_prompt("tell me a joke about the following", "alpacas")
|
| 98 |
+
)
|
| 99 |
+
assert "### Instruction:" in res
|
| 100 |
+
assert "tell me a joke" in res
|
| 101 |
+
assert res.startswith("###")
|
| 102 |
+
|
| 103 |
+
def test_prompt_style_w_chat(self):
|
| 104 |
+
prompter = UnpromptedPrompter(prompt_style=PromptStyle.CHAT.value)
|
| 105 |
+
res = next(
|
| 106 |
+
prompter.build_prompt("tell me a joke about the following", "alpacas")
|
| 107 |
+
)
|
| 108 |
+
assert "USER:" in res
|
| 109 |
+
assert "tell me a joke" in res
|
| 110 |
+
assert res.startswith("USER:")
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class MultipleChoiceExplainPrompterTest(unittest.TestCase):
|
| 114 |
+
"""
|
| 115 |
+
Test class for MultipleChoiceExplainPrompter
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
def test_prompt_style_w_chat(self):
|
| 119 |
+
prompter = MultipleChoiceExplainPrompter(prompt_style=PromptStyle.CHAT.value)
|
| 120 |
+
res = next(prompter.build_prompt("choose one", "- A\n- B\n- C", "C"))
|
| 121 |
+
assert "USER:" in res
|
| 122 |
+
assert "choose one" in res
|
| 123 |
+
assert "Choose the answer that best answers the question." in res
|
| 124 |
+
assert "- A\n- B\n- C" in res
|
tests/test_tokenizers.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test cases for the tokenizer loading
|
| 3 |
+
"""
|
| 4 |
+
import unittest
|
| 5 |
+
|
| 6 |
+
from axolotl.utils.dict import DictDefault
|
| 7 |
+
from axolotl.utils.models import load_tokenizer
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class TestTokenizers(unittest.TestCase):
|
| 11 |
+
"""
|
| 12 |
+
test class for the load_tokenizer fn
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def test_default_use_fast(self):
|
| 16 |
+
cfg = DictDefault({})
|
| 17 |
+
tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
|
| 18 |
+
assert "Fast" in tokenizer.__class__.__name__
|
| 19 |
+
|
| 20 |
+
def test_dont_use_fast(self):
|
| 21 |
+
cfg = DictDefault(
|
| 22 |
+
{
|
| 23 |
+
"tokenizer_use_fast": False,
|
| 24 |
+
}
|
| 25 |
+
)
|
| 26 |
+
tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
|
| 27 |
+
assert "Fast" not in tokenizer.__class__.__name__
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
if __name__ == "__main__":
|
| 31 |
+
unittest.main()
|
tests/test_validation.py
CHANGED
|
@@ -212,3 +212,104 @@ class ValidationTest(unittest.TestCase):
|
|
| 212 |
|
| 213 |
with pytest.raises(ValueError, match=regex_exp):
|
| 214 |
validate_config(cfg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
with pytest.raises(ValueError, match=regex_exp):
|
| 214 |
validate_config(cfg)
|
| 215 |
+
|
| 216 |
+
def test_flash_optimum(self):
|
| 217 |
+
cfg = DictDefault(
|
| 218 |
+
{
|
| 219 |
+
"flash_optimum": True,
|
| 220 |
+
"adapter": "lora",
|
| 221 |
+
}
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
with self._caplog.at_level(logging.WARNING):
|
| 225 |
+
validate_config(cfg)
|
| 226 |
+
assert any(
|
| 227 |
+
"BetterTransformers probably doesn't work with PEFT adapters"
|
| 228 |
+
in record.message
|
| 229 |
+
for record in self._caplog.records
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
cfg = DictDefault(
|
| 233 |
+
{
|
| 234 |
+
"flash_optimum": True,
|
| 235 |
+
}
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
with self._caplog.at_level(logging.WARNING):
|
| 239 |
+
validate_config(cfg)
|
| 240 |
+
assert any(
|
| 241 |
+
"probably set bfloat16 or float16" in record.message
|
| 242 |
+
for record in self._caplog.records
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
cfg = DictDefault(
|
| 246 |
+
{
|
| 247 |
+
"flash_optimum": True,
|
| 248 |
+
"fp16": True,
|
| 249 |
+
}
|
| 250 |
+
)
|
| 251 |
+
regex_exp = r".*AMP is not supported.*"
|
| 252 |
+
|
| 253 |
+
with pytest.raises(ValueError, match=regex_exp):
|
| 254 |
+
validate_config(cfg)
|
| 255 |
+
|
| 256 |
+
cfg = DictDefault(
|
| 257 |
+
{
|
| 258 |
+
"flash_optimum": True,
|
| 259 |
+
"bf16": True,
|
| 260 |
+
}
|
| 261 |
+
)
|
| 262 |
+
regex_exp = r".*AMP is not supported.*"
|
| 263 |
+
|
| 264 |
+
with pytest.raises(ValueError, match=regex_exp):
|
| 265 |
+
validate_config(cfg)
|
| 266 |
+
|
| 267 |
+
def test_adamw_hyperparams(self):
|
| 268 |
+
cfg = DictDefault(
|
| 269 |
+
{
|
| 270 |
+
"optimizer": None,
|
| 271 |
+
"adam_epsilon": 0.0001,
|
| 272 |
+
}
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
with self._caplog.at_level(logging.WARNING):
|
| 276 |
+
validate_config(cfg)
|
| 277 |
+
assert any(
|
| 278 |
+
"adamw hyperparameters found, but no adamw optimizer set"
|
| 279 |
+
in record.message
|
| 280 |
+
for record in self._caplog.records
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
cfg = DictDefault(
|
| 284 |
+
{
|
| 285 |
+
"optimizer": "adafactor",
|
| 286 |
+
"adam_beta1": 0.0001,
|
| 287 |
+
}
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
with self._caplog.at_level(logging.WARNING):
|
| 291 |
+
validate_config(cfg)
|
| 292 |
+
assert any(
|
| 293 |
+
"adamw hyperparameters found, but no adamw optimizer set"
|
| 294 |
+
in record.message
|
| 295 |
+
for record in self._caplog.records
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
cfg = DictDefault(
|
| 299 |
+
{
|
| 300 |
+
"optimizer": "adamw_bnb_8bit",
|
| 301 |
+
"adam_beta1": 0.9,
|
| 302 |
+
"adam_beta2": 0.99,
|
| 303 |
+
"adam_epsilon": 0.0001,
|
| 304 |
+
}
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
validate_config(cfg)
|
| 308 |
+
|
| 309 |
+
cfg = DictDefault(
|
| 310 |
+
{
|
| 311 |
+
"optimizer": "adafactor",
|
| 312 |
+
}
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
validate_config(cfg)
|