Fraser commited on
Commit
eb70d54
1 Parent(s): fa431ac

add base model code

Browse files
Files changed (5) hide show
  1. .gitignore +3 -0
  2. README.md +37 -0
  3. check_install.py +15 -0
  4. setup_tpu_vm_venv.sh +19 -0
  5. train.py +707 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .vscode
2
+ venv
3
+ *.pyc
README.md ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ tags: vae
4
+ license: apache-2.0
5
+ ---
6
+
7
+ # T5-VAE-Wiki (flax)
8
+
9
+ A Transformer-VAE made using flax.
10
+
11
+ Try the [demo] (TODO)!
12
+
13
+ It has been trained to interpolate on sentences form wikipedia.
14
+
15
+ Done as part of Huggingface community training ([see forum post](https://discuss.huggingface.co/t/train-a-vae-to-interpolate-on-english-sentences/7548)).
16
+
17
+ Builds on T5, using an autoencoder to convert it into an MMD-VAE ([more info](http://fras.uk/ml/large%20prior-free%20models/transformer-vae/2020/08/13/Transformers-as-Variational-Autoencoders.html)).
18
+
19
+ ## How to use from the 🤗/transformers library
20
+
21
+ Add model repo as a submodule:
22
+ ```bash
23
+ git submodule add https://github.com/Fraser-Greenlee/t5-vae-flax.git t5_vae_flax
24
+ ```
25
+
26
+ ```python
27
+ from transformers import AutoTokenizer
28
+ from t5_vae_flax.src.t5_vae import FlaxT5VaeForAutoencoding
29
+
30
+ tokenizer = AutoTokenizer.from_pretrained("t5-base")
31
+
32
+ model = FlaxT5VaeForAutoencoding.from_pretrained("flax-community/t5-vae-python")
33
+ ```
34
+
35
+ ## Setup
36
+
37
+ Run `setup_tpu_vm_venv.sh` to setup a virtual enviroment on a TPU VM for training.
check_install.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import FlaxRobertaModel, RobertaTokenizerFast
2
+ from datasets import load_dataset
3
+ import jax
4
+
5
+ dataset = load_dataset('oscar', "unshuffled_deduplicated_en", split='train', streaming=True)
6
+
7
+ dummy_input = next(iter(dataset))["text"]
8
+
9
+ tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
10
+ input_ids = tokenizer(dummy_input, return_tensors="np").input_ids[:, :10]
11
+
12
+ model = FlaxRobertaModel.from_pretrained("julien-c/dummy-unknown")
13
+
14
+ # run a forward pass, should return an object `FlaxBaseModelOutputWithPooling`
15
+ z = model(input_ids)
setup_tpu_vm_venv.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # setup training on a TPU VM
2
+ rm -fr venv
3
+ python3 -m venv venv
4
+ source venv/bin/activate
5
+ pip install -U pip
6
+ pip install -U wheel
7
+ pip install requests
8
+ pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
9
+
10
+ cd ..
11
+ git clone https://github.com/huggingface/transformers.git
12
+ cd transformers
13
+ pip install -e ".[flax]"
14
+ cd ..
15
+
16
+ git clone https://github.com/huggingface/datasets.git
17
+ cd datasets
18
+ pip install -e ".[streaming]"
19
+ cd ..
train.py ADDED
@@ -0,0 +1,707 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Pre-training/Fine-tuning seq2seq models on autoencoding a dataset.
3
+
4
+ TODO:
5
+ - [ ] Add reg loss
6
+ - [x] calculate MMD loss
7
+ - [ ] schedule MMD loss weight
8
+ - [ ] Add these params to the training arguments.
9
+
10
+ reg_schedule_k (:obj:`float`, `optional`, defaults to 0.0025):
11
+ Multiplied by global_step in a sigmoid, more gradually increase regulariser loss weight.
12
+ reg_schedule_b (:obj:`float`, `optional`, defaults to 6.25):
13
+ Added to global step in sigmoid, further delays increase in regulariser loss weight.
14
+ use_extra_logs (:obj:`bool`, `optional`, defaults to False):
15
+ Store extra logs during each training inference.
16
+
17
+ - [ ] Send the schedule time to the compute_loss method and calculate a coefficient based on that.
18
+ '''
19
+ import logging
20
+ import math
21
+ import os
22
+ import sys
23
+ import time
24
+ from dataclasses import dataclass, field
25
+ from pathlib import Path
26
+ from typing import Callable, Optional
27
+
28
+ import datasets
29
+ from datasets import Dataset, load_dataset
30
+ from tqdm import tqdm
31
+
32
+ import jax
33
+ import jax.numpy as jnp
34
+ import optax
35
+ import transformers
36
+ from flax import jax_utils, traverse_util
37
+ from flax.jax_utils import unreplicate
38
+ from flax.training import train_state
39
+ from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
40
+ from transformers import (
41
+ AutoTokenizer,
42
+ HfArgumentParser,
43
+ TrainingArguments,
44
+ is_tensorboard_available,
45
+ )
46
+ from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
47
+ from transformers.testing_utils import CaptureLogger
48
+
49
+ from t5_vae_flax.src.t5_vae import FlaxT5VaeForAutoencoding
50
+ from t5_vae_flax.src.config import T5VaeConfig
51
+
52
+
53
+ logger = logging.getLogger(__name__)
54
+
55
+
56
+ @dataclass
57
+ class ModelArguments:
58
+ """
59
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
60
+ """
61
+
62
+ model_name_or_path: Optional[str] = field(
63
+ default=None,
64
+ metadata={
65
+ "help": "The model checkpoint for weights initialization."
66
+ "Don't set if you want to train a model from scratch."
67
+ },
68
+ )
69
+ t5_model_name_or_path: Optional[str] = field(
70
+ default=None,
71
+ metadata={
72
+ "help": "The T5 model checkpoint for weights initialization."
73
+ "Needed when not starting from a T5-VAE model."
74
+ },
75
+ )
76
+ n_latent_tokens: Optional[int] = field(
77
+ default=6,
78
+ metadata={
79
+ "help": "Number of latent tokens (must be less than seq length)."
80
+ },
81
+ )
82
+ latent_token_size: Optional[int] = field(
83
+ default=32,
84
+ metadata={
85
+ "help": "Number of dimensions to use for each latent token."
86
+ },
87
+ )
88
+ add_special_tokens: bool = field(
89
+ default=False,
90
+ metadata={"help": "Add these special tokens to the tokenizer: {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}"},
91
+ )
92
+ config_path: Optional[str] = field(
93
+ default=None, metadata={"help": "Pretrained config path"}
94
+ )
95
+ tokenizer_name: Optional[str] = field(
96
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
97
+ )
98
+ cache_dir: Optional[str] = field(
99
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
100
+ )
101
+ use_fast_tokenizer: bool = field(
102
+ default=True,
103
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
104
+ )
105
+ dtype: Optional[str] = field(
106
+ default="float32",
107
+ metadata={
108
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
109
+ },
110
+ )
111
+
112
+
113
+ @dataclass
114
+ class DataTrainingArguments:
115
+ """
116
+ Arguments pertaining to what data we are going to input our model for training and eval.
117
+ """
118
+
119
+ dataset_name: Optional[str] = field(
120
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
121
+ )
122
+ dataset_config_name: Optional[str] = field(
123
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
124
+ )
125
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
126
+ validation_file: Optional[str] = field(
127
+ default=None,
128
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
129
+ )
130
+ max_train_samples: Optional[int] = field(
131
+ default=None,
132
+ metadata={
133
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
134
+ "value if set."
135
+ },
136
+ )
137
+ max_eval_samples: Optional[int] = field(
138
+ default=None,
139
+ metadata={
140
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
141
+ "value if set."
142
+ },
143
+ )
144
+ overwrite_cache: bool = field(
145
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
146
+ )
147
+ validation_split_percentage: Optional[int] = field(
148
+ default=5,
149
+ metadata={
150
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
151
+ },
152
+ )
153
+ block_size: Optional[int] = field(
154
+ default=None,
155
+ metadata={
156
+ "help": "Optional input sequence length after tokenization. "
157
+ "The training dataset will be truncated in block of this size for training. "
158
+ "Default to the model max input length for single sentence inputs (take into account special tokens)."
159
+ },
160
+ )
161
+ streaming: bool = field(
162
+ default=False, metadata={"help": "Stream the dataset."}
163
+ )
164
+ overwrite_cache: bool = field(
165
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
166
+ )
167
+ preprocessing_num_workers: Optional[int] = field(
168
+ default=None,
169
+ metadata={"help": "The number of processes to use for the preprocessing."},
170
+ )
171
+
172
+ def __post_init__(self):
173
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
174
+ raise ValueError("Need either a dataset name or a training/validation file.")
175
+ else:
176
+ if self.train_file is not None:
177
+ extension = self.train_file.split(".")[-1]
178
+ assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
179
+ if self.validation_file is not None:
180
+ extension = self.validation_file.split(".")[-1]
181
+ assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
182
+
183
+
184
+ class TrainState(train_state.TrainState):
185
+ dropout_rng: jnp.ndarray
186
+
187
+ def replicate(self):
188
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
189
+
190
+
191
+ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
192
+ """
193
+ Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
194
+ Shuffle batches if `shuffle` is `True`.
195
+ """
196
+ steps_per_epoch = len(dataset) // batch_size
197
+
198
+ if shuffle:
199
+ batch_idx = jax.random.permutation(rng, len(dataset))
200
+ else:
201
+ batch_idx = jnp.arange(len(dataset))
202
+
203
+ batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
204
+ batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
205
+
206
+ for idx in batch_idx:
207
+ batch = dataset[idx]
208
+ batch = {k: jnp.array(v) for k, v in batch.items()}
209
+
210
+ batch = shard(batch)
211
+
212
+ yield batch
213
+
214
+
215
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
216
+ summary_writer.scalar("train_time", train_time, step)
217
+
218
+ train_metrics = get_metrics(train_metrics)
219
+ for key, vals in train_metrics.items():
220
+ tag = f"train_{key}"
221
+ for i, val in enumerate(vals):
222
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
223
+
224
+
225
+ def write_eval_metric(summary_writer, eval_metrics, step):
226
+ for metric_name, value in eval_metrics.items():
227
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
228
+
229
+
230
+ def create_learning_rate_fn(
231
+ train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
232
+ ) -> Callable[[int], jnp.array]:
233
+ """Returns a linear warmup, linear_decay learning rate function."""
234
+ steps_per_epoch = train_ds_size // train_batch_size
235
+ num_train_steps = steps_per_epoch * num_train_epochs
236
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
237
+ decay_fn = optax.linear_schedule(
238
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
239
+ )
240
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
241
+ return schedule_fn
242
+
243
+
244
+ def main():
245
+ # See all possible arguments in src/transformers/training_args.py
246
+ # or by passing the --help flag to this script.
247
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
248
+
249
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
250
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
251
+ # If we pass only one argument to the script and it's the path to a json file,
252
+ # let's parse it to get our arguments.
253
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
254
+ else:
255
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
256
+
257
+ if (
258
+ os.path.exists(training_args.output_dir)
259
+ and os.listdir(training_args.output_dir)
260
+ and training_args.do_train
261
+ and not training_args.overwrite_output_dir
262
+ ):
263
+ raise ValueError(
264
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
265
+ "Use --overwrite_output_dir to overcome."
266
+ )
267
+
268
+ if data_args.block_size is None:
269
+ raise Exception('Must set block_size so we know what length of sequence to autoencode.')
270
+
271
+ # Make one log on every process with the configuration for debugging.
272
+ logging.basicConfig(
273
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
274
+ datefmt="%m/%d/%Y %H:%M:%S",
275
+ level=logging.INFO,
276
+ )
277
+ # Setup logging, we only want one process per machine to log things on the screen.
278
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
279
+ if jax.process_index() == 0:
280
+ datasets.utils.logging.set_verbosity_warning()
281
+ transformers.utils.logging.set_verbosity_info()
282
+ else:
283
+ datasets.utils.logging.set_verbosity_error()
284
+ transformers.utils.logging.set_verbosity_error()
285
+
286
+ # Set the verbosity to info of the Transformers logger (on main process only):
287
+ logger.info(f"Training/evaluation parameters {training_args}")
288
+
289
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
290
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
291
+ # (the dataset will be downloaded automatically from the datasets Hub).
292
+ #
293
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
294
+ # 'text' is found. You can easily tweak this behavior (see below).
295
+ #
296
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
297
+ # download the dataset.
298
+ if data_args.dataset_name is not None:
299
+ # Downloading and loading a dataset from the hub.
300
+ dataset = load_dataset(
301
+ data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, streaming=data_args.streaming, keep_in_memory=False
302
+ )
303
+
304
+ if "validation" not in dataset.keys():
305
+ dataset["validation"] = load_dataset(
306
+ data_args.dataset_name,
307
+ data_args.dataset_config_name,
308
+ split=f"train[:{data_args.validation_split_percentage}%]",
309
+ cache_dir=model_args.cache_dir,
310
+ )
311
+ dataset["train"] = load_dataset(
312
+ data_args.dataset_name,
313
+ data_args.dataset_config_name,
314
+ split=f"train[{data_args.validation_split_percentage}%:]",
315
+ cache_dir=model_args.cache_dir,
316
+ )
317
+ else:
318
+ data_files = {}
319
+ if data_args.train_file is not None:
320
+ data_files["train"] = data_args.train_file
321
+ if data_args.validation_file is not None:
322
+ data_files["validation"] = data_args.validation_file
323
+ extension = data_args.train_file.split(".")[-1]
324
+ if extension == "txt":
325
+ extension = "text"
326
+ dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
327
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
328
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
329
+
330
+ # Load pretrained model and tokenizer
331
+
332
+ # Distributed training:
333
+ # The .from_pretrained methods guarantee that only one local process can concurrently
334
+ # download model & vocab.
335
+
336
+ if model_args.config_path:
337
+ config = T5VaeConfig.from_pretrained(
338
+ model_args.config_path, cache_dir=model_args.cache_dir
339
+ )
340
+ elif model_args.model_name_or_path:
341
+ config = T5VaeConfig.from_pretrained(
342
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir
343
+ )
344
+ else:
345
+ config = T5VaeConfig(**model_args.__dict__)
346
+ logger.warning("You are instantiating a new config instance from scratch.")
347
+
348
+ if model_args.tokenizer_name:
349
+ tokenizer = AutoTokenizer.from_pretrained(
350
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
351
+ )
352
+ elif model_args.t5_model_name_or_path:
353
+ tokenizer = AutoTokenizer.from_pretrained(
354
+ model_args.t5_model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
355
+ )
356
+ else:
357
+ raise ValueError(
358
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
359
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
360
+ )
361
+
362
+ if model_args.model_name_or_path:
363
+ model = FlaxT5VaeForAutoencoding.from_pretrained(
364
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
365
+ )
366
+ assert model.params['t5']['shared']['embedding'].shape[0] == len(tokenizer), "T5 Tokenizer doesn't match T5Vae embedding size."
367
+ else:
368
+ vocab_size = len(tokenizer)
369
+ config.t5.vocab_size = vocab_size
370
+ config.vocab_size = vocab_size
371
+ logger.info("Training new model from scratch.")
372
+ model = FlaxT5VaeForAutoencoding(
373
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
374
+ )
375
+
376
+ if model_args.add_special_tokens:
377
+ special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}
378
+ num_added_tokens = tokenizer.add_special_tokens(special_tokens_dict)
379
+ print('We have added', num_added_tokens, 'tokens to GPT2')
380
+ model.resize_token_embeddings(len(tokenizer))
381
+ assert tokenizer.pad_token == '<PAD>'
382
+
383
+ # Preprocessing the datasets.
384
+ # First we tokenize all the texts.
385
+ if training_args.do_train:
386
+ column_names = dataset["train"].column_names
387
+ else:
388
+ column_names = dataset["validation"].column_names
389
+ text_column_name = "text" if "text" in column_names else column_names[0]
390
+
391
+ # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
392
+ tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")
393
+
394
+ def tokenize_function(examples):
395
+ with CaptureLogger(tok_logger) as cl:
396
+ output = tokenizer(examples[text_column_name])
397
+ # clm input could be much much longer than block_size
398
+ if "Token indices sequence length is longer than the" in cl.out:
399
+ tok_logger.warning(
400
+ "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model."
401
+ )
402
+ return output
403
+
404
+ # remove dataset tasks
405
+ for k in dataset.keys():
406
+ dataset[k].info.task_templates = []
407
+
408
+ tokenized_datasets = dataset.map(
409
+ tokenize_function,
410
+ batched=True,
411
+ num_proc=data_args.preprocessing_num_workers,
412
+ remove_columns=column_names,
413
+ load_from_cache_file=not data_args.overwrite_cache,
414
+ )
415
+
416
+ if data_args.block_size > tokenizer.model_max_length:
417
+ logger.warning(
418
+ f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model"
419
+ f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
420
+ )
421
+ block_size = min(data_args.block_size, tokenizer.model_max_length)
422
+
423
+ pad_token_id, start_token_id = tokenizer.pad_token_id, config.decoder_start_token_id
424
+
425
+ def clip_texts(examples):
426
+ examples["labels"] = examples["input_ids"].copy()
427
+
428
+ for i, input_ids in enumerate(examples["input_ids"]):
429
+ if len(input_ids) > block_size:
430
+ for k in examples.keys():
431
+ examples[k][i] = examples[k][i][:block_size]
432
+ elif len(input_ids) < block_size:
433
+ delta = block_size - len(input_ids)
434
+ examples['input_ids'][i] = examples['input_ids'][i] + [pad_token_id] * delta
435
+ examples['attention_mask'][i] = examples['attention_mask'][i] + [0] * delta
436
+ examples['labels'][i] = examples['labels'][i] + [-100] * delta
437
+
438
+ return examples
439
+
440
+ logger.info('clip_texts...')
441
+ clipped_lm_datasets = tokenized_datasets.map(
442
+ clip_texts,
443
+ batched=True,
444
+ num_proc=data_args.preprocessing_num_workers,
445
+ load_from_cache_file=not data_args.overwrite_cache,
446
+ )
447
+
448
+ def add_decoder_input_ids(examples):
449
+ arr_input_ids = jnp.array(examples["input_ids"])
450
+ pad = pad_token_id * jnp.ones((arr_input_ids.shape[0], 1), dtype=jnp.int32)
451
+ arr_pad_input_ids = jnp.concatenate((arr_input_ids, pad), axis=1)
452
+ examples['decoder_input_ids'] = shift_tokens_right(arr_pad_input_ids, pad_token_id, start_token_id)
453
+
454
+ arr_attention_mask = jnp.array(examples['attention_mask'])
455
+ ones = jnp.ones((arr_attention_mask.shape[0], 1), dtype=jnp.int32)
456
+ examples['decoder_attention_mask'] = jnp.concatenate((ones, arr_attention_mask), axis=1)
457
+
458
+ for k in ['decoder_input_ids', 'decoder_attention_mask']:
459
+ examples[k] = examples[k].tolist()
460
+
461
+ return examples
462
+
463
+ logger.info('add_decoder_input_ids...')
464
+ lm_datasets = clipped_lm_datasets.map(
465
+ add_decoder_input_ids,
466
+ batched=True,
467
+ num_proc=data_args.preprocessing_num_workers,
468
+ load_from_cache_file=not data_args.overwrite_cache,
469
+ )
470
+
471
+ if training_args.do_train:
472
+ if "train" not in tokenized_datasets:
473
+ raise ValueError("--do_train requires a train dataset")
474
+ train_dataset = lm_datasets["train"]
475
+ if data_args.max_train_samples is not None:
476
+ train_dataset = train_dataset.select(range(data_args.max_train_samples))
477
+
478
+ if training_args.do_eval:
479
+ if "validation" not in tokenized_datasets:
480
+ raise ValueError("--do_eval requires a validation dataset")
481
+ eval_dataset = lm_datasets["validation"]
482
+ if data_args.max_eval_samples is not None:
483
+ eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
484
+
485
+ # Enable tensorboard only on the master node
486
+ has_tensorboard = is_tensorboard_available()
487
+ if has_tensorboard and jax.process_index() == 0:
488
+ try:
489
+ from flax.metrics.tensorboard import SummaryWriter
490
+
491
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
492
+ except ImportError as ie:
493
+ has_tensorboard = False
494
+ logger.warning(
495
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
496
+ )
497
+ else:
498
+ logger.warning(
499
+ "Unable to display metrics through TensorBoard because the package is not installed: "
500
+ "Please run pip install tensorboard to enable."
501
+ )
502
+
503
+ # Initialize our training
504
+ rng = jax.random.PRNGKey(training_args.seed)
505
+ rng, dropout_rng = jax.random.split(rng)
506
+
507
+ # Store some constant
508
+ num_epochs = int(training_args.num_train_epochs)
509
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
510
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
511
+ steps_per_epoch = len(train_dataset) // train_batch_size
512
+ total_train_steps = steps_per_epoch * num_epochs
513
+
514
+ # Create learning rate schedule
515
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
516
+ len(train_dataset),
517
+ train_batch_size,
518
+ training_args.num_train_epochs,
519
+ training_args.warmup_steps,
520
+ training_args.learning_rate,
521
+ )
522
+
523
+ # We use Optax's "masking" functionality to not apply weight decay
524
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
525
+ # mask boolean with the same structure as the parameters.
526
+ # The mask is True for parameters that should be decayed.
527
+ # Note that this mask is specifically adapted for FlaxGPT2.
528
+ # For other models, one should correct the layer norm parameter naming
529
+ # accordingly.
530
+ def decay_mask_fn(params):
531
+ flat_params = traverse_util.flatten_dict(params)
532
+ flat_mask = {
533
+ path: (path[-1] != "bias" and path[-2:] not in [("ln_1", "scale"), ("ln_2", "scale"), ("ln_f", "scale")])
534
+ for path in flat_params
535
+ }
536
+ return traverse_util.unflatten_dict(flat_mask)
537
+
538
+ # create adam optimizer
539
+ if training_args.adafactor:
540
+ # We use the default parameters here to initialize adafactor,
541
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
542
+ optimizer = optax.adafactor(
543
+ learning_rate=linear_decay_lr_schedule_fn,
544
+ )
545
+ else:
546
+ optimizer = optax.adamw(
547
+ learning_rate=linear_decay_lr_schedule_fn,
548
+ b1=training_args.adam_beta1,
549
+ b2=training_args.adam_beta2,
550
+ eps=training_args.adam_epsilon,
551
+ weight_decay=training_args.weight_decay,
552
+ mask=decay_mask_fn,
553
+ )
554
+
555
+ # Setup train state
556
+ state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
557
+
558
+ def compute_kernel(x, y):
559
+ x_size = x.shape[0]
560
+ y_size = y.shape[0]
561
+ dim = x.shape[1]
562
+ tiled_x = jnp.repeat(jnp.reshape(x, (x_size, 1, dim)), y_size, axis=1)
563
+ tiled_y = jnp.repeat(jnp.reshape(y, (1, y_size, dim)), x_size, axis=0)
564
+ return jnp.exp(-jnp.mean((tiled_x - tiled_y) ** 2, axis=2) / dim * 1.0)
565
+
566
+ def compute_mmd(x, y):
567
+ x_kernel = compute_kernel(x, x)
568
+ y_kernel = compute_kernel(y, y)
569
+ xy_kernel = compute_kernel(x, y)
570
+ return jnp.mean(x_kernel) + jnp.mean(y_kernel) - 2 * jnp.mean(xy_kernel)
571
+
572
+ def regulariser_loss(latent_codes, rng):
573
+ true_samples = jax.random.normal(rng, latent_codes.shape)
574
+ # return jax.vmap(compute_mmd)(true_samples, latent_codes)
575
+ return compute_mmd(true_samples, latent_codes)
576
+
577
+ def loss_fn(logits, labels, latent_codes, regulariser_rng):
578
+ shift_logits = logits[..., :-1, :]
579
+ loss = optax.softmax_cross_entropy(shift_logits, onehot(labels, logits.shape[-1]))
580
+ reg_loss = regulariser_loss(latent_codes.reshape(-1, latent_codes.shape[-1]), regulariser_rng)
581
+ return loss.mean() + reg_loss.mean()
582
+
583
+ # Define gradient update step fn
584
+ def train_step(state, batch):
585
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
586
+ new_dropout_rng, regulariser_rng = jax.random.split(new_dropout_rng)
587
+
588
+ def compute_loss(params):
589
+ labels = batch.pop("labels")
590
+ outputs = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)
591
+ loss = loss_fn(outputs[0], labels, outputs[1], regulariser_rng)
592
+ return loss
593
+
594
+ grad_fn = jax.value_and_grad(compute_loss)
595
+ loss, grad = grad_fn(state.params)
596
+ grad = jax.lax.pmean(grad, "batch")
597
+
598
+ new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
599
+
600
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
601
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
602
+
603
+ return new_state, metrics
604
+
605
+ # Define eval fn
606
+ def eval_step(params, rng, batch):
607
+ labels = batch.pop("labels")
608
+ logits, latent_codes = model(**batch, params=params, train=False)[:2]
609
+ loss = loss_fn(logits, labels, latent_codes, rng)
610
+
611
+ # summarize metrics
612
+ metrics = {"loss": loss}
613
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
614
+ return metrics
615
+
616
+ # Create parallel version of the train and eval step
617
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
618
+ p_eval_step = jax.pmap(eval_step, "batch")
619
+
620
+ # Replicate the train state on each device
621
+ state = state.replicate()
622
+
623
+ logger.info("***** Running training *****")
624
+ logger.info(f" Num examples = {len(train_dataset)}")
625
+ logger.info(f" Num Epochs = {num_epochs}")
626
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
627
+ logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
628
+ logger.info(f" Total optimization steps = {total_train_steps}")
629
+
630
+ train_time = 0
631
+ train_metrics = []
632
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
633
+ for epoch in epochs:
634
+ # ======================== Training ================================
635
+ train_start = time.time()
636
+
637
+ # Create sampling rng
638
+ rng, input_rng = jax.random.split(rng)
639
+
640
+ # Generate an epoch by shuffling sampling indices from the train dataset
641
+ train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
642
+ steps_per_epoch = len(train_dataset) // train_batch_size
643
+ # train
644
+ for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
645
+ batch = next(train_loader)
646
+ state, train_metric = p_train_step(state, batch)
647
+ train_metrics.append(train_metric)
648
+
649
+ cur_step = epoch * (len(train_dataset) // train_batch_size) + step
650
+
651
+ if cur_step % training_args.logging_steps == 0 and cur_step > 0:
652
+ # Save metrics
653
+ train_metric = unreplicate(train_metric)
654
+ train_time += time.time() - train_start
655
+ if has_tensorboard and jax.process_index() == 0:
656
+ write_train_metric(summary_writer, train_metrics, train_time, cur_step)
657
+
658
+ epochs.write(
659
+ f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
660
+ )
661
+
662
+ train_metrics = []
663
+
664
+ if cur_step % training_args.eval_steps == 0 and cur_step > 0:
665
+ # ======================== Evaluating ==============================
666
+ eval_metrics = []
667
+ eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
668
+ eval_steps = len(eval_dataset) // eval_batch_size
669
+ for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
670
+ # Model forward
671
+ batch = next(eval_loader)
672
+ metrics = p_eval_step(state.params, state.dropout_rng, batch)
673
+ eval_metrics.append(metrics)
674
+
675
+ # normalize eval metrics
676
+ eval_metrics = get_metrics(eval_metrics)
677
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
678
+
679
+ try:
680
+ eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
681
+ except OverflowError:
682
+ eval_metrics["perplexity"] = float("inf")
683
+
684
+ # Print metrics and update progress bar
685
+ desc = f"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})"
686
+ epochs.write(desc)
687
+ epochs.desc = desc
688
+
689
+ # Save metrics
690
+ if has_tensorboard and jax.process_index() == 0:
691
+ cur_step = epoch * (len(train_dataset) // train_batch_size)
692
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
693
+
694
+ if cur_step % training_args.save_steps == 0 and cur_step > 0:
695
+ # save checkpoint after each epoch and push checkpoint to the hub
696
+ if jax.process_index() == 0:
697
+ params = jax.device_get(unreplicate(state.params))
698
+ model.save_pretrained(
699
+ training_args.output_dir,
700
+ params=params,
701
+ push_to_hub=training_args.push_to_hub,
702
+ commit_message=f"Saving weights and logs of step {cur_step}",
703
+ )
704
+
705
+
706
+ if __name__ == "__main__":
707
+ main()