iseeek-bert-nano / train.py
lixiangchun's picture
initial upload
c9e5de4 verified
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
from dataclasses import dataclass, field
import pathlib
from typing import Dict, Optional, Sequence
import torch
import transformers
from torch.utils.data import Dataset
from transformers import Trainer
import numpy as np
import json
IGNORE_INDEX = -100
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
@dataclass
class DataArguments:
data_path: str = field(default=None, metadata={"help": "Path to the training data."})
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
model_max_length: int = field(
default=8192,
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
)
local_rank = None
def rank0_print(*args):
if local_rank == 0:
print(*args)
def bert_masking(input_ids, random_tokens, mask_token_id, mask_prob=0.15):
assert len(input_ids) > 1
if isinstance(input_ids, list):
input_ids = np.array(input_ids)
elif isinstance(input_ids, torch.Tensor):
input_ids = input_ids.numpy()
elif isinstance(input_ids, np.ndarray):
pass
labels = np.full_like(input_ids, IGNORE_INDEX) # Initialize labels with -100 (ignore index for loss calculation)
# We exclude the first and last tokens from being masked
num_tokens = len(input_ids)
valid_indices = np.arange(1, num_tokens - 1) # Ignore the first (index 0) and last (index -1) tokens
# Determine the number of tokens to mask (15% of total valid tokens)
num_mask = int(np.ceil(mask_prob * len(valid_indices)))
# Randomly choose indices to mask from the valid indices
mask_indices = np.random.choice(valid_indices, num_mask, replace=False)
for idx in mask_indices:
prob = np.random.rand() # Generate a random number between 0 and 1
if prob < 0.8:
# 80% of the time, replace with [MASK] token
labels[idx] = input_ids[idx]
input_ids[idx] = mask_token_id
elif prob < 0.9:
# 10% of the time, replace with a random token
labels[idx] = input_ids[idx]
input_ids[idx] = np.random.choice(random_tokens)
else:
# 10% of the time, keep the original token (but predict it)
labels[idx] = input_ids[idx]
input_ids = torch.from_numpy(input_ids)
labels = torch.from_numpy(labels)
return dict(input_ids=input_ids, labels=labels)
def is_not_special_token(token_name):
unused = token_name.startswith("unused")
is_special_token = (token_name in ["[CLS]", "[MASK]", "[PAD]", "[UNK]"])
flag = ((not unused) and (not is_special_token))
return flag
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizerFast):
super(SupervisedDataset, self).__init__()
logging.warning("Loading data...")
self.tokenizer = tokenizer
self.max_length = 64 # max number of genes
with open(data_path) as f:
self.list_data = [line.split()[0: self.max_length] for line in f if len(line.split()) >= self.max_length]
self.cached_input_ids = {}
self.random_tokens = [token_id for token_name, token_id in self.tokenizer.vocab.items() if is_not_special_token(token_name)]
def __len__(self):
return len(self.list_data)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
if i in self.cached_input_ids:
input_ids = self.cached_input_ids[i]
else:
input_ids = self.tokenizer(self.list_data[i], is_split_into_words=True)["input_ids"]
self.cached_input_ids[i] = input_ids
inputs = bert_masking(input_ids, self.random_tokens, self.tokenizer.mask_token_id)
return inputs
@dataclass
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
tokenizer: transformers.PreTrainedTokenizerFast
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
return dict(
input_ids=input_ids,
labels=labels,
attention_mask=(input_ids.ne(self.tokenizer.pad_token_id)).long(),
)
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizerFast, data_args) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
def train():
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
#model = transformers.AutoModelForCausalLM.from_pretrained(
# model_args.model_name_or_path,
# cache_dir=training_args.cache_dir,
#)
config = transformers.AutoConfig.from_pretrained('config.json')
#model = transformers.OPTForCausalLM(config)
model = transformers.BertForMaskedLM(config)
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)/1e+6
rank0_print(model)
rank0_print(f"model_size: {model_size:.3f} Mb")
tokenizer = transformers.PreTrainedTokenizerFast.from_pretrained("tokenizer")
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
#trainer.train()
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
trainer.save_state()
trainer.save_model(output_dir=training_args.output_dir)
if __name__ == "__main__":
train()