# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import logging from dataclasses import dataclass, field import pathlib from typing import Dict, Optional, Sequence import torch import transformers from torch.utils.data import Dataset from transformers import Trainer import numpy as np import json IGNORE_INDEX = -100 @dataclass class ModelArguments: model_name_or_path: Optional[str] = field(default="facebook/opt-125m") @dataclass class DataArguments: data_path: str = field(default=None, metadata={"help": "Path to the training data."}) @dataclass class TrainingArguments(transformers.TrainingArguments): cache_dir: Optional[str] = field(default=None) optim: str = field(default="adamw_torch") model_max_length: int = field( default=8192, metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, ) local_rank = None def rank0_print(*args): if local_rank == 0: print(*args) def bert_masking(input_ids, random_tokens, mask_token_id, mask_prob=0.15): assert len(input_ids) > 1 if isinstance(input_ids, list): input_ids = np.array(input_ids) elif isinstance(input_ids, torch.Tensor): input_ids = input_ids.numpy() elif isinstance(input_ids, np.ndarray): pass labels = np.full_like(input_ids, IGNORE_INDEX) # Initialize labels with -100 (ignore index for loss calculation) # We exclude the first and last tokens from being masked num_tokens = len(input_ids) valid_indices = np.arange(1, num_tokens - 1) # Ignore the first (index 0) and last (index -1) tokens # Determine the number of tokens to mask (15% of total valid tokens) num_mask = int(np.ceil(mask_prob * len(valid_indices))) # Randomly choose indices to mask from the valid indices mask_indices = np.random.choice(valid_indices, num_mask, replace=False) for idx in mask_indices: prob = np.random.rand() # Generate a random number between 0 and 1 if prob < 0.8: # 80% of the time, replace with [MASK] token labels[idx] = input_ids[idx] input_ids[idx] = mask_token_id elif prob < 0.9: # 10% of the time, replace with a random token labels[idx] = input_ids[idx] input_ids[idx] = np.random.choice(random_tokens) else: # 10% of the time, keep the original token (but predict it) labels[idx] = input_ids[idx] input_ids = torch.from_numpy(input_ids) labels = torch.from_numpy(labels) return dict(input_ids=input_ids, labels=labels) def is_not_special_token(token_name): unused = token_name.startswith("unused") is_special_token = (token_name in ["[CLS]", "[MASK]", "[PAD]", "[UNK]"]) flag = ((not unused) and (not is_special_token)) return flag class SupervisedDataset(Dataset): """Dataset for supervised fine-tuning.""" def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizerFast): super(SupervisedDataset, self).__init__() logging.warning("Loading data...") self.tokenizer = tokenizer self.max_length = 64 # max number of genes with open(data_path) as f: self.list_data = [line.split()[0: self.max_length] for line in f if len(line.split()) >= self.max_length] self.cached_input_ids = {} self.random_tokens = [token_id for token_name, token_id in self.tokenizer.vocab.items() if is_not_special_token(token_name)] def __len__(self): return len(self.list_data) def __getitem__(self, i) -> Dict[str, torch.Tensor]: if i in self.cached_input_ids: input_ids = self.cached_input_ids[i] else: input_ids = self.tokenizer(self.list_data[i], is_split_into_words=True)["input_ids"] self.cached_input_ids[i] = input_ids inputs = bert_masking(input_ids, self.random_tokens, self.tokenizer.mask_token_id) return inputs @dataclass class DataCollatorForSupervisedDataset(object): """Collate examples for supervised fine-tuning.""" tokenizer: transformers.PreTrainedTokenizerFast def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) input_ids = torch.nn.utils.rnn.pad_sequence( input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id ) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) return dict( input_ids=input_ids, labels=labels, attention_mask=(input_ids.ne(self.tokenizer.pad_token_id)).long(), ) def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizerFast, data_args) -> Dict: """Make dataset and collator for supervised fine-tuning.""" train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path) data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) def train(): parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() #model = transformers.AutoModelForCausalLM.from_pretrained( # model_args.model_name_or_path, # cache_dir=training_args.cache_dir, #) config = transformers.AutoConfig.from_pretrained('config.json') #model = transformers.OPTForCausalLM(config) model = transformers.BertForMaskedLM(config) model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)/1e+6 rank0_print(model) rank0_print(f"model_size: {model_size:.3f} Mb") tokenizer = transformers.PreTrainedTokenizerFast.from_pretrained("tokenizer") data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) #trainer.train() if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): trainer.train(resume_from_checkpoint=True) else: trainer.train() trainer.save_state() trainer.save_model(output_dir=training_args.output_dir) if __name__ == "__main__": train()