Fix(message): Improve error message for bad format (#365)
Browse files
src/axolotl/prompt_strategies/llama2_chat.py
CHANGED
|
@@ -29,7 +29,7 @@ from dataclasses import dataclass, field
|
|
| 29 |
from typing import Generator, List, Sequence
|
| 30 |
|
| 31 |
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
| 32 |
-
from axolotl.prompters import IGNORE_TOKEN_ID
|
| 33 |
|
| 34 |
|
| 35 |
@dataclass
|
|
@@ -190,7 +190,7 @@ class Llama2ChatPrompter: # pylint: disable=too-few-public-methods
|
|
| 190 |
conv.messages = [] # pylint: disable=R0801
|
| 191 |
for j, sentence in enumerate(source):
|
| 192 |
role = roles[sentence["from"]]
|
| 193 |
-
assert role == conv.roles[j % 2]
|
| 194 |
if sentence["value"]:
|
| 195 |
conv.append_message(role, sentence["value"])
|
| 196 |
yield conv
|
|
|
|
| 29 |
from typing import Generator, List, Sequence
|
| 30 |
|
| 31 |
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
| 32 |
+
from axolotl.prompters import IGNORE_TOKEN_ID, SHAREGPT_ASSERTION_FAILED_ROLE
|
| 33 |
|
| 34 |
|
| 35 |
@dataclass
|
|
|
|
| 190 |
conv.messages = [] # pylint: disable=R0801
|
| 191 |
for j, sentence in enumerate(source):
|
| 192 |
role = roles[sentence["from"]]
|
| 193 |
+
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
|
| 194 |
if sentence["value"]:
|
| 195 |
conv.append_message(role, sentence["value"])
|
| 196 |
yield conv
|
src/axolotl/prompters.py
CHANGED
|
@@ -260,6 +260,11 @@ class Conversation:
|
|
| 260 |
self.messages.append([role, message])
|
| 261 |
|
| 262 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
| 264 |
"""
|
| 265 |
A prompter that generates prompts for the ShareGPT
|
|
@@ -316,7 +321,7 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
|
| 316 |
conv.messages = []
|
| 317 |
for j, sentence in enumerate(source):
|
| 318 |
role = roles[sentence["from"]]
|
| 319 |
-
assert role == conv.roles[j % 2]
|
| 320 |
conv.append_message(role, sentence["value"])
|
| 321 |
|
| 322 |
for part in conv.get_prompt():
|
|
|
|
| 260 |
self.messages.append([role, message])
|
| 261 |
|
| 262 |
|
| 263 |
+
SHAREGPT_ASSERTION_FAILED_ROLE = (
|
| 264 |
+
"Role did not alternate between turns (gpt and human). Please check your data."
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
| 269 |
"""
|
| 270 |
A prompter that generates prompts for the ShareGPT
|
|
|
|
| 321 |
conv.messages = []
|
| 322 |
for j, sentence in enumerate(source):
|
| 323 |
role = roles[sentence["from"]]
|
| 324 |
+
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
|
| 325 |
conv.append_message(role, sentence["value"])
|
| 326 |
|
| 327 |
for part in conv.get_prompt():
|