Adds chat templates (#1022)
Browse files- README.md +3 -0
- src/axolotl/utils/chat_templates.py +29 -0
- src/axolotl/utils/models.py +7 -0
README.md
CHANGED
|
@@ -589,6 +589,9 @@ datasets:
|
|
| 589 |
# For `completion` datsets only, uses the provided field instead of `text` column
|
| 590 |
field:
|
| 591 |
|
|
|
|
|
|
|
|
|
|
| 592 |
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
| 593 |
# subsequent training attempts load faster, relative path
|
| 594 |
dataset_prepared_path: data/last_run_prepared
|
|
|
|
| 589 |
# For `completion` datsets only, uses the provided field instead of `text` column
|
| 590 |
field:
|
| 591 |
|
| 592 |
+
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
| 593 |
+
# Currently supports chatml and inst (mistral/mixtral)
|
| 594 |
+
chat_template: chatml
|
| 595 |
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
| 596 |
# subsequent training attempts load faster, relative path
|
| 597 |
dataset_prepared_path: data/last_run_prepared
|
src/axolotl/utils/chat_templates.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This module provides functionality for selecting chat templates based on user choices.
|
| 3 |
+
These templates are used for formatting messages in a conversation.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def chat_templates(user_choice: str):
|
| 8 |
+
"""
|
| 9 |
+
Finds the correct chat_template for the tokenizer_config.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
user_choice (str): The user's choice of template.
|
| 13 |
+
|
| 14 |
+
Returns:
|
| 15 |
+
str: The chosen template string.
|
| 16 |
+
|
| 17 |
+
Raises:
|
| 18 |
+
ValueError: If the user_choice is not found in the templates.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
templates = {
|
| 22 |
+
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
| 23 |
+
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
if user_choice in templates:
|
| 27 |
+
return templates[user_choice]
|
| 28 |
+
|
| 29 |
+
raise ValueError(f"Template '{user_choice}' not found.")
|
src/axolotl/utils/models.py
CHANGED
|
@@ -26,6 +26,7 @@ from transformers.deepspeed import is_deepspeed_zero3_enabled
|
|
| 26 |
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
| 27 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
| 28 |
from axolotl.utils.bench import log_gpu_memory_usage
|
|
|
|
| 29 |
from axolotl.utils.dict import DictDefault
|
| 30 |
|
| 31 |
LOG = logging.getLogger("axolotl")
|
|
@@ -186,6 +187,12 @@ def load_tokenizer(cfg):
|
|
| 186 |
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
| 187 |
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
| 188 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
return tokenizer
|
| 190 |
|
| 191 |
|
|
|
|
| 26 |
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
| 27 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
| 28 |
from axolotl.utils.bench import log_gpu_memory_usage
|
| 29 |
+
from axolotl.utils.chat_templates import chat_templates
|
| 30 |
from axolotl.utils.dict import DictDefault
|
| 31 |
|
| 32 |
LOG = logging.getLogger("axolotl")
|
|
|
|
| 187 |
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
| 188 |
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
| 189 |
|
| 190 |
+
if cfg.chat_template:
|
| 191 |
+
tokenizer.chat_template = chat_templates(cfg.chat_template)
|
| 192 |
+
else:
|
| 193 |
+
LOG.info(
|
| 194 |
+
"No Chat template selected. Consider adding a chat template for easier inference."
|
| 195 |
+
)
|
| 196 |
return tokenizer
|
| 197 |
|
| 198 |
|