|  | """ | 
					
						
						|  | Basic completion text | 
					
						
						|  | """ | 
					
						
						|  | from collections import defaultdict | 
					
						
						|  | from typing import Any, Dict, Generator, Optional, Tuple | 
					
						
						|  |  | 
					
						
						|  | from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): | 
					
						
						|  | """ | 
					
						
						|  | Tokenizing strategy for Completion prompts. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | _field: str = "text" | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, *args, max_length=None, **kwargs): | 
					
						
						|  | super().__init__(*args, **kwargs) | 
					
						
						|  | if max_length is not None: | 
					
						
						|  | self.max_length = max_length | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def supports_batched(self): | 
					
						
						|  | return True | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def field(self) -> str: | 
					
						
						|  | return self._field | 
					
						
						|  |  | 
					
						
						|  | @field.setter | 
					
						
						|  | def field(self, new_field: str): | 
					
						
						|  | self._field = new_field | 
					
						
						|  |  | 
					
						
						|  | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: | 
					
						
						|  | return ( | 
					
						
						|  | prompt[self.field], | 
					
						
						|  | "", | 
					
						
						|  | "", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def tokenize_prompt(self, prompt): | 
					
						
						|  | res = defaultdict(lambda: []) | 
					
						
						|  | feature_names = list(prompt.keys()) | 
					
						
						|  | for row in zip(*prompt.values()): | 
					
						
						|  | prompt_row = dict(zip(feature_names, row)) | 
					
						
						|  | ( | 
					
						
						|  | instruction, | 
					
						
						|  | _, | 
					
						
						|  | _, | 
					
						
						|  | ) = self.parse_instruction_fields(prompt_row) | 
					
						
						|  |  | 
					
						
						|  | full_prompt = self._build_full_prompt(instruction, None, None) | 
					
						
						|  | tokenized_full_prompt = self._tokenize(full_prompt) | 
					
						
						|  |  | 
					
						
						|  | for key, val in tokenized_full_prompt.items(): | 
					
						
						|  | for i in range(0, len(val), self.sequence_len): | 
					
						
						|  | res[key].append(val[i : i + self.sequence_len]) | 
					
						
						|  |  | 
					
						
						|  | return dict(res) | 
					
						
						|  |  | 
					
						
						|  | def _build_full_prompt( | 
					
						
						|  | self, instruction, input, response | 
					
						
						|  | ): | 
					
						
						|  | return next(iter(self.prompter.build_prompt(instruction, input, response))) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class CompletionPrompter: | 
					
						
						|  | """ | 
					
						
						|  | Prompter for completion | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def build_prompt( | 
					
						
						|  | self, | 
					
						
						|  | instruction: str, | 
					
						
						|  | input=None, | 
					
						
						|  | output=None, | 
					
						
						|  | ) -> Generator[str, None, None]: | 
					
						
						|  | yield instruction | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): | 
					
						
						|  | strat = CompletionPromptTokenizingStrategy( | 
					
						
						|  | CompletionPrompter(), | 
					
						
						|  | tokenizer, | 
					
						
						|  | cfg.train_on_inputs, | 
					
						
						|  | cfg.sequence_len, | 
					
						
						|  | max_length=cfg.sequence_len * 64, | 
					
						
						|  | ) | 
					
						
						|  | if ds_cfg and "field" in ds_cfg: | 
					
						
						|  | strat.field = ds_cfg["field"] | 
					
						
						|  |  | 
					
						
						|  | return strat | 
					
						
						|  |  |