experimental expansion of ctx len
Browse files- scripts/finetune.py +26 -18
- src/axolotl/utils/data.py +31 -1
scripts/finetune.py
CHANGED
|
@@ -6,22 +6,20 @@ import os
|
|
| 6 |
import random
|
| 7 |
import signal
|
| 8 |
import sys
|
| 9 |
-
from functools import partial
|
| 10 |
from pathlib import Path
|
| 11 |
from typing import Any, Dict, List, Optional, Union
|
| 12 |
|
| 13 |
import fire
|
| 14 |
import torch
|
| 15 |
import yaml
|
| 16 |
-
from transformers import GenerationConfig, TextStreamer
|
| 17 |
-
|
| 18 |
-
from axolotl.utils.data import load_prepare_datasets
|
| 19 |
-
from axolotl.utils.dict import DictDefault
|
| 20 |
-
from axolotl.utils.models import load_model, load_tokenizer
|
| 21 |
|
| 22 |
# add src to the pythonpath so we don't need to pip install this
|
| 23 |
from optimum.bettertransformer import BetterTransformer
|
|
|
|
| 24 |
|
|
|
|
|
|
|
|
|
|
| 25 |
from axolotl.utils.tokenization import check_dataset_labels
|
| 26 |
from axolotl.utils.trainer import setup_trainer
|
| 27 |
from axolotl.utils.validation import validate_config
|
|
@@ -204,9 +202,19 @@ def train(
|
|
| 204 |
if check_not_in(
|
| 205 |
["inference", "shard", "merge_lora"], kwargs
|
| 206 |
): # don't need to load dataset for these
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
if cfg.debug or "debug" in kwargs:
|
| 212 |
logging.info("check_dataset_labels...")
|
|
@@ -256,7 +264,7 @@ def train(
|
|
| 256 |
logging.info("check_dataset_labels...")
|
| 257 |
check_dataset_labels(
|
| 258 |
train_dataset.select(
|
| 259 |
-
[random.randrange(0, len(train_dataset) - 1) for i in range(5)]
|
| 260 |
),
|
| 261 |
tokenizer,
|
| 262 |
)
|
|
@@ -265,10 +273,7 @@ def train(
|
|
| 265 |
logging.info("Finished preparing dataset. Exiting...")
|
| 266 |
return
|
| 267 |
|
| 268 |
-
|
| 269 |
-
model.train()
|
| 270 |
-
except:
|
| 271 |
-
pass
|
| 272 |
|
| 273 |
trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
|
| 274 |
|
|
@@ -285,14 +290,15 @@ def train(
|
|
| 285 |
|
| 286 |
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
| 287 |
if cfg.local_rank == 0:
|
| 288 |
-
|
|
|
|
| 289 |
if cfg.flash_optimum:
|
| 290 |
model = BetterTransformer.reverse(model)
|
| 291 |
model.save_pretrained(cfg.output_dir)
|
| 292 |
sys.exit(0)
|
|
|
|
| 293 |
signal.signal(
|
| 294 |
-
signal.SIGINT,
|
| 295 |
-
lambda signum, frame: terminate_handler(signum, frame, model)
|
| 296 |
)
|
| 297 |
|
| 298 |
logging.info("Starting trainer...")
|
|
@@ -316,7 +322,9 @@ def train(
|
|
| 316 |
if not Path(cfg.output_dir).is_dir():
|
| 317 |
os.makedirs(cfg.output_dir, exist_ok=True)
|
| 318 |
if cfg.flash_optimum:
|
| 319 |
-
with torch.backends.cuda.sdp_kernel(
|
|
|
|
|
|
|
| 320 |
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
| 321 |
else:
|
| 322 |
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
|
|
|
| 6 |
import random
|
| 7 |
import signal
|
| 8 |
import sys
|
|
|
|
| 9 |
from pathlib import Path
|
| 10 |
from typing import Any, Dict, List, Optional, Union
|
| 11 |
|
| 12 |
import fire
|
| 13 |
import torch
|
| 14 |
import yaml
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
# add src to the pythonpath so we don't need to pip install this
|
| 17 |
from optimum.bettertransformer import BetterTransformer
|
| 18 |
+
from transformers import GenerationConfig, TextStreamer
|
| 19 |
|
| 20 |
+
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
| 21 |
+
from axolotl.utils.dict import DictDefault
|
| 22 |
+
from axolotl.utils.models import load_model, load_tokenizer
|
| 23 |
from axolotl.utils.tokenization import check_dataset_labels
|
| 24 |
from axolotl.utils.trainer import setup_trainer
|
| 25 |
from axolotl.utils.validation import validate_config
|
|
|
|
| 202 |
if check_not_in(
|
| 203 |
["inference", "shard", "merge_lora"], kwargs
|
| 204 |
): # don't need to load dataset for these
|
| 205 |
+
if not cfg.pretraining_dataset:
|
| 206 |
+
train_dataset, eval_dataset = load_prepare_datasets(
|
| 207 |
+
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
| 208 |
+
)
|
| 209 |
+
else:
|
| 210 |
+
if cfg.pretraining_dataset is True:
|
| 211 |
+
pretraining_dataset = "togethercomputer/RedPajama-Data-1T"
|
| 212 |
+
else:
|
| 213 |
+
pretraining_dataset = cfg.pretraining_dataset
|
| 214 |
+
train_dataset = load_pretraining_dataset(
|
| 215 |
+
pretraining_dataset, tokenizer, max_tokens=cfg.sequence_len
|
| 216 |
+
)
|
| 217 |
+
eval_dataset = None
|
| 218 |
|
| 219 |
if cfg.debug or "debug" in kwargs:
|
| 220 |
logging.info("check_dataset_labels...")
|
|
|
|
| 264 |
logging.info("check_dataset_labels...")
|
| 265 |
check_dataset_labels(
|
| 266 |
train_dataset.select(
|
| 267 |
+
[random.randrange(0, len(train_dataset) - 1) for i in range(5)] # nosec
|
| 268 |
),
|
| 269 |
tokenizer,
|
| 270 |
)
|
|
|
|
| 273 |
logging.info("Finished preparing dataset. Exiting...")
|
| 274 |
return
|
| 275 |
|
| 276 |
+
model.train()
|
|
|
|
|
|
|
|
|
|
| 277 |
|
| 278 |
trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
|
| 279 |
|
|
|
|
| 290 |
|
| 291 |
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
| 292 |
if cfg.local_rank == 0:
|
| 293 |
+
|
| 294 |
+
def terminate_handler(_, __, model):
|
| 295 |
if cfg.flash_optimum:
|
| 296 |
model = BetterTransformer.reverse(model)
|
| 297 |
model.save_pretrained(cfg.output_dir)
|
| 298 |
sys.exit(0)
|
| 299 |
+
|
| 300 |
signal.signal(
|
| 301 |
+
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
|
|
|
|
| 302 |
)
|
| 303 |
|
| 304 |
logging.info("Starting trainer...")
|
|
|
|
| 322 |
if not Path(cfg.output_dir).is_dir():
|
| 323 |
os.makedirs(cfg.output_dir, exist_ok=True)
|
| 324 |
if cfg.flash_optimum:
|
| 325 |
+
with torch.backends.cuda.sdp_kernel(
|
| 326 |
+
enable_flash=True, enable_math=True, enable_mem_efficient=True
|
| 327 |
+
):
|
| 328 |
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
| 329 |
else:
|
| 330 |
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
src/axolotl/utils/data.py
CHANGED
|
@@ -5,7 +5,8 @@ from hashlib import md5
|
|
| 5 |
from pathlib import Path
|
| 6 |
from typing import List, Tuple, Union
|
| 7 |
|
| 8 |
-
|
|
|
|
| 9 |
from huggingface_hub import hf_hub_download
|
| 10 |
from transformers import PreTrainedTokenizerBase
|
| 11 |
|
|
@@ -392,3 +393,32 @@ def load_prepare_datasets(
|
|
| 392 |
eval_dataset = dataset["test"]
|
| 393 |
|
| 394 |
return train_dataset, eval_dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
from pathlib import Path
|
| 6 |
from typing import List, Tuple, Union
|
| 7 |
|
| 8 |
+
import torch
|
| 9 |
+
from datasets import Dataset, DatasetDict, IterableDataset, load_dataset, load_from_disk
|
| 10 |
from huggingface_hub import hf_hub_download
|
| 11 |
from transformers import PreTrainedTokenizerBase
|
| 12 |
|
|
|
|
| 393 |
eval_dataset = dataset["test"]
|
| 394 |
|
| 395 |
return train_dataset, eval_dataset
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
class PretrainingDatasetWrapper(IterableDataset):
|
| 399 |
+
"""
|
| 400 |
+
Wrapper for pretraining dataset that avoids loading the dataset into memory
|
| 401 |
+
"""
|
| 402 |
+
|
| 403 |
+
def __init__(self, tokenizer, dataset_path, max_tokens=2048):
|
| 404 |
+
self.tokenizer = tokenizer
|
| 405 |
+
self.dataset_path = dataset_path
|
| 406 |
+
self.max_tokens = max_tokens
|
| 407 |
+
|
| 408 |
+
def __iter__(self):
|
| 409 |
+
buffer = []
|
| 410 |
+
for sample in load_dataset(
|
| 411 |
+
self.dataset_path,
|
| 412 |
+
name="all",
|
| 413 |
+
split="train",
|
| 414 |
+
streaming=True,
|
| 415 |
+
).shuffle(buffer_size=10000):
|
| 416 |
+
buffer += self.tokenizer(sample["text"])["input_ids"]
|
| 417 |
+
buffer += [self.tokenizer.eos_token_id]
|
| 418 |
+
while len(buffer) > self.max_tokens:
|
| 419 |
+
yield torch.tensor(buffer[: self.max_tokens])
|
| 420 |
+
buffer = buffer[self.max_tokens :]
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
def load_pretraining_dataset(path, tokenizer, max_tokens=2048):
|
| 424 |
+
return PretrainingDatasetWrapper(tokenizer, path, max_tokens=max_tokens)
|