fix: `train_on_inputs: true` ignored for sharegpt (#1045) [skip ci]
Browse files* fix: `train_on_inputs: true` ignored for sharegpt
* enable unit test for train_on_inputs for sharegpt
---------
Co-authored-by: Wing Lian <[email protected]>
src/axolotl/prompt_tokenizers.py
CHANGED
|
@@ -379,10 +379,12 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
| 379 |
add_eos_token=False,
|
| 380 |
strip_bos_token=True,
|
| 381 |
)
|
| 382 |
-
|
| 383 |
-
|
|
|
|
|
|
|
|
|
|
| 384 |
elif assistant in role:
|
| 385 |
-
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
|
| 386 |
role = (
|
| 387 |
role.replace(role_remap[1]["from"], role_remap[1]["to"])
|
| 388 |
if role_remap
|
|
@@ -406,18 +408,24 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
| 406 |
add_eos_token=False,
|
| 407 |
strip_bos_token=True,
|
| 408 |
)
|
| 409 |
-
# not masked out from labels
|
| 410 |
labels = copy.deepcopy(res["input_ids"])
|
| 411 |
-
|
| 412 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 413 |
elif role == "":
|
| 414 |
turn = content
|
| 415 |
# this is only ever the first part, should include the bos token and the user query
|
| 416 |
res = self._tokenize(
|
| 417 |
turn, add_eos_token=False, strip_bos_token=False
|
| 418 |
)
|
| 419 |
-
|
| 420 |
-
|
|
|
|
|
|
|
|
|
|
| 421 |
else:
|
| 422 |
LOG.warning(f"unhandled role: {role}")
|
| 423 |
continue
|
|
|
|
| 379 |
add_eos_token=False,
|
| 380 |
strip_bos_token=True,
|
| 381 |
)
|
| 382 |
+
if self.train_on_inputs:
|
| 383 |
+
labels = copy.deepcopy(res["input_ids"])
|
| 384 |
+
else:
|
| 385 |
+
# everything from this is masked out from the labels
|
| 386 |
+
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
| 387 |
elif assistant in role:
|
|
|
|
| 388 |
role = (
|
| 389 |
role.replace(role_remap[1]["from"], role_remap[1]["to"])
|
| 390 |
if role_remap
|
|
|
|
| 408 |
add_eos_token=False,
|
| 409 |
strip_bos_token=True,
|
| 410 |
)
|
|
|
|
| 411 |
labels = copy.deepcopy(res["input_ids"])
|
| 412 |
+
if not self.train_on_inputs:
|
| 413 |
+
# mask out role tokens from the labels
|
| 414 |
+
len_role = len(role_res["input_ids"])
|
| 415 |
+
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
|
| 416 |
+
len_role, len(labels)
|
| 417 |
+
)
|
| 418 |
elif role == "":
|
| 419 |
turn = content
|
| 420 |
# this is only ever the first part, should include the bos token and the user query
|
| 421 |
res = self._tokenize(
|
| 422 |
turn, add_eos_token=False, strip_bos_token=False
|
| 423 |
)
|
| 424 |
+
if self.train_on_inputs:
|
| 425 |
+
labels = copy.deepcopy(res["input_ids"])
|
| 426 |
+
else:
|
| 427 |
+
# everything from this is masked out from the labels
|
| 428 |
+
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
| 429 |
else:
|
| 430 |
LOG.warning(f"unhandled role: {role}")
|
| 431 |
continue
|
tests/prompt_strategies/test_sharegpt.py
CHANGED
|
@@ -104,7 +104,7 @@ class TestSharegpt:
|
|
| 104 |
role_key_human=None,
|
| 105 |
),
|
| 106 |
tokenizer,
|
| 107 |
-
|
| 108 |
2048, # sequence_len
|
| 109 |
)
|
| 110 |
|
|
@@ -124,30 +124,30 @@ class TestSharegpt:
|
|
| 124 |
]
|
| 125 |
# fmt: on
|
| 126 |
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
|
|
|
| 104 |
role_key_human=None,
|
| 105 |
),
|
| 106 |
tokenizer,
|
| 107 |
+
False, # train_on_inputs
|
| 108 |
2048, # sequence_len
|
| 109 |
)
|
| 110 |
|
|
|
|
| 124 |
]
|
| 125 |
# fmt: on
|
| 126 |
|
| 127 |
+
def test_no_train_on_input(self, sharegpt_dataset, tokenizer):
|
| 128 |
+
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
| 129 |
+
ShareGPTPrompterV2(
|
| 130 |
+
conversation="chatml",
|
| 131 |
+
role_key_model=None,
|
| 132 |
+
role_key_human=None,
|
| 133 |
+
),
|
| 134 |
+
tokenizer,
|
| 135 |
+
True, # train_on_inputs
|
| 136 |
+
2048, # sequence_len
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
dataset_wrapper = TokenizedPromptDataset(
|
| 140 |
+
strategy, sharegpt_dataset, process_count=1
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
labels = dataset_wrapper[0]["labels"]
|
| 144 |
+
# fmt: off
|
| 145 |
+
assert labels == [
|
| 146 |
+
1, # bos
|
| 147 |
+
32001, 1587, 13, 25997, 32000, 28705, 13, # system
|
| 148 |
+
32001, 2188, 13, 21558, 32000, 28705, 13, # human
|
| 149 |
+
32001, 13892, 13, 21558, 32000, 28705, 13, # gpt
|
| 150 |
+
32001, 2188, 13, 12684, 17664, 32000, 28705, 13, # human
|
| 151 |
+
32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt
|
| 152 |
+
]
|
| 153 |
+
# fmt: on
|