update prompts for open orca to match the paper (#317)
Browse files
src/axolotl/prompt_strategies/alpaca_w_system.py
CHANGED
|
@@ -66,15 +66,34 @@ class SystemDataPrompter(AlpacaPrompter):
|
|
| 66 |
) -> Generator[str, None, None]:
|
| 67 |
# returns the full prompt from instruction and optional input
|
| 68 |
# if a label (=response, =output) is provided, it's also appended.
|
|
|
|
| 69 |
if input:
|
| 70 |
-
res =
|
|
|
|
|
|
|
| 71 |
else:
|
| 72 |
-
res =
|
|
|
|
|
|
|
| 73 |
if output:
|
| 74 |
res = f"{res}{output}"
|
| 75 |
yield res
|
| 76 |
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy):
|
| 79 |
"""
|
| 80 |
Tokenizing strategy for OpenOrca datasets
|
|
@@ -113,7 +132,7 @@ def load_chat(tokenizer, cfg):
|
|
| 113 |
|
| 114 |
def load_open_orca(tokenizer, cfg):
|
| 115 |
return OpenOrcaPromptTokenizingStrategy(
|
| 116 |
-
|
| 117 |
tokenizer,
|
| 118 |
cfg.train_on_inputs,
|
| 119 |
cfg.sequence_len,
|
|
|
|
| 66 |
) -> Generator[str, None, None]:
|
| 67 |
# returns the full prompt from instruction and optional input
|
| 68 |
# if a label (=response, =output) is provided, it's also appended.
|
| 69 |
+
formatted_sys_prompt = f"### System:\n{system}\n\n" if system else ""
|
| 70 |
if input:
|
| 71 |
+
res = formatted_sys_prompt + self.turn_format.format(
|
| 72 |
+
instruction=instruction, input=input
|
| 73 |
+
)
|
| 74 |
else:
|
| 75 |
+
res = formatted_sys_prompt + self.turn_no_input_format.format(
|
| 76 |
+
instruction=instruction
|
| 77 |
+
)
|
| 78 |
if output:
|
| 79 |
res = f"{res}{output}"
|
| 80 |
yield res
|
| 81 |
|
| 82 |
|
| 83 |
+
class OpenOrcaSystemDataPrompter(SystemDataPrompter):
|
| 84 |
+
"""
|
| 85 |
+
Alpaca Style Prompter that uses system prompts from the dataset, with OpenOrca prompts
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
def match_prompt_style(self):
|
| 89 |
+
if self.prompt_style == PromptStyle.INSTRUCT.value:
|
| 90 |
+
self.turn_format = "### User:\n{instruction}\n\n### Additional Context:\n{input}\n\n### Assistant:\n"
|
| 91 |
+
self.turn_no_input_format = "### User:\n{instruction}\n\n### Assistant:\n"
|
| 92 |
+
if self.prompt_style == PromptStyle.CHAT.value:
|
| 93 |
+
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
|
| 94 |
+
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
|
| 95 |
+
|
| 96 |
+
|
| 97 |
class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy):
|
| 98 |
"""
|
| 99 |
Tokenizing strategy for OpenOrca datasets
|
|
|
|
| 132 |
|
| 133 |
def load_open_orca(tokenizer, cfg):
|
| 134 |
return OpenOrcaPromptTokenizingStrategy(
|
| 135 |
+
OpenOrcaSystemDataPrompter(PromptStyle.INSTRUCT.value),
|
| 136 |
tokenizer,
|
| 137 |
cfg.train_on_inputs,
|
| 138 |
cfg.sequence_len,
|
tests/test_prompt_tokenizers.py
CHANGED
|
@@ -130,8 +130,9 @@ class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
|
|
| 130 |
"output": "Hi! How can I help?",
|
| 131 |
}
|
| 132 |
example = strat.tokenize_prompt(sample)
|
| 133 |
-
assert example["input_ids"][0:
|
| 134 |
-
assert example["input_ids"][
|
|
|
|
| 135 |
|
| 136 |
|
| 137 |
if __name__ == "__main__":
|
|
|
|
| 130 |
"output": "Hi! How can I help?",
|
| 131 |
}
|
| 132 |
example = strat.tokenize_prompt(sample)
|
| 133 |
+
assert example["input_ids"][0:4] == [1, 835, 2184, 29901] # "<s>### System:"
|
| 134 |
+
assert example["input_ids"][5:7] == [1509, 20118] # "use cot"
|
| 135 |
+
assert example["input_ids"][9] == 11889 # USER
|
| 136 |
|
| 137 |
|
| 138 |
if __name__ == "__main__":
|
tests/test_prompters.py
CHANGED
|
@@ -70,7 +70,7 @@ class AlpacaPrompterTest(unittest.TestCase):
|
|
| 70 |
)
|
| 71 |
)
|
| 72 |
assert "use cot" in res
|
| 73 |
-
assert res.startswith("
|
| 74 |
assert "### Instruction:" not in res
|
| 75 |
assert "### Input:" not in res
|
| 76 |
assert "alpacas" in res
|
|
|
|
| 70 |
)
|
| 71 |
)
|
| 72 |
assert "use cot" in res
|
| 73 |
+
assert res.startswith("### System:")
|
| 74 |
assert "### Instruction:" not in res
|
| 75 |
assert "### Input:" not in res
|
| 76 |
assert "alpacas" in res
|