Feat/chatml add system message (#1117)
Browse files* add system message to template
* readme update
* added code to register new system message
* register chatml template for test
---------
Co-authored-by: Mads Henrichsen <[email protected]>
Co-authored-by: Wing Lian <[email protected]>
README.md
CHANGED
|
@@ -613,6 +613,8 @@ rl:
|
|
| 613 |
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
| 614 |
# Currently supports chatml and inst (mistral/mixtral)
|
| 615 |
chat_template: chatml
|
|
|
|
|
|
|
| 616 |
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
| 617 |
# subsequent training attempts load faster, relative path
|
| 618 |
dataset_prepared_path: data/last_run_prepared
|
|
|
|
| 613 |
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
| 614 |
# Currently supports chatml and inst (mistral/mixtral)
|
| 615 |
chat_template: chatml
|
| 616 |
+
# Changes the default system message
|
| 617 |
+
default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml.
|
| 618 |
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
| 619 |
# subsequent training attempts load faster, relative path
|
| 620 |
dataset_prepared_path: data/last_run_prepared
|
src/axolotl/cli/preprocess.py
CHANGED
|
@@ -18,6 +18,7 @@ from axolotl.cli import (
|
|
| 18 |
)
|
| 19 |
from axolotl.common.cli import PreprocessCliArgs
|
| 20 |
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
|
|
|
| 21 |
|
| 22 |
LOG = logging.getLogger("axolotl.cli.preprocess")
|
| 23 |
|
|
@@ -34,6 +35,12 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|
| 34 |
return_remaining_strings=True
|
| 35 |
)
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
if not parsed_cfg.dataset_prepared_path:
|
| 38 |
msg = (
|
| 39 |
Fore.RED
|
|
|
|
| 18 |
)
|
| 19 |
from axolotl.common.cli import PreprocessCliArgs
|
| 20 |
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
| 21 |
+
from axolotl.prompt_strategies.sharegpt import register_chatml_template
|
| 22 |
|
| 23 |
LOG = logging.getLogger("axolotl.cli.preprocess")
|
| 24 |
|
|
|
|
| 35 |
return_remaining_strings=True
|
| 36 |
)
|
| 37 |
|
| 38 |
+
if parsed_cfg.chat_template == "chatml" and parsed_cfg.default_system_message:
|
| 39 |
+
LOG.info(
|
| 40 |
+
f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
|
| 41 |
+
)
|
| 42 |
+
register_chatml_template(parsed_cfg.default_system_message)
|
| 43 |
+
|
| 44 |
if not parsed_cfg.dataset_prepared_path:
|
| 45 |
msg = (
|
| 46 |
Fore.RED
|
src/axolotl/cli/train.py
CHANGED
|
@@ -18,6 +18,7 @@ from axolotl.cli import (
|
|
| 18 |
print_axolotl_text_art,
|
| 19 |
)
|
| 20 |
from axolotl.common.cli import TrainerCliArgs
|
|
|
|
| 21 |
from axolotl.train import train
|
| 22 |
|
| 23 |
LOG = logging.getLogger("axolotl.cli.train")
|
|
@@ -37,7 +38,12 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
|
| 37 |
print_axolotl_text_art()
|
| 38 |
check_accelerate_default_config()
|
| 39 |
check_user_token()
|
| 40 |
-
if cfg.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
| 42 |
else:
|
| 43 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
|
|
| 18 |
print_axolotl_text_art,
|
| 19 |
)
|
| 20 |
from axolotl.common.cli import TrainerCliArgs
|
| 21 |
+
from axolotl.prompt_strategies.sharegpt import register_chatml_template
|
| 22 |
from axolotl.train import train
|
| 23 |
|
| 24 |
LOG = logging.getLogger("axolotl.cli.train")
|
|
|
|
| 38 |
print_axolotl_text_art()
|
| 39 |
check_accelerate_default_config()
|
| 40 |
check_user_token()
|
| 41 |
+
if cfg.chat_template == "chatml" and cfg.default_system_message:
|
| 42 |
+
LOG.info(
|
| 43 |
+
f"ChatML set. Adding default system message: {cfg.default_system_message}"
|
| 44 |
+
)
|
| 45 |
+
register_chatml_template(cfg.default_system_message)
|
| 46 |
+
|
| 47 |
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
| 48 |
else:
|
| 49 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
src/axolotl/prompt_strategies/sharegpt.py
CHANGED
|
@@ -6,16 +6,19 @@ from fastchat.conversation import Conversation, SeparatorStyle, register_conv_te
|
|
| 6 |
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
| 7 |
from axolotl.prompters import ShareGPTPrompterV2
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
)
|
| 18 |
-
)
|
| 19 |
|
| 20 |
|
| 21 |
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|
|
|
| 6 |
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
| 7 |
from axolotl.prompters import ShareGPTPrompterV2
|
| 8 |
|
| 9 |
+
|
| 10 |
+
def register_chatml_template(system_message=None):
|
| 11 |
+
system_message = system_message or "You are a helpful assistant."
|
| 12 |
+
register_conv_template(
|
| 13 |
+
Conversation(
|
| 14 |
+
name="chatml",
|
| 15 |
+
system_template="<|im_start|>system\n{system_message}",
|
| 16 |
+
system_message=system_message,
|
| 17 |
+
roles=["<|im_start|>user", "<|im_start|>assistant"],
|
| 18 |
+
sep_style=SeparatorStyle.CHATML,
|
| 19 |
+
sep="<|im_end|>",
|
| 20 |
+
)
|
| 21 |
)
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
src/axolotl/utils/chat_templates.py
CHANGED
|
@@ -20,7 +20,7 @@ def chat_templates(user_choice: str):
|
|
| 20 |
|
| 21 |
templates = {
|
| 22 |
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
| 23 |
-
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in
|
| 24 |
}
|
| 25 |
|
| 26 |
if user_choice in templates:
|
|
|
|
| 20 |
|
| 21 |
templates = {
|
| 22 |
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
| 23 |
+
"chatml": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
| 24 |
}
|
| 25 |
|
| 26 |
if user_choice in templates:
|
src/axolotl/utils/models.py
CHANGED
|
@@ -219,7 +219,13 @@ def load_tokenizer(cfg):
|
|
| 219 |
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
| 220 |
|
| 221 |
if cfg.chat_template:
|
| 222 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
else:
|
| 224 |
LOG.info(
|
| 225 |
"No Chat template selected. Consider adding a chat template for easier inference."
|
|
|
|
| 219 |
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
| 220 |
|
| 221 |
if cfg.chat_template:
|
| 222 |
+
chat_template_string = chat_templates(cfg.chat_template)
|
| 223 |
+
if cfg.default_system_message and cfg.chat_template == "chatml":
|
| 224 |
+
chat_template_string = chat_template_string.replace(
|
| 225 |
+
"You are a helpful assistant.", cfg.default_system_message
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
tokenizer.chat_template = chat_template_string
|
| 229 |
else:
|
| 230 |
LOG.info(
|
| 231 |
"No Chat template selected. Consider adding a chat template for easier inference."
|
tests/prompt_strategies/test_sharegpt.py
CHANGED
|
@@ -7,9 +7,14 @@ from tokenizers import AddedToken
|
|
| 7 |
from transformers import AutoTokenizer
|
| 8 |
|
| 9 |
from axolotl.datasets import TokenizedPromptDataset
|
| 10 |
-
from axolotl.prompt_strategies.sharegpt import
|
|
|
|
|
|
|
|
|
|
| 11 |
from axolotl.prompters import ShareGPTPrompterV2
|
| 12 |
|
|
|
|
|
|
|
| 13 |
|
| 14 |
@pytest.fixture(name="sharegpt_dataset")
|
| 15 |
def fixture_sharegpt_dataset():
|
|
|
|
| 7 |
from transformers import AutoTokenizer
|
| 8 |
|
| 9 |
from axolotl.datasets import TokenizedPromptDataset
|
| 10 |
+
from axolotl.prompt_strategies.sharegpt import (
|
| 11 |
+
SimpleShareGPTPromptTokenizingStrategy,
|
| 12 |
+
register_chatml_template,
|
| 13 |
+
)
|
| 14 |
from axolotl.prompters import ShareGPTPrompterV2
|
| 15 |
|
| 16 |
+
register_chatml_template()
|
| 17 |
+
|
| 18 |
|
| 19 |
@pytest.fixture(name="sharegpt_dataset")
|
| 20 |
def fixture_sharegpt_dataset():
|