File size: 6,416 Bytes
			
			| e7d3e2d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 | """
monkeypatch to add a get_turns method
"""
import logging
from typing import Generator, Tuple
from fastchat.conversation import SeparatorStyle
LOG = logging.getLogger("axolotl.monkeypatch.fastchat_conversation_turns")
def get_prompt(self) -> str:
    ret = ""
    for role, msg in self.get_turns():
        ret += role + msg
    return ret
def get_turns(  # pylint: disable=too-many-return-statements
    self,
) -> Generator[Tuple[str, str], None, None]:
    """Get the prompt for generation."""
    system_prompt = self.system_template.format(system_message=self.system_message)
    if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
        yield "", system_prompt + self.sep
        for role, message in self.messages:
            if message:
                yield role + ": ", message + self.sep
            else:
                yield role + ":", ""
        return
    if self.sep_style == SeparatorStyle.ADD_COLON_TWO:
        seps = [self.sep, self.sep2]
        yield "", system_prompt + seps[0]
        for i, (role, message) in enumerate(self.messages):
            if message:
                yield role + ": ", message + seps[i % 2]
            else:
                yield role + ":", ""
        return
    if self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
        yield "", system_prompt + self.sep
        for role, message in self.messages:
            if message:
                yield role + ": ", message + self.sep
            else:
                yield role + ": ", ""  # must be end with a space
        return
    if self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
        yield "", "" if system_prompt == "" else system_prompt + self.sep
        for role, message in self.messages:
            if message:
                yield role + "\n", message + self.sep
            else:
                yield role + "\n", ""
        return
    if self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
        yield "", system_prompt
        for role, message in self.messages:
            if message:
                yield role, message + self.sep
            else:
                yield role, ""
        return
    if self.sep_style == SeparatorStyle.NO_COLON_TWO:
        seps = [self.sep, self.sep2]
        yield "", system_prompt
        for i, (role, message) in enumerate(self.messages):
            if message:
                yield role, message + seps[i % 2]
            else:
                yield role, ""
        return
    if self.sep_style == SeparatorStyle.RWKV:
        yield "", system_prompt
        for i, (role, message) in enumerate(self.messages):
            if message:
                yield role + ": ", message.replace("\r\n", "\n").replace(
                    "\n\n", "\n"
                ) + "\n\n"
            else:
                yield role + ":", ""
        return
    if self.sep_style == SeparatorStyle.LLAMA2:
        seps = [self.sep, self.sep2]
        if self.system_message:
            yield "", system_prompt
        else:
            yield "", "[INST] "
        for i, (role, message) in enumerate(self.messages[1:]):
            if message:
                yield role + " ", message + seps[i % 2]
            else:
                yield role, ""
        return
    if self.sep_style == SeparatorStyle.CHATGLM:
        # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
        # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
        round_add_n = 1 if self.name == "chatglm2" else 0
        if system_prompt:
            yield "", system_prompt + self.sep
        for i, (role, message) in enumerate(self.messages):
            if i % 2 == 0:
                yield "", f"[Round {i//2 + round_add_n}]{self.sep}"
            if message:
                yield f"{role}:", f"{message}{self.sep}"
            else:
                yield f"{role}:", ""
        return
    if self.sep_style == SeparatorStyle.CHATML:
        yield "", "" if system_prompt == "" else system_prompt + self.sep + "\n"
        for role, message in self.messages:
            if message:
                yield role + "\n", message + self.sep + "\n"
            else:
                yield role + "\n", ""
        return
    if self.sep_style == SeparatorStyle.CHATINTERN:
        # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
        seps = [self.sep, self.sep2]
        yield "", system_prompt
        for i, (role, message) in enumerate(self.messages):
            prefix = "<s>" if i % 2 == 0 else ""
            if message:
                yield prefix + role + ":", message + seps[i % 2] + "\n"
            else:
                yield role + ":", ""
        return
    if self.sep_style == SeparatorStyle.DOLLY:
        seps = [self.sep, self.sep2]
        yield "", system_prompt
        for i, (role, message) in enumerate(self.messages):
            if message:
                suffix = "\n\n" if i % 2 == 1 else ""
                yield role + ":\n", message + seps[i % 2] + suffix
            else:
                yield role + ":\n", ""
        return
    if self.sep_style == SeparatorStyle.PHOENIX:
        yield "", system_prompt
        for role, message in self.messages:
            if message:
                yield role + ": ", "<s>" + message + "</s>"
            else:
                yield role + ": " + "<s>", ""
        return
    if self.sep_style == SeparatorStyle.ROBIN:
        yield "", system_prompt + self.sep
        for role, message in self.messages:
            if message:
                yield role + ":\n", message + self.sep
            else:
                yield role + ":\n", ""
        return
    if self.sep_style == SeparatorStyle.FALCON_CHAT:
        if self.system_message:
            yield "", system_prompt + self.sep
        for role, message in self.messages:
            if message:
                yield role + ": ", message + self.sep
            else:
                yield role + ":", ""
    else:
        raise ValueError(f"Invalid style: {self.sep_style}")
def add_get_turns_to_conversation():
    import fastchat.conversation
    fastchat.conversation.Conversation.get_turns = get_turns
    fastchat.conversation.Conversation.get_prompt = get_prompt
 |