support user defined prompters, pretokenized datasets in config, local parquet, local arrow files (#348)
Browse files* support user defined prompters, pretokenized datasets in config, local parquet, local arrow files
* fix user defined dataset types
* fix for system prompts
* fix tests
* fix checks for parquet and arrow
* aha moment that d.data_files isn't used
* add documentation for ds_type to add support for parquet and arrow
README.md
CHANGED
|
@@ -392,6 +392,7 @@ datasets:
|
|
| 392 |
- path: vicgalle/alpaca-gpt4
|
| 393 |
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
| 394 |
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
|
|
|
| 395 |
data_files: # path to source data files
|
| 396 |
shards: # number of shards to split data into
|
| 397 |
name: # name of dataset configuration to load
|
|
|
|
| 392 |
- path: vicgalle/alpaca-gpt4
|
| 393 |
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
| 394 |
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
| 395 |
+
ds_type: # Optional[str] (json|arrow|parquet) defines the datatype when path is a file
|
| 396 |
data_files: # path to source data files
|
| 397 |
shards: # number of shards to split data into
|
| 398 |
name: # name of dataset configuration to load
|
src/axolotl/prompt_strategies/__init__.py
CHANGED
|
@@ -2,8 +2,10 @@
|
|
| 2 |
|
| 3 |
import importlib
|
| 4 |
|
|
|
|
| 5 |
|
| 6 |
-
|
|
|
|
| 7 |
try:
|
| 8 |
load_fn = "load"
|
| 9 |
if strategy.split(".")[-1].startswith("load_"):
|
|
@@ -11,6 +13,9 @@ def load(strategy, tokenizer, cfg):
|
|
| 11 |
strategy = ".".join(strategy.split(".")[:-1])
|
| 12 |
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies")
|
| 13 |
func = getattr(mod, load_fn)
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
| 15 |
except Exception: # pylint: disable=broad-exception-caught
|
| 16 |
return None
|
|
|
|
| 2 |
|
| 3 |
import importlib
|
| 4 |
|
| 5 |
+
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
|
| 6 |
|
| 7 |
+
|
| 8 |
+
def load(strategy, tokenizer, cfg, ds_cfg):
|
| 9 |
try:
|
| 10 |
load_fn = "load"
|
| 11 |
if strategy.split(".")[-1].startswith("load_"):
|
|
|
|
| 13 |
strategy = ".".join(strategy.split(".")[:-1])
|
| 14 |
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies")
|
| 15 |
func = getattr(mod, load_fn)
|
| 16 |
+
load_kwargs = {}
|
| 17 |
+
if strategy == "user_defined":
|
| 18 |
+
load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg)
|
| 19 |
+
return func(tokenizer, cfg, **load_kwargs)
|
| 20 |
except Exception: # pylint: disable=broad-exception-caught
|
| 21 |
return None
|
src/axolotl/prompt_strategies/alpaca_w_system.py
CHANGED
|
@@ -57,6 +57,8 @@ class SystemDataPrompter(AlpacaPrompter):
|
|
| 57 |
Alpaca Style Prompter that uses system prompts from the dataset
|
| 58 |
"""
|
| 59 |
|
|
|
|
|
|
|
| 60 |
def build_prompt_w_system(
|
| 61 |
self,
|
| 62 |
system: str,
|
|
|
|
| 57 |
Alpaca Style Prompter that uses system prompts from the dataset
|
| 58 |
"""
|
| 59 |
|
| 60 |
+
system_format: str = "### System:\n{system}\n\n"
|
| 61 |
+
|
| 62 |
def build_prompt_w_system(
|
| 63 |
self,
|
| 64 |
system: str,
|
src/axolotl/prompt_strategies/user_defined.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
User Defined prompts with configuration from the YML config
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from functools import partial
|
| 7 |
+
from typing import Optional, Tuple
|
| 8 |
+
|
| 9 |
+
from axolotl.prompt_strategies.alpaca_w_system import (
|
| 10 |
+
InstructionWSystemPromptTokenizingStrategy,
|
| 11 |
+
SystemDataPrompter,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class UserDefinedDatasetConfig:
|
| 17 |
+
"""
|
| 18 |
+
dataclass configuration representing a userdefined dataset type
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
system_prompt: str = ""
|
| 22 |
+
field_system: str = "system"
|
| 23 |
+
field_instruction: str = "instruction"
|
| 24 |
+
field_input: str = "input"
|
| 25 |
+
field_output: str = "output"
|
| 26 |
+
format: str = "{instruction} {input} "
|
| 27 |
+
no_input_format: str = "{instruction} "
|
| 28 |
+
system_format: str = "{system}"
|
| 29 |
+
|
| 30 |
+
def __getitem__(self, item):
|
| 31 |
+
return getattr(self, item)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class UserDefinedPromptTokenizationStrategy(InstructionWSystemPromptTokenizingStrategy):
|
| 35 |
+
"""
|
| 36 |
+
Prompt Tokenization Strategy for user defined prompts
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def load(tokenizer, cfg, ds_cfg: Optional[UserDefinedDatasetConfig] = None):
|
| 41 |
+
if not ds_cfg:
|
| 42 |
+
raise ValueError("Missing dataset prompt configuration")
|
| 43 |
+
|
| 44 |
+
system_prompt = ""
|
| 45 |
+
if ds_cfg.system_prompt:
|
| 46 |
+
system_prompt = ds_cfg.system_prompt
|
| 47 |
+
|
| 48 |
+
def parse_instruction_fields(
|
| 49 |
+
field_instruction,
|
| 50 |
+
field_input,
|
| 51 |
+
field_output,
|
| 52 |
+
field_system,
|
| 53 |
+
system_prompt,
|
| 54 |
+
prompt,
|
| 55 |
+
) -> Tuple[str, str, str, str]:
|
| 56 |
+
return (
|
| 57 |
+
prompt[field_instruction],
|
| 58 |
+
prompt[field_input] if field_input in prompt else "",
|
| 59 |
+
prompt[field_output] if field_output in prompt else "",
|
| 60 |
+
prompt[field_system] if field_system in prompt else system_prompt,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
turn_format = ds_cfg.format
|
| 64 |
+
turn_no_input_format = ds_cfg.no_input_format
|
| 65 |
+
system_format = ds_cfg.system_format
|
| 66 |
+
|
| 67 |
+
class UserDefinedPrompter(SystemDataPrompter):
|
| 68 |
+
"""
|
| 69 |
+
Prompter for user defined prompts
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
def match_prompt_style(self):
|
| 73 |
+
self.turn_format = turn_format
|
| 74 |
+
self.turn_no_input_format = turn_no_input_format
|
| 75 |
+
self.system_format = system_format
|
| 76 |
+
|
| 77 |
+
prompter = UserDefinedPrompter()
|
| 78 |
+
|
| 79 |
+
strat = UserDefinedPromptTokenizationStrategy(
|
| 80 |
+
prompter,
|
| 81 |
+
tokenizer,
|
| 82 |
+
cfg.train_on_inputs,
|
| 83 |
+
cfg.sequence_len,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
setattr(
|
| 87 |
+
strat,
|
| 88 |
+
"parse_instruction_fields",
|
| 89 |
+
partial(
|
| 90 |
+
parse_instruction_fields,
|
| 91 |
+
ds_cfg.field_instruction,
|
| 92 |
+
ds_cfg.field_input,
|
| 93 |
+
ds_cfg.field_output,
|
| 94 |
+
ds_cfg.field_system,
|
| 95 |
+
system_prompt,
|
| 96 |
+
),
|
| 97 |
+
)
|
| 98 |
+
return strat
|
src/axolotl/prompters.py
CHANGED
|
@@ -26,7 +26,7 @@ class AlpacaPrompter:
|
|
| 26 |
|
| 27 |
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
|
| 28 |
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
|
| 29 |
-
system_format: str
|
| 30 |
turn_format: str
|
| 31 |
turn_no_input_format: str
|
| 32 |
prompt_style: Optional[PromptStyle] = None
|
|
@@ -63,13 +63,17 @@ class AlpacaPrompter:
|
|
| 63 |
# returns the full prompt from instruction and optional input
|
| 64 |
# if a label (=response, =output) is provided, it's also appended.
|
| 65 |
if input:
|
| 66 |
-
res =
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
| 69 |
else:
|
| 70 |
-
res =
|
| 71 |
-
|
| 72 |
-
|
|
|
|
|
|
|
| 73 |
if output:
|
| 74 |
res = f"{res}{output}"
|
| 75 |
yield res
|
|
|
|
| 26 |
|
| 27 |
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
|
| 28 |
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
|
| 29 |
+
system_format: str = "{system}"
|
| 30 |
turn_format: str
|
| 31 |
turn_no_input_format: str
|
| 32 |
prompt_style: Optional[PromptStyle] = None
|
|
|
|
| 63 |
# returns the full prompt from instruction and optional input
|
| 64 |
# if a label (=response, =output) is provided, it's also appended.
|
| 65 |
if input:
|
| 66 |
+
res = (
|
| 67 |
+
self.system_format.format(system=self.system_prompt)
|
| 68 |
+
if self.system_prompt
|
| 69 |
+
else ""
|
| 70 |
+
) + self.turn_format.format(instruction=instruction, input=input)
|
| 71 |
else:
|
| 72 |
+
res = (
|
| 73 |
+
self.system_format.format(system=self.system_no_input_prompt)
|
| 74 |
+
if self.system_prompt
|
| 75 |
+
else ""
|
| 76 |
+
) + self.turn_no_input_format.format(instruction=instruction)
|
| 77 |
if output:
|
| 78 |
res = f"{res}{output}"
|
| 79 |
yield res
|
src/axolotl/utils/data.py
CHANGED
|
@@ -41,6 +41,7 @@ from axolotl.prompters import (
|
|
| 41 |
ShareGPTPrompter,
|
| 42 |
SummarizeTLDRPrompter,
|
| 43 |
)
|
|
|
|
| 44 |
from axolotl.utils.distributed import is_main_process, zero_first
|
| 45 |
from axolotl.utils.trainer import (
|
| 46 |
calculate_total_num_steps,
|
|
@@ -160,8 +161,15 @@ def load_tokenized_prepared_datasets(
|
|
| 160 |
split=None,
|
| 161 |
)
|
| 162 |
elif local_path.is_file():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
ds = load_dataset(
|
| 164 |
-
|
| 165 |
name=d.name,
|
| 166 |
data_files=d.path,
|
| 167 |
streaming=False,
|
|
@@ -198,13 +206,27 @@ def load_tokenized_prepared_datasets(
|
|
| 198 |
)
|
| 199 |
else:
|
| 200 |
ds = ds.shuffle(seed=seed).shard(num_shards=d.shards, index=0)
|
|
|
|
|
|
|
| 201 |
d_type = d.type
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
|
|
|
| 205 |
if "train" in ds:
|
| 206 |
ds = ds["train"]
|
| 207 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
| 209 |
datasets.append(ds_wrapper)
|
| 210 |
elif d_base_type == "alpaca":
|
|
|
|
| 41 |
ShareGPTPrompter,
|
| 42 |
SummarizeTLDRPrompter,
|
| 43 |
)
|
| 44 |
+
from axolotl.utils.dict import DictDefault
|
| 45 |
from axolotl.utils.distributed import is_main_process, zero_first
|
| 46 |
from axolotl.utils.trainer import (
|
| 47 |
calculate_total_num_steps,
|
|
|
|
| 161 |
split=None,
|
| 162 |
)
|
| 163 |
elif local_path.is_file():
|
| 164 |
+
ds_type = "json"
|
| 165 |
+
if d.ds_type:
|
| 166 |
+
ds_type = d.ds_type
|
| 167 |
+
elif ".parquet" in d.path:
|
| 168 |
+
ds_type = "parquet"
|
| 169 |
+
elif ".arrow" in d.path:
|
| 170 |
+
ds_type = "arrow"
|
| 171 |
ds = load_dataset(
|
| 172 |
+
ds_type,
|
| 173 |
name=d.name,
|
| 174 |
data_files=d.path,
|
| 175 |
streaming=False,
|
|
|
|
| 206 |
)
|
| 207 |
else:
|
| 208 |
ds = ds.shuffle(seed=seed).shard(num_shards=d.shards, index=0)
|
| 209 |
+
|
| 210 |
+
d_base_type = d_prompt_style = None
|
| 211 |
d_type = d.type
|
| 212 |
+
if isinstance(d_type, str):
|
| 213 |
+
d_type_split = d_type.split(":")
|
| 214 |
+
d_base_type = d_type_split[0]
|
| 215 |
+
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
|
| 216 |
if "train" in ds:
|
| 217 |
ds = ds["train"]
|
| 218 |
+
if (
|
| 219 |
+
"input_ids" in ds.features
|
| 220 |
+
and "attention_mask" in ds.features
|
| 221 |
+
and "labels" in ds.features
|
| 222 |
+
):
|
| 223 |
+
# dataset is already tokenized, just drop it straight in
|
| 224 |
+
datasets.append(ds)
|
| 225 |
+
elif isinstance(d.type, DictDefault):
|
| 226 |
+
ds_strategy = load("user_defined", tokenizer, cfg, d.type.to_dict())
|
| 227 |
+
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
| 228 |
+
datasets.append(ds_wrapper)
|
| 229 |
+
elif ds_strategy := load(d.type, tokenizer, cfg, d):
|
| 230 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
| 231 |
datasets.append(ds_wrapper)
|
| 232 |
elif d_base_type == "alpaca":
|