support custom field for completion from yml (#580)
Browse files* support custom field for completion from yml
* remove legacy completion check and add doc
* update README docs
README.md
CHANGED
|
@@ -322,6 +322,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
|
| 322 |
- path: EleutherAI/pile
|
| 323 |
name: enron_emails
|
| 324 |
type: completion # format from earlier
|
|
|
|
| 325 |
|
| 326 |
# huggingface repo with multiple named configurations/subsets
|
| 327 |
datasets:
|
|
@@ -444,6 +445,9 @@ datasets:
|
|
| 444 |
# 'no_input_format' cannot include {input}
|
| 445 |
no_input_format: "{instruction} "
|
| 446 |
|
|
|
|
|
|
|
|
|
|
| 447 |
# axolotl attempts to save the dataset as an arrow after packing the data together so
|
| 448 |
# subsequent training attempts load faster, relative path
|
| 449 |
dataset_prepared_path: data/last_run_prepared
|
|
|
|
| 322 |
- path: EleutherAI/pile
|
| 323 |
name: enron_emails
|
| 324 |
type: completion # format from earlier
|
| 325 |
+
field: text # Optional[str] default: text, field to use for completion data
|
| 326 |
|
| 327 |
# huggingface repo with multiple named configurations/subsets
|
| 328 |
datasets:
|
|
|
|
| 445 |
# 'no_input_format' cannot include {input}
|
| 446 |
no_input_format: "{instruction} "
|
| 447 |
|
| 448 |
+
# for completions datsets, uses the provided field if not `text`
|
| 449 |
+
field:
|
| 450 |
+
|
| 451 |
# axolotl attempts to save the dataset as an arrow after packing the data together so
|
| 452 |
# subsequent training attempts load faster, relative path
|
| 453 |
dataset_prepared_path: data/last_run_prepared
|
src/axolotl/prompt_strategies/__init__.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
"""Module to load prompt strategies."""
|
| 2 |
|
| 3 |
import importlib
|
|
|
|
| 4 |
|
| 5 |
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
|
| 6 |
|
|
@@ -16,6 +17,10 @@ def load(strategy, tokenizer, cfg, ds_cfg):
|
|
| 16 |
load_kwargs = {}
|
| 17 |
if strategy == "user_defined":
|
| 18 |
load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
return func(tokenizer, cfg, **load_kwargs)
|
| 20 |
except Exception: # pylint: disable=broad-exception-caught
|
| 21 |
return None
|
|
|
|
| 1 |
"""Module to load prompt strategies."""
|
| 2 |
|
| 3 |
import importlib
|
| 4 |
+
import inspect
|
| 5 |
|
| 6 |
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
|
| 7 |
|
|
|
|
| 17 |
load_kwargs = {}
|
| 18 |
if strategy == "user_defined":
|
| 19 |
load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg)
|
| 20 |
+
else:
|
| 21 |
+
sig = inspect.signature(func)
|
| 22 |
+
if "ds_cfg" in sig.parameters:
|
| 23 |
+
load_kwargs["ds_cfg"] = ds_cfg
|
| 24 |
return func(tokenizer, cfg, **load_kwargs)
|
| 25 |
except Exception: # pylint: disable=broad-exception-caught
|
| 26 |
return None
|
src/axolotl/prompt_strategies/completion.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Basic completion text
|
| 3 |
+
"""
|
| 4 |
+
from typing import Any, Dict, Optional
|
| 5 |
+
|
| 6 |
+
from axolotl.prompt_tokenizers import CompletionPromptTokenizingStrategy
|
| 7 |
+
from axolotl.prompters import CompletionPrompter
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
| 11 |
+
strat = CompletionPromptTokenizingStrategy(
|
| 12 |
+
CompletionPrompter(),
|
| 13 |
+
tokenizer,
|
| 14 |
+
cfg.train_on_inputs,
|
| 15 |
+
cfg.sequence_len,
|
| 16 |
+
)
|
| 17 |
+
if ds_cfg and "field" in ds_cfg:
|
| 18 |
+
strat.field = ds_cfg["field"]
|
| 19 |
+
|
| 20 |
+
return strat
|
src/axolotl/prompt_tokenizers.py
CHANGED
|
@@ -245,8 +245,31 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
|
| 245 |
Tokenizing strategy for Completion prompts.
|
| 246 |
"""
|
| 247 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
def tokenize_prompt(self, prompt):
|
| 249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
tokenized_full_prompt = self._tokenize(full_prompt)
|
| 251 |
|
| 252 |
return tokenized_full_prompt
|
|
|
|
| 245 |
Tokenizing strategy for Completion prompts.
|
| 246 |
"""
|
| 247 |
|
| 248 |
+
_field: str = "text"
|
| 249 |
+
|
| 250 |
+
@property
|
| 251 |
+
def field(self) -> str:
|
| 252 |
+
return self._field
|
| 253 |
+
|
| 254 |
+
@field.setter
|
| 255 |
+
def field(self, new_field: str):
|
| 256 |
+
self._field = new_field
|
| 257 |
+
|
| 258 |
+
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
| 259 |
+
return (
|
| 260 |
+
prompt[self.field],
|
| 261 |
+
"",
|
| 262 |
+
"",
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
def tokenize_prompt(self, prompt):
|
| 266 |
+
(
|
| 267 |
+
instruction,
|
| 268 |
+
_,
|
| 269 |
+
_,
|
| 270 |
+
) = self.parse_instruction_fields(prompt)
|
| 271 |
+
|
| 272 |
+
full_prompt = self._build_full_prompt(instruction, None, None)
|
| 273 |
tokenized_full_prompt = self._tokenize(full_prompt)
|
| 274 |
|
| 275 |
return tokenized_full_prompt
|
src/axolotl/utils/data.py
CHANGED
|
@@ -22,7 +22,6 @@ from axolotl.prompt_tokenizers import (
|
|
| 22 |
AlpacaMultipleChoicePromptTokenizingStrategy,
|
| 23 |
AlpacaPromptTokenizingStrategy,
|
| 24 |
AlpacaReflectionPTStrategy,
|
| 25 |
-
CompletionPromptTokenizingStrategy,
|
| 26 |
GPTeacherPromptTokenizingStrategy,
|
| 27 |
JeopardyPromptTokenizingStrategy,
|
| 28 |
OpenAssistantPromptTokenizingStrategy,
|
|
@@ -31,7 +30,6 @@ from axolotl.prompt_tokenizers import (
|
|
| 31 |
)
|
| 32 |
from axolotl.prompters import (
|
| 33 |
AlpacaPrompter,
|
| 34 |
-
CompletionPrompter,
|
| 35 |
GPTeacherPrompter,
|
| 36 |
JeopardyPrompter,
|
| 37 |
MultipleChoiceConcisePrompter,
|
|
@@ -327,15 +325,6 @@ def load_tokenized_prepared_datasets(
|
|
| 327 |
)
|
| 328 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
| 329 |
datasets.append(ds_wrapper)
|
| 330 |
-
elif d_base_type == "completion":
|
| 331 |
-
ds_strategy = CompletionPromptTokenizingStrategy(
|
| 332 |
-
CompletionPrompter(),
|
| 333 |
-
tokenizer,
|
| 334 |
-
cfg.train_on_inputs,
|
| 335 |
-
cfg.sequence_len,
|
| 336 |
-
)
|
| 337 |
-
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
| 338 |
-
datasets.append(ds_wrapper)
|
| 339 |
else:
|
| 340 |
suffix = ""
|
| 341 |
if ":load_" in d.type:
|
|
|
|
| 22 |
AlpacaMultipleChoicePromptTokenizingStrategy,
|
| 23 |
AlpacaPromptTokenizingStrategy,
|
| 24 |
AlpacaReflectionPTStrategy,
|
|
|
|
| 25 |
GPTeacherPromptTokenizingStrategy,
|
| 26 |
JeopardyPromptTokenizingStrategy,
|
| 27 |
OpenAssistantPromptTokenizingStrategy,
|
|
|
|
| 30 |
)
|
| 31 |
from axolotl.prompters import (
|
| 32 |
AlpacaPrompter,
|
|
|
|
| 33 |
GPTeacherPrompter,
|
| 34 |
JeopardyPrompter,
|
| 35 |
MultipleChoiceConcisePrompter,
|
|
|
|
| 325 |
)
|
| 326 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
| 327 |
datasets.append(ds_wrapper)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
else:
|
| 329 |
suffix = ""
|
| 330 |
if ":load_" in d.type:
|