add flash attn context for efficient training and attempt setting model to train mode:
Browse files- scripts/finetune.py +23 -1
scripts/finetune.py
CHANGED
|
@@ -252,6 +252,24 @@ def train(
|
|
| 252 |
model.save_pretrained(cfg.output_dir)
|
| 253 |
return
|
| 254 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
|
| 256 |
|
| 257 |
model.config.use_cache = False
|
|
@@ -297,7 +315,11 @@ def train(
|
|
| 297 |
|
| 298 |
if not Path(cfg.output_dir).is_dir():
|
| 299 |
os.makedirs(cfg.output_dir, exist_ok=True)
|
| 300 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
|
| 302 |
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
| 303 |
|
|
|
|
| 252 |
model.save_pretrained(cfg.output_dir)
|
| 253 |
return
|
| 254 |
|
| 255 |
+
if cfg.debug:
|
| 256 |
+
logging.info("check_dataset_labels...")
|
| 257 |
+
check_dataset_labels(
|
| 258 |
+
train_dataset.select(
|
| 259 |
+
[random.randrange(0, len(train_dataset) - 1) for i in range(5)]
|
| 260 |
+
),
|
| 261 |
+
tokenizer,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
if prepare_ds_only:
|
| 265 |
+
logging.info("Finished preparing dataset. Exiting...")
|
| 266 |
+
return
|
| 267 |
+
|
| 268 |
+
try:
|
| 269 |
+
model.train()
|
| 270 |
+
except:
|
| 271 |
+
pass
|
| 272 |
+
|
| 273 |
trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
|
| 274 |
|
| 275 |
model.config.use_cache = False
|
|
|
|
| 315 |
|
| 316 |
if not Path(cfg.output_dir).is_dir():
|
| 317 |
os.makedirs(cfg.output_dir, exist_ok=True)
|
| 318 |
+
if cfg.flash_optimum:
|
| 319 |
+
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
|
| 320 |
+
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
| 321 |
+
else:
|
| 322 |
+
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
| 323 |
|
| 324 |
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
| 325 |
|