|
''' |
|
Pre-training/Fine-tuning seq2seq models on autoencoding a dataset. |
|
|
|
TODO: |
|
- [ ] Add reg loss |
|
- [x] calculate MMD loss |
|
- [ ] schedule MMD loss weight |
|
- [ ] Add these params to the training arguments. |
|
|
|
reg_schedule_k (:obj:`float`, `optional`, defaults to 0.0025): |
|
Multiplied by global_step in a sigmoid, more gradually increase regulariser loss weight. |
|
reg_schedule_b (:obj:`float`, `optional`, defaults to 6.25): |
|
Added to global step in sigmoid, further delays increase in regulariser loss weight. |
|
use_extra_logs (:obj:`bool`, `optional`, defaults to False): |
|
Store extra logs during each training inference. |
|
|
|
- [ ] Send the schedule time to the compute_loss method and calculate a coefficient based on that. |
|
''' |
|
import logging |
|
import math |
|
import os |
|
import sys |
|
import time |
|
from dataclasses import dataclass, field |
|
from pathlib import Path |
|
from typing import Callable, Optional |
|
|
|
import datasets |
|
from datasets import Dataset, load_dataset |
|
from tqdm import tqdm |
|
|
|
import jax |
|
import jax.numpy as jnp |
|
import optax |
|
import transformers |
|
from flax import jax_utils, traverse_util |
|
from flax.jax_utils import unreplicate |
|
from flax.training import train_state |
|
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key |
|
from transformers import ( |
|
AutoTokenizer, |
|
HfArgumentParser, |
|
TrainingArguments, |
|
is_tensorboard_available, |
|
) |
|
from transformers.models.t5.modeling_flax_t5 import shift_tokens_right |
|
from transformers.testing_utils import CaptureLogger |
|
|
|
from t5_vae_flax.src.t5_vae import FlaxT5VaeForAutoencoding |
|
from t5_vae_flax.src.config import T5VaeConfig |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@dataclass |
|
class ModelArguments: |
|
""" |
|
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. |
|
""" |
|
|
|
model_name_or_path: Optional[str] = field( |
|
default=None, |
|
metadata={ |
|
"help": "The model checkpoint for weights initialization." |
|
"Don't set if you want to train a model from scratch." |
|
}, |
|
) |
|
t5_model_name_or_path: Optional[str] = field( |
|
default=None, |
|
metadata={ |
|
"help": "The T5 model checkpoint for weights initialization." |
|
"Needed when not starting from a T5-VAE model." |
|
}, |
|
) |
|
n_latent_tokens: Optional[int] = field( |
|
default=6, |
|
metadata={ |
|
"help": "Number of latent tokens (must be less than seq length)." |
|
}, |
|
) |
|
latent_token_size: Optional[int] = field( |
|
default=32, |
|
metadata={ |
|
"help": "Number of dimensions to use for each latent token." |
|
}, |
|
) |
|
add_special_tokens: bool = field( |
|
default=False, |
|
metadata={"help": "Add these special tokens to the tokenizer: {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}"}, |
|
) |
|
config_path: Optional[str] = field( |
|
default=None, metadata={"help": "Pretrained config path"} |
|
) |
|
tokenizer_name: Optional[str] = field( |
|
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} |
|
) |
|
cache_dir: Optional[str] = field( |
|
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} |
|
) |
|
use_fast_tokenizer: bool = field( |
|
default=True, |
|
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, |
|
) |
|
dtype: Optional[str] = field( |
|
default="float32", |
|
metadata={ |
|
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`." |
|
}, |
|
) |
|
|
|
|
|
@dataclass |
|
class DataTrainingArguments: |
|
""" |
|
Arguments pertaining to what data we are going to input our model for training and eval. |
|
""" |
|
|
|
dataset_name: Optional[str] = field( |
|
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} |
|
) |
|
dataset_config_name: Optional[str] = field( |
|
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} |
|
) |
|
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) |
|
validation_file: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, |
|
) |
|
max_train_samples: Optional[int] = field( |
|
default=None, |
|
metadata={ |
|
"help": "For debugging purposes or quicker training, truncate the number of training examples to this " |
|
"value if set." |
|
}, |
|
) |
|
max_eval_samples: Optional[int] = field( |
|
default=None, |
|
metadata={ |
|
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " |
|
"value if set." |
|
}, |
|
) |
|
overwrite_cache: bool = field( |
|
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} |
|
) |
|
validation_split_percentage: Optional[int] = field( |
|
default=5, |
|
metadata={ |
|
"help": "The percentage of the train set used as validation set in case there's no validation split" |
|
}, |
|
) |
|
block_size: Optional[int] = field( |
|
default=None, |
|
metadata={ |
|
"help": "Optional input sequence length after tokenization. " |
|
"The training dataset will be truncated in block of this size for training. " |
|
"Default to the model max input length for single sentence inputs (take into account special tokens)." |
|
}, |
|
) |
|
streaming: bool = field( |
|
default=False, metadata={"help": "Stream the dataset."} |
|
) |
|
overwrite_cache: bool = field( |
|
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} |
|
) |
|
preprocessing_num_workers: Optional[int] = field( |
|
default=None, |
|
metadata={"help": "The number of processes to use for the preprocessing."}, |
|
) |
|
|
|
def __post_init__(self): |
|
if self.dataset_name is None and self.train_file is None and self.validation_file is None: |
|
raise ValueError("Need either a dataset name or a training/validation file.") |
|
else: |
|
if self.train_file is not None: |
|
extension = self.train_file.split(".")[-1] |
|
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file." |
|
if self.validation_file is not None: |
|
extension = self.validation_file.split(".")[-1] |
|
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file." |
|
|
|
|
|
class TrainState(train_state.TrainState): |
|
dropout_rng: jnp.ndarray |
|
|
|
def replicate(self): |
|
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng)) |
|
|
|
|
|
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False): |
|
""" |
|
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices. |
|
Shuffle batches if `shuffle` is `True`. |
|
""" |
|
steps_per_epoch = len(dataset) // batch_size |
|
|
|
if shuffle: |
|
batch_idx = jax.random.permutation(rng, len(dataset)) |
|
else: |
|
batch_idx = jnp.arange(len(dataset)) |
|
|
|
batch_idx = batch_idx[: steps_per_epoch * batch_size] |
|
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) |
|
|
|
for idx in batch_idx: |
|
batch = dataset[idx] |
|
batch = {k: jnp.array(v) for k, v in batch.items()} |
|
|
|
batch = shard(batch) |
|
|
|
yield batch |
|
|
|
|
|
def write_train_metric(summary_writer, train_metrics, train_time, step): |
|
summary_writer.scalar("train_time", train_time, step) |
|
|
|
train_metrics = get_metrics(train_metrics) |
|
for key, vals in train_metrics.items(): |
|
tag = f"train_{key}" |
|
for i, val in enumerate(vals): |
|
summary_writer.scalar(tag, val, step - len(vals) + i + 1) |
|
|
|
|
|
def write_eval_metric(summary_writer, eval_metrics, step): |
|
for metric_name, value in eval_metrics.items(): |
|
summary_writer.scalar(f"eval_{metric_name}", value, step) |
|
|
|
|
|
def create_learning_rate_fn( |
|
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float |
|
) -> Callable[[int], jnp.array]: |
|
"""Returns a linear warmup, linear_decay learning rate function.""" |
|
steps_per_epoch = train_ds_size // train_batch_size |
|
num_train_steps = steps_per_epoch * num_train_epochs |
|
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) |
|
decay_fn = optax.linear_schedule( |
|
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps |
|
) |
|
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) |
|
return schedule_fn |
|
|
|
|
|
def main(): |
|
|
|
|
|
|
|
|
|
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) |
|
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): |
|
|
|
|
|
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) |
|
else: |
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
|
|
|
if ( |
|
os.path.exists(training_args.output_dir) |
|
and os.listdir(training_args.output_dir) |
|
and training_args.do_train |
|
and not training_args.overwrite_output_dir |
|
): |
|
raise ValueError( |
|
f"Output directory ({training_args.output_dir}) already exists and is not empty." |
|
"Use --overwrite_output_dir to overcome." |
|
) |
|
|
|
if data_args.block_size is None: |
|
raise Exception('Must set block_size so we know what length of sequence to autoencode.') |
|
|
|
|
|
logging.basicConfig( |
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
|
datefmt="%m/%d/%Y %H:%M:%S", |
|
level=logging.INFO, |
|
) |
|
|
|
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) |
|
if jax.process_index() == 0: |
|
datasets.utils.logging.set_verbosity_warning() |
|
transformers.utils.logging.set_verbosity_info() |
|
else: |
|
datasets.utils.logging.set_verbosity_error() |
|
transformers.utils.logging.set_verbosity_error() |
|
|
|
|
|
logger.info(f"Training/evaluation parameters {training_args}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if data_args.dataset_name is not None: |
|
|
|
dataset = load_dataset( |
|
data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, streaming=data_args.streaming, keep_in_memory=False |
|
) |
|
|
|
if "validation" not in dataset.keys(): |
|
dataset["validation"] = load_dataset( |
|
data_args.dataset_name, |
|
data_args.dataset_config_name, |
|
split=f"train[:{data_args.validation_split_percentage}%]", |
|
cache_dir=model_args.cache_dir, |
|
) |
|
dataset["train"] = load_dataset( |
|
data_args.dataset_name, |
|
data_args.dataset_config_name, |
|
split=f"train[{data_args.validation_split_percentage}%:]", |
|
cache_dir=model_args.cache_dir, |
|
) |
|
else: |
|
data_files = {} |
|
if data_args.train_file is not None: |
|
data_files["train"] = data_args.train_file |
|
if data_args.validation_file is not None: |
|
data_files["validation"] = data_args.validation_file |
|
extension = data_args.train_file.split(".")[-1] |
|
if extension == "txt": |
|
extension = "text" |
|
dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if model_args.config_path: |
|
config = T5VaeConfig.from_pretrained( |
|
model_args.config_path, cache_dir=model_args.cache_dir |
|
) |
|
elif model_args.model_name_or_path: |
|
config = T5VaeConfig.from_pretrained( |
|
model_args.model_name_or_path, cache_dir=model_args.cache_dir |
|
) |
|
else: |
|
config = T5VaeConfig(**model_args.__dict__) |
|
logger.warning("You are instantiating a new config instance from scratch.") |
|
|
|
if model_args.tokenizer_name: |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer |
|
) |
|
elif model_args.t5_model_name_or_path: |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_args.t5_model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer |
|
) |
|
else: |
|
raise ValueError( |
|
"You are instantiating a new tokenizer from scratch. This is not supported by this script." |
|
"You can do it from another script, save it, and load it from here, using --tokenizer_name." |
|
) |
|
|
|
if model_args.model_name_or_path: |
|
model = FlaxT5VaeForAutoencoding.from_pretrained( |
|
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype) |
|
) |
|
assert model.params['t5']['shared']['embedding'].shape[0] == len(tokenizer), "T5 Tokenizer doesn't match T5Vae embedding size." |
|
else: |
|
vocab_size = len(tokenizer) |
|
config.t5.vocab_size = vocab_size |
|
config.vocab_size = vocab_size |
|
logger.info("Training new model from scratch.") |
|
model = FlaxT5VaeForAutoencoding( |
|
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype) |
|
) |
|
|
|
if model_args.add_special_tokens: |
|
special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'} |
|
num_added_tokens = tokenizer.add_special_tokens(special_tokens_dict) |
|
print('We have added', num_added_tokens, 'tokens to GPT2') |
|
model.resize_token_embeddings(len(tokenizer)) |
|
assert tokenizer.pad_token == '<PAD>' |
|
|
|
|
|
|
|
if training_args.do_train: |
|
column_names = dataset["train"].column_names |
|
else: |
|
column_names = dataset["validation"].column_names |
|
text_column_name = "text" if "text" in column_names else column_names[0] |
|
|
|
|
|
tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base") |
|
|
|
def tokenize_function(examples): |
|
with CaptureLogger(tok_logger) as cl: |
|
output = tokenizer(examples[text_column_name]) |
|
|
|
if "Token indices sequence length is longer than the" in cl.out: |
|
tok_logger.warning( |
|
"^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model." |
|
) |
|
return output |
|
|
|
|
|
for k in dataset.keys(): |
|
dataset[k].info.task_templates = [] |
|
|
|
tokenized_datasets = dataset.map( |
|
tokenize_function, |
|
batched=True, |
|
num_proc=data_args.preprocessing_num_workers, |
|
remove_columns=column_names, |
|
load_from_cache_file=not data_args.overwrite_cache, |
|
) |
|
|
|
if data_args.block_size > tokenizer.model_max_length: |
|
logger.warning( |
|
f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model" |
|
f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." |
|
) |
|
block_size = min(data_args.block_size, tokenizer.model_max_length) |
|
|
|
pad_token_id, start_token_id = tokenizer.pad_token_id, config.decoder_start_token_id |
|
|
|
def clip_texts(examples): |
|
examples["labels"] = examples["input_ids"].copy() |
|
|
|
for i, input_ids in enumerate(examples["input_ids"]): |
|
if len(input_ids) > block_size: |
|
for k in examples.keys(): |
|
examples[k][i] = examples[k][i][:block_size] |
|
elif len(input_ids) < block_size: |
|
delta = block_size - len(input_ids) |
|
examples['input_ids'][i] = examples['input_ids'][i] + [pad_token_id] * delta |
|
examples['attention_mask'][i] = examples['attention_mask'][i] + [0] * delta |
|
examples['labels'][i] = examples['labels'][i] + [-100] * delta |
|
|
|
return examples |
|
|
|
logger.info('clip_texts...') |
|
clipped_lm_datasets = tokenized_datasets.map( |
|
clip_texts, |
|
batched=True, |
|
num_proc=data_args.preprocessing_num_workers, |
|
load_from_cache_file=not data_args.overwrite_cache, |
|
) |
|
|
|
def add_decoder_input_ids(examples): |
|
arr_input_ids = jnp.array(examples["input_ids"]) |
|
pad = pad_token_id * jnp.ones((arr_input_ids.shape[0], 1), dtype=jnp.int32) |
|
arr_pad_input_ids = jnp.concatenate((arr_input_ids, pad), axis=1) |
|
examples['decoder_input_ids'] = shift_tokens_right(arr_pad_input_ids, pad_token_id, start_token_id) |
|
|
|
arr_attention_mask = jnp.array(examples['attention_mask']) |
|
ones = jnp.ones((arr_attention_mask.shape[0], 1), dtype=jnp.int32) |
|
examples['decoder_attention_mask'] = jnp.concatenate((ones, arr_attention_mask), axis=1) |
|
|
|
for k in ['decoder_input_ids', 'decoder_attention_mask']: |
|
examples[k] = examples[k].tolist() |
|
|
|
return examples |
|
|
|
logger.info('add_decoder_input_ids...') |
|
lm_datasets = clipped_lm_datasets.map( |
|
add_decoder_input_ids, |
|
batched=True, |
|
num_proc=data_args.preprocessing_num_workers, |
|
load_from_cache_file=not data_args.overwrite_cache, |
|
) |
|
|
|
if training_args.do_train: |
|
if "train" not in tokenized_datasets: |
|
raise ValueError("--do_train requires a train dataset") |
|
train_dataset = lm_datasets["train"] |
|
if data_args.max_train_samples is not None: |
|
train_dataset = train_dataset.select(range(data_args.max_train_samples)) |
|
|
|
if training_args.do_eval: |
|
if "validation" not in tokenized_datasets: |
|
raise ValueError("--do_eval requires a validation dataset") |
|
eval_dataset = lm_datasets["validation"] |
|
if data_args.max_eval_samples is not None: |
|
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) |
|
|
|
|
|
has_tensorboard = is_tensorboard_available() |
|
if has_tensorboard and jax.process_index() == 0: |
|
try: |
|
from flax.metrics.tensorboard import SummaryWriter |
|
|
|
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) |
|
except ImportError as ie: |
|
has_tensorboard = False |
|
logger.warning( |
|
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" |
|
) |
|
else: |
|
logger.warning( |
|
"Unable to display metrics through TensorBoard because the package is not installed: " |
|
"Please run pip install tensorboard to enable." |
|
) |
|
|
|
|
|
rng = jax.random.PRNGKey(training_args.seed) |
|
rng, dropout_rng = jax.random.split(rng) |
|
|
|
|
|
num_epochs = int(training_args.num_train_epochs) |
|
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() |
|
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() |
|
steps_per_epoch = len(train_dataset) // train_batch_size |
|
total_train_steps = steps_per_epoch * num_epochs |
|
|
|
|
|
linear_decay_lr_schedule_fn = create_learning_rate_fn( |
|
len(train_dataset), |
|
train_batch_size, |
|
training_args.num_train_epochs, |
|
training_args.warmup_steps, |
|
training_args.learning_rate, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def decay_mask_fn(params): |
|
flat_params = traverse_util.flatten_dict(params) |
|
flat_mask = { |
|
path: (path[-1] != "bias" and path[-2:] not in [("ln_1", "scale"), ("ln_2", "scale"), ("ln_f", "scale")]) |
|
for path in flat_params |
|
} |
|
return traverse_util.unflatten_dict(flat_mask) |
|
|
|
|
|
if training_args.adafactor: |
|
|
|
|
|
optimizer = optax.adafactor( |
|
learning_rate=linear_decay_lr_schedule_fn, |
|
) |
|
else: |
|
optimizer = optax.adamw( |
|
learning_rate=linear_decay_lr_schedule_fn, |
|
b1=training_args.adam_beta1, |
|
b2=training_args.adam_beta2, |
|
eps=training_args.adam_epsilon, |
|
weight_decay=training_args.weight_decay, |
|
mask=decay_mask_fn, |
|
) |
|
|
|
|
|
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng) |
|
|
|
def compute_kernel(x, y): |
|
x_size = x.shape[0] |
|
y_size = y.shape[0] |
|
dim = x.shape[1] |
|
tiled_x = jnp.repeat(jnp.reshape(x, (x_size, 1, dim)), y_size, axis=1) |
|
tiled_y = jnp.repeat(jnp.reshape(y, (1, y_size, dim)), x_size, axis=0) |
|
return jnp.exp(-jnp.mean((tiled_x - tiled_y) ** 2, axis=2) / dim * 1.0) |
|
|
|
def compute_mmd(x, y): |
|
x_kernel = compute_kernel(x, x) |
|
y_kernel = compute_kernel(y, y) |
|
xy_kernel = compute_kernel(x, y) |
|
return jnp.mean(x_kernel) + jnp.mean(y_kernel) - 2 * jnp.mean(xy_kernel) |
|
|
|
def regulariser_loss(latent_codes, rng): |
|
true_samples = jax.random.normal(rng, latent_codes.shape) |
|
|
|
return compute_mmd(true_samples, latent_codes) |
|
|
|
def loss_fn(logits, labels, latent_codes, regulariser_rng): |
|
shift_logits = logits[..., :-1, :] |
|
loss = optax.softmax_cross_entropy(shift_logits, onehot(labels, logits.shape[-1])) |
|
reg_loss = regulariser_loss(latent_codes.reshape(-1, latent_codes.shape[-1]), regulariser_rng) |
|
return loss.mean() + reg_loss.mean() |
|
|
|
|
|
def train_step(state, batch): |
|
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) |
|
new_dropout_rng, regulariser_rng = jax.random.split(new_dropout_rng) |
|
|
|
def compute_loss(params): |
|
labels = batch.pop("labels") |
|
outputs = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True) |
|
loss = loss_fn(outputs[0], labels, outputs[1], regulariser_rng) |
|
return loss |
|
|
|
grad_fn = jax.value_and_grad(compute_loss) |
|
loss, grad = grad_fn(state.params) |
|
grad = jax.lax.pmean(grad, "batch") |
|
|
|
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng) |
|
|
|
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} |
|
metrics = jax.lax.pmean(metrics, axis_name="batch") |
|
|
|
return new_state, metrics |
|
|
|
|
|
def eval_step(params, rng, batch): |
|
labels = batch.pop("labels") |
|
logits, latent_codes = model(**batch, params=params, train=False)[:2] |
|
loss = loss_fn(logits, labels, latent_codes, rng) |
|
|
|
|
|
metrics = {"loss": loss} |
|
metrics = jax.lax.pmean(metrics, axis_name="batch") |
|
return metrics |
|
|
|
|
|
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,)) |
|
p_eval_step = jax.pmap(eval_step, "batch") |
|
|
|
|
|
state = state.replicate() |
|
|
|
logger.info("***** Running training *****") |
|
logger.info(f" Num examples = {len(train_dataset)}") |
|
logger.info(f" Num Epochs = {num_epochs}") |
|
logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}") |
|
logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}") |
|
logger.info(f" Total optimization steps = {total_train_steps}") |
|
|
|
train_time = 0 |
|
train_metrics = [] |
|
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) |
|
for epoch in epochs: |
|
|
|
train_start = time.time() |
|
|
|
|
|
rng, input_rng = jax.random.split(rng) |
|
|
|
|
|
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True) |
|
steps_per_epoch = len(train_dataset) // train_batch_size |
|
|
|
for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): |
|
batch = next(train_loader) |
|
state, train_metric = p_train_step(state, batch) |
|
train_metrics.append(train_metric) |
|
|
|
cur_step = epoch * (len(train_dataset) // train_batch_size) + step |
|
|
|
if cur_step % training_args.logging_steps == 0 and cur_step > 0: |
|
|
|
train_metric = unreplicate(train_metric) |
|
train_time += time.time() - train_start |
|
if has_tensorboard and jax.process_index() == 0: |
|
write_train_metric(summary_writer, train_metrics, train_time, cur_step) |
|
|
|
epochs.write( |
|
f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})" |
|
) |
|
|
|
train_metrics = [] |
|
|
|
if cur_step % training_args.eval_steps == 0 and cur_step > 0: |
|
|
|
eval_metrics = [] |
|
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size) |
|
eval_steps = len(eval_dataset) // eval_batch_size |
|
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False): |
|
|
|
batch = next(eval_loader) |
|
metrics = p_eval_step(state.params, state.dropout_rng, batch) |
|
eval_metrics.append(metrics) |
|
|
|
|
|
eval_metrics = get_metrics(eval_metrics) |
|
eval_metrics = jax.tree_map(jnp.mean, eval_metrics) |
|
|
|
try: |
|
eval_metrics["perplexity"] = math.exp(eval_metrics["loss"]) |
|
except OverflowError: |
|
eval_metrics["perplexity"] = float("inf") |
|
|
|
|
|
desc = f"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})" |
|
epochs.write(desc) |
|
epochs.desc = desc |
|
|
|
|
|
if has_tensorboard and jax.process_index() == 0: |
|
cur_step = epoch * (len(train_dataset) // train_batch_size) |
|
write_eval_metric(summary_writer, eval_metrics, cur_step) |
|
|
|
if cur_step % training_args.save_steps == 0 and cur_step > 0: |
|
|
|
if jax.process_index() == 0: |
|
params = jax.device_get(unreplicate(state.params)) |
|
model.save_pretrained( |
|
training_args.output_dir, |
|
params=params, |
|
push_to_hub=training_args.push_to_hub, |
|
commit_message=f"Saving weights and logs of step {cur_step}", |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|