|  | """ | 
					
						
						|  | monkeypatch to add a get_turns method | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | import logging | 
					
						
						|  | from typing import Generator, Tuple | 
					
						
						|  |  | 
					
						
						|  | from fastchat.conversation import SeparatorStyle | 
					
						
						|  |  | 
					
						
						|  | LOG = logging.getLogger("axolotl.monkeypatch.fastchat_conversation_turns") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_prompt(self) -> str: | 
					
						
						|  | ret = "" | 
					
						
						|  | for role, msg in self.get_turns(): | 
					
						
						|  | ret += role + msg | 
					
						
						|  | return ret | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_turns( | 
					
						
						|  | self, | 
					
						
						|  | ) -> Generator[Tuple[str, str], None, None]: | 
					
						
						|  | """Get the prompt for generation.""" | 
					
						
						|  | system_prompt = self.system_template.format(system_message=self.system_message) | 
					
						
						|  | if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE: | 
					
						
						|  | yield "", system_prompt + self.sep | 
					
						
						|  | for role, message in self.messages: | 
					
						
						|  | if message: | 
					
						
						|  | yield role + ": ", message + self.sep | 
					
						
						|  | else: | 
					
						
						|  | yield role + ":", "" | 
					
						
						|  | return | 
					
						
						|  | if self.sep_style == SeparatorStyle.ADD_COLON_TWO: | 
					
						
						|  | seps = [self.sep, self.sep2] | 
					
						
						|  | yield "", system_prompt + seps[0] | 
					
						
						|  | for i, (role, message) in enumerate(self.messages): | 
					
						
						|  | if message: | 
					
						
						|  | yield role + ": ", message + seps[i % 2] | 
					
						
						|  | else: | 
					
						
						|  | yield role + ":", "" | 
					
						
						|  | return | 
					
						
						|  | if self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: | 
					
						
						|  | yield "", system_prompt + self.sep | 
					
						
						|  | for role, message in self.messages: | 
					
						
						|  | if message: | 
					
						
						|  | yield role + ": ", message + self.sep | 
					
						
						|  | else: | 
					
						
						|  | yield role + ": ", "" | 
					
						
						|  | return | 
					
						
						|  | if self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE: | 
					
						
						|  | yield "", "" if system_prompt == "" else system_prompt + self.sep | 
					
						
						|  | for role, message in self.messages: | 
					
						
						|  | if message: | 
					
						
						|  | yield role + "\n", message + self.sep | 
					
						
						|  | else: | 
					
						
						|  | yield role + "\n", "" | 
					
						
						|  | return | 
					
						
						|  | if self.sep_style == SeparatorStyle.NO_COLON_SINGLE: | 
					
						
						|  | yield "", system_prompt | 
					
						
						|  | for role, message in self.messages: | 
					
						
						|  | if message: | 
					
						
						|  | yield role, message + self.sep | 
					
						
						|  | else: | 
					
						
						|  | yield role, "" | 
					
						
						|  | return | 
					
						
						|  | if self.sep_style == SeparatorStyle.NO_COLON_TWO: | 
					
						
						|  | seps = [self.sep, self.sep2] | 
					
						
						|  | yield "", system_prompt | 
					
						
						|  | for i, (role, message) in enumerate(self.messages): | 
					
						
						|  | if message: | 
					
						
						|  | yield role, message + seps[i % 2] | 
					
						
						|  | else: | 
					
						
						|  | yield role, "" | 
					
						
						|  | return | 
					
						
						|  | if self.sep_style == SeparatorStyle.RWKV: | 
					
						
						|  | yield "", system_prompt | 
					
						
						|  | for i, (role, message) in enumerate(self.messages): | 
					
						
						|  | if message: | 
					
						
						|  | yield role + ": ", message.replace("\r\n", "\n").replace( | 
					
						
						|  | "\n\n", "\n" | 
					
						
						|  | ) + "\n\n" | 
					
						
						|  | else: | 
					
						
						|  | yield role + ":", "" | 
					
						
						|  | return | 
					
						
						|  | if self.sep_style == SeparatorStyle.LLAMA2 and self.name != "mistral": | 
					
						
						|  | if self.system_message: | 
					
						
						|  | if self.messages: | 
					
						
						|  |  | 
					
						
						|  | first_role, first_msg = self.messages[0] | 
					
						
						|  | if first_role == self.roles[0]: | 
					
						
						|  | system_prompt += first_msg | 
					
						
						|  | self.messages.pop(0) | 
					
						
						|  | yield "", system_prompt | 
					
						
						|  | for i, (role, message) in enumerate(self.messages): | 
					
						
						|  | if message: | 
					
						
						|  | if (i % 2 == 0 and not self.system_message) or ( | 
					
						
						|  | i % 2 != 0 and self.system_message | 
					
						
						|  | ): | 
					
						
						|  | role = "<s> " + role | 
					
						
						|  | yield role + " ", message | 
					
						
						|  | else: | 
					
						
						|  | yield role, "" | 
					
						
						|  | return | 
					
						
						|  | if self.sep_style == SeparatorStyle.LLAMA2 and self.name == "mistral": | 
					
						
						|  | contains_sys_msg = False | 
					
						
						|  | if self.system_message: | 
					
						
						|  | contains_sys_msg = True | 
					
						
						|  | if self.messages: | 
					
						
						|  |  | 
					
						
						|  | first_role, first_msg = self.messages[0] | 
					
						
						|  | if first_role == self.roles[0]: | 
					
						
						|  | system_prompt = self.system_template.format( | 
					
						
						|  | system_message=" " + self.system_message | 
					
						
						|  | ) | 
					
						
						|  | system_prompt += first_msg | 
					
						
						|  | self.messages.pop(0) | 
					
						
						|  | yield "", system_prompt | 
					
						
						|  | for i, (role, message) in enumerate(self.messages): | 
					
						
						|  | if message and i == 0 and not contains_sys_msg: | 
					
						
						|  | yield "", system_prompt.strip() + " " + message | 
					
						
						|  | elif message: | 
					
						
						|  | yield role + " ", message | 
					
						
						|  | else: | 
					
						
						|  | yield role, "" | 
					
						
						|  | return | 
					
						
						|  | if self.sep_style == SeparatorStyle.CHATGLM: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | round_add_n = 1 if self.name == "chatglm2" else 0 | 
					
						
						|  | if system_prompt: | 
					
						
						|  | yield "", system_prompt + self.sep | 
					
						
						|  |  | 
					
						
						|  | for i, (role, message) in enumerate(self.messages): | 
					
						
						|  | if i % 2 == 0: | 
					
						
						|  | yield "", f"[Round {i//2 + round_add_n}]{self.sep}" | 
					
						
						|  |  | 
					
						
						|  | if message: | 
					
						
						|  | yield f"{role}:", f"{message}{self.sep}" | 
					
						
						|  | else: | 
					
						
						|  | yield f"{role}:", "" | 
					
						
						|  | return | 
					
						
						|  | if self.sep_style == SeparatorStyle.CHATML: | 
					
						
						|  | yield "", "" if system_prompt == "" else system_prompt + self.sep + "\n" | 
					
						
						|  | for role, message in self.messages: | 
					
						
						|  | if message: | 
					
						
						|  | yield role + "\n", message + self.sep + "\n" | 
					
						
						|  | else: | 
					
						
						|  | yield role + "\n", "" | 
					
						
						|  | return | 
					
						
						|  | if self.sep_style == SeparatorStyle.CHATGLM3: | 
					
						
						|  | if self.system_message: | 
					
						
						|  | yield "", system_prompt | 
					
						
						|  | for role, message in self.messages: | 
					
						
						|  | if message: | 
					
						
						|  | yield role + "\n", " " + message | 
					
						
						|  | else: | 
					
						
						|  | yield role | 
					
						
						|  | return | 
					
						
						|  | if self.sep_style == SeparatorStyle.CHATINTERN: | 
					
						
						|  |  | 
					
						
						|  | seps = [self.sep, self.sep2] | 
					
						
						|  | yield "", system_prompt | 
					
						
						|  | for i, (role, message) in enumerate(self.messages): | 
					
						
						|  | prefix = "<s>" if i % 2 == 0 else "" | 
					
						
						|  | if message: | 
					
						
						|  | yield prefix + role + ":", message + seps[i % 2] + "\n" | 
					
						
						|  | else: | 
					
						
						|  | yield role + ":", "" | 
					
						
						|  | return | 
					
						
						|  | if self.sep_style == SeparatorStyle.DOLLY: | 
					
						
						|  | seps = [self.sep, self.sep2] | 
					
						
						|  | yield "", system_prompt | 
					
						
						|  | for i, (role, message) in enumerate(self.messages): | 
					
						
						|  | if message: | 
					
						
						|  | suffix = "\n\n" if i % 2 == 1 else "" | 
					
						
						|  | yield role + ":\n", message + seps[i % 2] + suffix | 
					
						
						|  | else: | 
					
						
						|  | yield role + ":\n", "" | 
					
						
						|  | return | 
					
						
						|  | if self.sep_style == SeparatorStyle.PHOENIX: | 
					
						
						|  | yield "", system_prompt | 
					
						
						|  | for role, message in self.messages: | 
					
						
						|  | if message: | 
					
						
						|  | yield role + ": ", "<s>" + message + "</s>" | 
					
						
						|  | else: | 
					
						
						|  | yield role + ": " + "<s>", "" | 
					
						
						|  | return | 
					
						
						|  | if self.sep_style == SeparatorStyle.ROBIN: | 
					
						
						|  | yield "", system_prompt + self.sep | 
					
						
						|  | for role, message in self.messages: | 
					
						
						|  | if message: | 
					
						
						|  | yield role + ":\n", message + self.sep | 
					
						
						|  | else: | 
					
						
						|  | yield role + ":\n", "" | 
					
						
						|  | return | 
					
						
						|  | if self.sep_style == SeparatorStyle.FALCON_CHAT: | 
					
						
						|  | if self.system_message: | 
					
						
						|  | yield "", system_prompt + self.sep | 
					
						
						|  | for role, message in self.messages: | 
					
						
						|  | if message: | 
					
						
						|  | yield role + ": ", message + self.sep | 
					
						
						|  | else: | 
					
						
						|  | yield role + ":", "" | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError(f"Invalid style: {self.sep_style}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def add_get_turns_to_conversation(): | 
					
						
						|  | import fastchat.conversation | 
					
						
						|  |  | 
					
						
						|  | fastchat.conversation.Conversation.get_turns = get_turns | 
					
						
						|  | fastchat.conversation.Conversation.get_prompt = get_prompt | 
					
						
						|  |  |