split completion text to sequence_len (#616)
Browse files
src/axolotl/datasets.py
CHANGED
|
@@ -38,10 +38,15 @@ class TokenizedPromptDataset(Dataset):
|
|
| 38 |
def process(self, dataset):
|
| 39 |
features = dataset.features.keys()
|
| 40 |
num_proc = min(64, os.cpu_count())
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
return dataset.map(
|
| 42 |
self.prompt_tokenizer.tokenize_prompt,
|
| 43 |
num_proc=num_proc,
|
| 44 |
remove_columns=features,
|
|
|
|
| 45 |
)
|
| 46 |
|
| 47 |
|
|
|
|
| 38 |
def process(self, dataset):
|
| 39 |
features = dataset.features.keys()
|
| 40 |
num_proc = min(64, os.cpu_count())
|
| 41 |
+
map_kwargs = {}
|
| 42 |
+
if self.prompt_tokenizer.supports_batched:
|
| 43 |
+
map_kwargs["batched"] = True
|
| 44 |
+
map_kwargs["batch_size"] = 100
|
| 45 |
return dataset.map(
|
| 46 |
self.prompt_tokenizer.tokenize_prompt,
|
| 47 |
num_proc=num_proc,
|
| 48 |
remove_columns=features,
|
| 49 |
+
**map_kwargs,
|
| 50 |
)
|
| 51 |
|
| 52 |
|
src/axolotl/prompt_strategies/completion.py
CHANGED
|
@@ -1,10 +1,81 @@
|
|
| 1 |
"""
|
| 2 |
Basic completion text
|
| 3 |
"""
|
| 4 |
-
from
|
|
|
|
| 5 |
|
| 6 |
-
from axolotl.prompt_tokenizers import
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|
@@ -13,6 +84,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|
| 13 |
tokenizer,
|
| 14 |
cfg.train_on_inputs,
|
| 15 |
cfg.sequence_len,
|
|
|
|
| 16 |
)
|
| 17 |
if ds_cfg and "field" in ds_cfg:
|
| 18 |
strat.field = ds_cfg["field"]
|
|
|
|
| 1 |
"""
|
| 2 |
Basic completion text
|
| 3 |
"""
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
from typing import Any, Dict, Generator, Optional, Tuple
|
| 6 |
|
| 7 |
+
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
| 11 |
+
"""
|
| 12 |
+
Tokenizing strategy for Completion prompts.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
_field: str = "text"
|
| 16 |
+
|
| 17 |
+
def __init__(self, *args, max_length=None, **kwargs):
|
| 18 |
+
super().__init__(*args, **kwargs)
|
| 19 |
+
if max_length is not None:
|
| 20 |
+
self.max_length = max_length
|
| 21 |
+
|
| 22 |
+
@property
|
| 23 |
+
def supports_batched(self):
|
| 24 |
+
return True
|
| 25 |
+
|
| 26 |
+
@property
|
| 27 |
+
def field(self) -> str:
|
| 28 |
+
return self._field
|
| 29 |
+
|
| 30 |
+
@field.setter
|
| 31 |
+
def field(self, new_field: str):
|
| 32 |
+
self._field = new_field
|
| 33 |
+
|
| 34 |
+
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
| 35 |
+
return (
|
| 36 |
+
prompt[self.field],
|
| 37 |
+
"",
|
| 38 |
+
"",
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
def tokenize_prompt(self, prompt):
|
| 42 |
+
res = defaultdict(lambda: [])
|
| 43 |
+
feature_names = list(prompt.keys())
|
| 44 |
+
for row in zip(*prompt.values()):
|
| 45 |
+
prompt_row = dict(zip(feature_names, row))
|
| 46 |
+
(
|
| 47 |
+
instruction,
|
| 48 |
+
_,
|
| 49 |
+
_,
|
| 50 |
+
) = self.parse_instruction_fields(prompt_row)
|
| 51 |
+
|
| 52 |
+
full_prompt = self._build_full_prompt(instruction, None, None)
|
| 53 |
+
tokenized_full_prompt = self._tokenize(full_prompt)
|
| 54 |
+
|
| 55 |
+
for key, val in tokenized_full_prompt.items():
|
| 56 |
+
for i in range(0, len(val), self.sequence_len):
|
| 57 |
+
res[key].append(val[i : i + self.sequence_len])
|
| 58 |
+
|
| 59 |
+
return dict(res)
|
| 60 |
+
|
| 61 |
+
def _build_full_prompt(
|
| 62 |
+
self, instruction, input, response
|
| 63 |
+
): # pylint: disable=redefined-builtin
|
| 64 |
+
return next(iter(self.prompter.build_prompt(instruction, input, response)))
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class CompletionPrompter:
|
| 68 |
+
"""
|
| 69 |
+
Prompter for completion
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
def build_prompt(
|
| 73 |
+
self,
|
| 74 |
+
instruction: str,
|
| 75 |
+
input=None, # pylint: disable=redefined-builtin, unused-argument
|
| 76 |
+
output=None, # pylint: disable=unused-argument
|
| 77 |
+
) -> Generator[str, None, None]:
|
| 78 |
+
yield instruction
|
| 79 |
|
| 80 |
|
| 81 |
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|
|
|
| 84 |
tokenizer,
|
| 85 |
cfg.train_on_inputs,
|
| 86 |
cfg.sequence_len,
|
| 87 |
+
max_length=cfg.sequence_len * 64,
|
| 88 |
)
|
| 89 |
if ds_cfg and "field" in ds_cfg:
|
| 90 |
strat.field = ds_cfg["field"]
|
src/axolotl/prompt_tokenizers.py
CHANGED
|
@@ -41,11 +41,16 @@ class PromptTokenizingStrategy(abc.ABC):
|
|
| 41 |
self.tokenizer: PreTrainedTokenizer = tokenizer
|
| 42 |
self.train_on_inputs = train_on_inputs
|
| 43 |
self.sequence_len = sequence_len
|
|
|
|
| 44 |
|
| 45 |
@abc.abstractmethod
|
| 46 |
def tokenize_prompt(self, prompt):
|
| 47 |
pass
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
@functools.lru_cache(maxsize=128)
|
| 50 |
def _get_user_token(self):
|
| 51 |
try:
|
|
@@ -77,7 +82,7 @@ class PromptTokenizingStrategy(abc.ABC):
|
|
| 77 |
result = self.tokenizer(
|
| 78 |
prompt,
|
| 79 |
truncation=True,
|
| 80 |
-
max_length=self.
|
| 81 |
padding=False,
|
| 82 |
return_tensors=None,
|
| 83 |
)
|
|
@@ -86,7 +91,7 @@ class PromptTokenizingStrategy(abc.ABC):
|
|
| 86 |
if (
|
| 87 |
len(result["input_ids"]) > 0
|
| 88 |
and result["input_ids"][-1] != self.tokenizer.eos_token_id
|
| 89 |
-
and len(result["input_ids"]) < self.
|
| 90 |
and add_eos_token
|
| 91 |
):
|
| 92 |
result["input_ids"].append(self.tokenizer.eos_token_id)
|
|
@@ -247,46 +252,6 @@ class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
|
| 247 |
)
|
| 248 |
|
| 249 |
|
| 250 |
-
class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
| 251 |
-
"""
|
| 252 |
-
Tokenizing strategy for Completion prompts.
|
| 253 |
-
"""
|
| 254 |
-
|
| 255 |
-
_field: str = "text"
|
| 256 |
-
|
| 257 |
-
@property
|
| 258 |
-
def field(self) -> str:
|
| 259 |
-
return self._field
|
| 260 |
-
|
| 261 |
-
@field.setter
|
| 262 |
-
def field(self, new_field: str):
|
| 263 |
-
self._field = new_field
|
| 264 |
-
|
| 265 |
-
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
| 266 |
-
return (
|
| 267 |
-
prompt[self.field],
|
| 268 |
-
"",
|
| 269 |
-
"",
|
| 270 |
-
)
|
| 271 |
-
|
| 272 |
-
def tokenize_prompt(self, prompt):
|
| 273 |
-
(
|
| 274 |
-
instruction,
|
| 275 |
-
_,
|
| 276 |
-
_,
|
| 277 |
-
) = self.parse_instruction_fields(prompt)
|
| 278 |
-
|
| 279 |
-
full_prompt = self._build_full_prompt(instruction, None, None)
|
| 280 |
-
tokenized_full_prompt = self._tokenize(full_prompt)
|
| 281 |
-
|
| 282 |
-
return tokenized_full_prompt
|
| 283 |
-
|
| 284 |
-
def _build_full_prompt(
|
| 285 |
-
self, instruction, input, response
|
| 286 |
-
): # pylint: disable=redefined-builtin
|
| 287 |
-
return next(iter(self.prompter.build_prompt(instruction, input, response)))
|
| 288 |
-
|
| 289 |
-
|
| 290 |
class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
| 291 |
"""
|
| 292 |
Tokenizing strategy for Reflection prompts.
|
|
|
|
| 41 |
self.tokenizer: PreTrainedTokenizer = tokenizer
|
| 42 |
self.train_on_inputs = train_on_inputs
|
| 43 |
self.sequence_len = sequence_len
|
| 44 |
+
self.max_length = sequence_len
|
| 45 |
|
| 46 |
@abc.abstractmethod
|
| 47 |
def tokenize_prompt(self, prompt):
|
| 48 |
pass
|
| 49 |
|
| 50 |
+
@property
|
| 51 |
+
def supports_batched(self):
|
| 52 |
+
return False
|
| 53 |
+
|
| 54 |
@functools.lru_cache(maxsize=128)
|
| 55 |
def _get_user_token(self):
|
| 56 |
try:
|
|
|
|
| 82 |
result = self.tokenizer(
|
| 83 |
prompt,
|
| 84 |
truncation=True,
|
| 85 |
+
max_length=self.max_length,
|
| 86 |
padding=False,
|
| 87 |
return_tensors=None,
|
| 88 |
)
|
|
|
|
| 91 |
if (
|
| 92 |
len(result["input_ids"]) > 0
|
| 93 |
and result["input_ids"][-1] != self.tokenizer.eos_token_id
|
| 94 |
+
and len(result["input_ids"]) < self.max_length
|
| 95 |
and add_eos_token
|
| 96 |
):
|
| 97 |
result["input_ids"].append(self.tokenizer.eos_token_id)
|
|
|
|
| 252 |
)
|
| 253 |
|
| 254 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
| 256 |
"""
|
| 257 |
Tokenizing strategy for Reflection prompts.
|
src/axolotl/prompters.py
CHANGED
|
@@ -135,20 +135,6 @@ class SummarizeTLDRPrompter(AlpacaPrompter):
|
|
| 135 |
self.turn_no_input_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
|
| 136 |
|
| 137 |
|
| 138 |
-
class CompletionPrompter:
|
| 139 |
-
"""
|
| 140 |
-
Prompter for completion
|
| 141 |
-
"""
|
| 142 |
-
|
| 143 |
-
def build_prompt(
|
| 144 |
-
self,
|
| 145 |
-
instruction: str,
|
| 146 |
-
input=None, # pylint: disable=redefined-builtin, unused-argument
|
| 147 |
-
output=None, # pylint: disable=unused-argument
|
| 148 |
-
) -> Generator[str, None, None]:
|
| 149 |
-
yield instruction
|
| 150 |
-
|
| 151 |
-
|
| 152 |
class GPTeacherPrompter(AlpacaPrompter):
|
| 153 |
"""
|
| 154 |
Prompter for GPTeacher
|
|
|
|
| 135 |
self.turn_no_input_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
|
| 136 |
|
| 137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
class GPTeacherPrompter(AlpacaPrompter):
|
| 139 |
"""
|
| 140 |
Prompter for GPTeacher
|