|  | """pretraining prompt strategies""" | 
					
						
						|  | from typing import Generator | 
					
						
						|  |  | 
					
						
						|  | from transformers import BatchEncoding | 
					
						
						|  |  | 
					
						
						|  | from axolotl.prompt_tokenizers import PromptTokenizingStrategy | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class PretrainTokenizer: | 
					
						
						|  | """basic tokenization class for pretraining""" | 
					
						
						|  |  | 
					
						
						|  | def build_prompt(self, prompt) -> Generator[str, None, None]: | 
					
						
						|  | yield prompt | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class PretrainTokenizationStrategy(PromptTokenizingStrategy): | 
					
						
						|  | """handles tokenization for pretraining with strides""" | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def supports_batched(self): | 
					
						
						|  | return True | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, *args, max_length=None, text_column="text", **kwargs): | 
					
						
						|  | super().__init__(*args, **kwargs) | 
					
						
						|  | if max_length: | 
					
						
						|  | self.max_length = max_length | 
					
						
						|  | self.text_column = text_column | 
					
						
						|  |  | 
					
						
						|  | def _tokenize( | 
					
						
						|  | self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False | 
					
						
						|  | ) -> BatchEncoding: | 
					
						
						|  | res = self.tokenizer( | 
					
						
						|  | prompt, | 
					
						
						|  | truncation=True, | 
					
						
						|  | max_length=self.max_length - 1, | 
					
						
						|  | add_special_tokens=True, | 
					
						
						|  | return_overflowing_tokens=True, | 
					
						
						|  | stride=256, | 
					
						
						|  | ) | 
					
						
						|  | res["input_ids"] = [ | 
					
						
						|  | seq + [self.tokenizer.eos_token_id] for seq in res["input_ids"] | 
					
						
						|  | ] | 
					
						
						|  | res["attention_mask"] = [seq + [1] for seq in res["attention_mask"]] | 
					
						
						|  |  | 
					
						
						|  | return res | 
					
						
						|  |  | 
					
						
						|  | def tokenize_prompt(self, prompt): | 
					
						
						|  | return self._tokenize(prompt[self.text_column]) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def load(tokenizer, cfg): | 
					
						
						|  | strat = PretrainTokenizationStrategy( | 
					
						
						|  | PretrainTokenizer(), | 
					
						
						|  | tokenizer, | 
					
						
						|  | cfg.train_on_inputs, | 
					
						
						|  | cfg.sequence_len, | 
					
						
						|  | text_column=cfg.pretraining_dataset[0]["text_column"] or "text", | 
					
						
						|  | max_length=cfg.sequence_len * 64, | 
					
						
						|  | ) | 
					
						
						|  | return strat | 
					
						
						|  |  |