|  | """Module containing Dataset functionality""" | 
					
						
						|  |  | 
					
						
						|  | import logging | 
					
						
						|  | import os | 
					
						
						|  | from typing import List, Optional | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | from datasets import Dataset, IterableDataset | 
					
						
						|  |  | 
					
						
						|  | from .prompt_tokenizers import PromptTokenizingStrategy | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | LOG = logging.getLogger("axolotl") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class TokenizedPromptDataset(Dataset): | 
					
						
						|  | """ | 
					
						
						|  | Dataset that returns tokenized prompts from a stream of text files. | 
					
						
						|  | Args: | 
					
						
						|  | prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for processing the data. | 
					
						
						|  | dataset (dataset.Dataset): Dataset with text files. | 
					
						
						|  | process_count (int): Number of processes to use for tokenizing. | 
					
						
						|  | keep_in_memory (bool): Whether to keep the tokenized dataset in memory. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | prompt_tokenizer: PromptTokenizingStrategy, | 
					
						
						|  | dataset: IterableDataset, | 
					
						
						|  | process_count: Optional[int] = None, | 
					
						
						|  | keep_in_memory: Optional[bool] = False, | 
					
						
						|  | **kwargs, | 
					
						
						|  | ): | 
					
						
						|  | self.prompt_tokenizer = prompt_tokenizer | 
					
						
						|  | self.process_count = process_count | 
					
						
						|  | self.keep_in_memory = keep_in_memory | 
					
						
						|  | super().__init__( | 
					
						
						|  | self.process(dataset).data, | 
					
						
						|  | **kwargs, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def process(self, dataset): | 
					
						
						|  | features = dataset.features.keys() | 
					
						
						|  | num_proc = min(64, self.process_count if self.process_count else os.cpu_count()) | 
					
						
						|  |  | 
					
						
						|  | map_kwargs = {} | 
					
						
						|  | if self.prompt_tokenizer.supports_batched: | 
					
						
						|  | map_kwargs["batched"] = True | 
					
						
						|  | map_kwargs["batch_size"] = 100 | 
					
						
						|  | return dataset.map( | 
					
						
						|  | self.prompt_tokenizer.tokenize_prompt, | 
					
						
						|  | num_proc=num_proc, | 
					
						
						|  | remove_columns=features, | 
					
						
						|  | keep_in_memory=self.keep_in_memory, | 
					
						
						|  | **map_kwargs, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ConstantLengthDataset(IterableDataset): | 
					
						
						|  | """ | 
					
						
						|  | Iterable dataset that returns constant length chunks of tokens from stream of text files. | 
					
						
						|  | Args: | 
					
						
						|  | tokenizer (Tokenizer): The processor used for processing the data. | 
					
						
						|  | dataset (dataset.Dataset): Dataset with text files. | 
					
						
						|  | seq_length (int): Length of token sequences to return. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | tokenizer, | 
					
						
						|  | datasets, | 
					
						
						|  | seq_length=2048, | 
					
						
						|  | ): | 
					
						
						|  | self.tokenizer = tokenizer | 
					
						
						|  | self.concat_token_id = tokenizer.eos_token_id | 
					
						
						|  | self.datasets: List[IterableDataset] = datasets | 
					
						
						|  | self.seq_length = seq_length | 
					
						
						|  |  | 
					
						
						|  | vocab_size = len(tokenizer.get_vocab()) | 
					
						
						|  |  | 
					
						
						|  | if vocab_size <= torch.iinfo(torch.int16).max: | 
					
						
						|  | self.tokens_dtype = torch.int16 | 
					
						
						|  | elif vocab_size <= torch.iinfo(torch.int32).max: | 
					
						
						|  | self.tokens_dtype = torch.int32 | 
					
						
						|  | else: | 
					
						
						|  | self.tokens_dtype = torch.int64 | 
					
						
						|  |  | 
					
						
						|  | def __iter__(self): | 
					
						
						|  | buffer = { | 
					
						
						|  | "input_ids": [], | 
					
						
						|  | "attention_mask": [], | 
					
						
						|  | "labels": [], | 
					
						
						|  | "position_ids": [], | 
					
						
						|  | } | 
					
						
						|  | buffer_len = 0 | 
					
						
						|  | for dataset in self.datasets: | 
					
						
						|  | idx = 0 | 
					
						
						|  | iterator = iter(dataset) | 
					
						
						|  | more_examples = True | 
					
						
						|  | while more_examples: | 
					
						
						|  | try: | 
					
						
						|  | example = next(iterator) | 
					
						
						|  | idx += 1 | 
					
						
						|  | except StopIteration: | 
					
						
						|  | more_examples = False | 
					
						
						|  | example = None | 
					
						
						|  |  | 
					
						
						|  | add_concat_token = False | 
					
						
						|  | if example: | 
					
						
						|  | example_len = len(example["input_ids"]) | 
					
						
						|  | add_concat_token = example["input_ids"][-1] != self.concat_token_id | 
					
						
						|  | else: | 
					
						
						|  | example_len = 0 | 
					
						
						|  |  | 
					
						
						|  | if not example_len or ( | 
					
						
						|  | buffer_len + int(add_concat_token) + example_len > self.seq_length | 
					
						
						|  | ): | 
					
						
						|  | if buffer["input_ids"]: | 
					
						
						|  | input_ids = torch.cat(buffer["input_ids"], dim=-1)[ | 
					
						
						|  | : self.seq_length | 
					
						
						|  | ] | 
					
						
						|  | attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[ | 
					
						
						|  | : self.seq_length | 
					
						
						|  | ] | 
					
						
						|  | position_ids = torch.cat(buffer["position_ids"], dim=-1)[ | 
					
						
						|  | : self.seq_length | 
					
						
						|  | ] | 
					
						
						|  | labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length] | 
					
						
						|  | if labels.size() == input_ids.size() and ( | 
					
						
						|  | attention_mask.size() == input_ids.size() | 
					
						
						|  | ): | 
					
						
						|  | yield { | 
					
						
						|  | "input_ids": input_ids, | 
					
						
						|  | "labels": labels, | 
					
						
						|  | "attention_mask": attention_mask, | 
					
						
						|  | "position_ids": position_ids, | 
					
						
						|  | } | 
					
						
						|  | else: | 
					
						
						|  | LOG.warning( | 
					
						
						|  | f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}" | 
					
						
						|  | ) | 
					
						
						|  | buffer = { | 
					
						
						|  | "input_ids": [], | 
					
						
						|  | "attention_mask": [], | 
					
						
						|  | "labels": [], | 
					
						
						|  | "position_ids": [], | 
					
						
						|  | } | 
					
						
						|  | buffer_len = 0 | 
					
						
						|  | idx = 1 | 
					
						
						|  |  | 
					
						
						|  | if example: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if len(example["input_ids"]) <= self.seq_length: | 
					
						
						|  | input_ids = example["input_ids"] | 
					
						
						|  | attention_mask = example["attention_mask"] | 
					
						
						|  | labels = example["labels"] | 
					
						
						|  |  | 
					
						
						|  | if add_concat_token: | 
					
						
						|  | input_ids.append(self.concat_token_id) | 
					
						
						|  | attention_mask.append(1) | 
					
						
						|  | labels.append(self.concat_token_id) | 
					
						
						|  |  | 
					
						
						|  | input_ids_with_concat = torch.tensor( | 
					
						
						|  | input_ids, dtype=self.tokens_dtype | 
					
						
						|  | ) | 
					
						
						|  | attention_mask_with_concat = torch.tensor( | 
					
						
						|  | [idx * m for m in attention_mask], dtype=torch.int16 | 
					
						
						|  | ) | 
					
						
						|  | labels_with_concat = torch.tensor( | 
					
						
						|  | labels, dtype=self.tokens_dtype | 
					
						
						|  | ) | 
					
						
						|  | position_ids = torch.arange( | 
					
						
						|  | len(input_ids), dtype=self.tokens_dtype | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | buffer["input_ids"].append(input_ids_with_concat) | 
					
						
						|  | buffer["attention_mask"].append(attention_mask_with_concat) | 
					
						
						|  | buffer["labels"].append(labels_with_concat) | 
					
						
						|  | buffer["position_ids"].append(position_ids) | 
					
						
						|  | buffer_len += len(input_ids) | 
					
						
						|  |  |