fixes for alpaca w chatml, and don't include attention_mask w mistral for flash attention (#728)
Browse files
src/axolotl/prompt_strategies/alpaca_chat.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
-
"""Module
|
| 2 |
|
| 3 |
-
from typing import Tuple
|
| 4 |
|
| 5 |
from axolotl.prompt_tokenizers import (
|
| 6 |
AlpacaPromptTokenizingStrategy,
|
|
@@ -9,9 +9,13 @@ from axolotl.prompt_tokenizers import (
|
|
| 9 |
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
|
| 10 |
|
| 11 |
|
| 12 |
-
def load(tokenizer, cfg):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
return AlpacaPromptTokenizingStrategy(
|
| 14 |
-
AlpacaPrompter(
|
| 15 |
tokenizer,
|
| 16 |
cfg.train_on_inputs,
|
| 17 |
cfg.sequence_len,
|
|
|
|
| 1 |
+
"""Module for Alpaca prompt strategy classes"""
|
| 2 |
|
| 3 |
+
from typing import Any, Dict, Optional, Tuple
|
| 4 |
|
| 5 |
from axolotl.prompt_tokenizers import (
|
| 6 |
AlpacaPromptTokenizingStrategy,
|
|
|
|
| 9 |
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
|
| 10 |
|
| 11 |
|
| 12 |
+
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
| 13 |
+
prompt_style = PromptStyle.CHAT.value
|
| 14 |
+
if ds_cfg and "conversation" in ds_cfg:
|
| 15 |
+
prompt_style = ds_cfg["conversation"]
|
| 16 |
+
|
| 17 |
return AlpacaPromptTokenizingStrategy(
|
| 18 |
+
AlpacaPrompter(prompt_style),
|
| 19 |
tokenizer,
|
| 20 |
cfg.train_on_inputs,
|
| 21 |
cfg.sequence_len,
|
src/axolotl/utils/trainer.py
CHANGED
|
@@ -423,7 +423,9 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
|
| 423 |
)
|
| 424 |
|
| 425 |
# Phi doesn't want the attention_mask feature when training
|
| 426 |
-
if "CodeGenTokenizer" in tokenizer.__class__.__name__
|
|
|
|
|
|
|
| 427 |
train_dataset = train_dataset.remove_columns("attention_mask")
|
| 428 |
if eval_dataset:
|
| 429 |
eval_dataset = eval_dataset.remove_columns("attention_mask")
|
|
|
|
| 423 |
)
|
| 424 |
|
| 425 |
# Phi doesn't want the attention_mask feature when training
|
| 426 |
+
if "CodeGenTokenizer" in tokenizer.__class__.__name__ or (
|
| 427 |
+
cfg.is_mistral_derived_model and cfg.flash_attention
|
| 428 |
+
):
|
| 429 |
train_dataset = train_dataset.remove_columns("attention_mask")
|
| 430 |
if eval_dataset:
|
| 431 |
eval_dataset = eval_dataset.remove_columns("attention_mask")
|