Fraser commited on
Commit
2095da4
1 Parent(s): 7de4509

add dataset scripts

Browse files
Files changed (8) hide show
  1. .gitignore +2 -0
  2. convert_files.py +17 -0
  3. get_data.sh +23 -0
  4. merge_datasets.py +12 -0
  5. prepare_data.sh +0 -0
  6. train.py +41 -215
  7. train.sh +22 -0
  8. wiki_sentences.py +46 -0
.gitignore CHANGED
@@ -1,3 +1,5 @@
1
  .vscode
2
  venv
3
  *.pyc
 
 
 
1
  .vscode
2
  venv
3
  *.pyc
4
+ segment_*
5
+ dataset.csv
convert_files.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from tqdm import tqdm
3
+ from transformers import AutoTokenizer
4
+
5
+ tokenizer = AutoTokenizer.from_pretrained('gpt2')
6
+
7
+ for i in tqdm(range(298)):
8
+
9
+ with open(f'wikipedia_json_64_filtered/wikipedia.segmented.nltk.split.seq64.{i}.json', 'r') as f:
10
+ rows = json.load(f)
11
+
12
+ tokens = [row['gpt2_token'] for row in rows]
13
+ texts = tokenizer.batch_decode(tokens)
14
+
15
+ with open(f'wikipedia/{i}.txt', 'w') as f:
16
+ for txt in texts:
17
+ f.write(txt.strip() + '\n')
get_data.sh ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=13NnkYAhwszQxc1C5HHfThnF7c1cjzjAD' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=13NnkYAhwszQxc1C5HHfThnF7c1cjzjAD" -O segment_1.zip && rm -rf /tmp/cookies.txt
4
+ wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=14p6FHip_hGTXC-_7SYaK32BpEhZRDJI4' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=14p6FHip_hGTXC-_7SYaK32BpEhZRDJI4" -O segment_2.zip && rm -rf /tmp/cookies.txt
5
+ wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1IaRfTFh51Wf_zPtK6tjE6xw-up_Z6EyN' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1IaRfTFh51Wf_zPtK6tjE6xw-up_Z6EyN" -O segment_3.zip && rm -rf /tmp/cookies.txt
6
+ wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1KGhV397Xfej56uJ9H10xD7tfLdhWlg4q' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1KGhV397Xfej56uJ9H10xD7tfLdhWlg4q" -O segment_4.zip && rm -rf /tmp/cookies.txt
7
+ wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1LfsQ1s9wr1mBG3I1bbvnbyrYmnsrXxZt' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1LfsQ1s9wr1mBG3I1bbvnbyrYmnsrXxZt" -O segment_5.zip && rm -rf /tmp/cookies.txt
8
+ wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1OctFe_JPR0Ajh77FzWdfeYnWZinKl2sW' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1OctFe_JPR0Ajh77FzWdfeYnWZinKl2sW" -O segment_6.zip && rm -rf /tmp/cookies.txt
9
+ wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1W-Yi8gHCcT8O5F4TcDHScH7pOb0GQZdu' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1W-Yi8gHCcT8O5F4TcDHScH7pOb0GQZdu" -O segment_7.zip && rm -rf /tmp/cookies.txt
10
+ wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1jgHjnpe7Vk1pvRgfnH4S4KiRrpUQyqyp' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1jgHjnpe7Vk1pvRgfnH4S4KiRrpUQyqyp" -O segment_8.zip && rm -rf /tmp/cookies.txt
11
+ wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1oVst8RG8G2d21DL6q4DwO7aJxE1vA2fc' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1oVst8RG8G2d21DL6q4DwO7aJxE1vA2fc" -O segment_9.zip && rm -rf /tmp/cookies.txt
12
+ wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1qwckIM8YBbU9bnArB6bAoStY3e9I1kqU' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1qwckIM8YBbU9bnArB6bAoStY3e9I1kqU" -O segment_0.zip && rm -rf /tmp/cookies.txt
13
+
14
+ unzip segment_1.zip
15
+ unzip segment_2.zip
16
+ unzip segment_3.zip
17
+ unzip segment_4.zip
18
+ unzip segment_5.zip
19
+ unzip segment_6.zip
20
+ unzip segment_7.zip
21
+ unzip segment_8.zip
22
+ unzip segment_9.zip
23
+
merge_datasets.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datasets
2
+ import pandas as pd
3
+
4
+ dfs = []
5
+
6
+ for i in range(10):
7
+ dfs.append(
8
+ datasets.ArrowReader.read_table(f'segment_{i}/dataset.arrow').to_pandas()
9
+ )
10
+
11
+ full_df = pd.concat(dfs, ignore_index=True)
12
+ full_df.to_csv('dataset.csv')
prepare_data.sh ADDED
File without changes
train.py CHANGED
@@ -17,7 +17,6 @@
17
  - [ ] Send the schedule time to the compute_loss method and calculate a coefficient based on that.
18
  '''
19
  import logging
20
- import math
21
  import os
22
  import sys
23
  import time
@@ -31,6 +30,7 @@ from tqdm import tqdm
31
 
32
  import jax
33
  import jax.numpy as jnp
 
34
  import optax
35
  import transformers
36
  from flax import jax_utils, traverse_util
@@ -44,7 +44,6 @@ from transformers import (
44
  is_tensorboard_available,
45
  )
46
  from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
47
- from transformers.testing_utils import CaptureLogger
48
 
49
  from t5_vae_flax.src.t5_vae import FlaxT5VaeForAutoencoding
50
  from t5_vae_flax.src.config import T5VaeConfig
@@ -113,7 +112,7 @@ class ModelArguments:
113
  @dataclass
114
  class DataTrainingArguments:
115
  """
116
- Arguments pertaining to what data we are going to input our model for training and eval.
117
  """
118
 
119
  dataset_name: Optional[str] = field(
@@ -123,10 +122,6 @@ class DataTrainingArguments:
123
  default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
124
  )
125
  train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
126
- validation_file: Optional[str] = field(
127
- default=None,
128
- metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
129
- )
130
  max_train_samples: Optional[int] = field(
131
  default=None,
132
  metadata={
@@ -134,21 +129,8 @@ class DataTrainingArguments:
134
  "value if set."
135
  },
136
  )
137
- max_eval_samples: Optional[int] = field(
138
- default=None,
139
- metadata={
140
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
141
- "value if set."
142
- },
143
- )
144
  overwrite_cache: bool = field(
145
- default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
146
- )
147
- validation_split_percentage: Optional[int] = field(
148
- default=5,
149
- metadata={
150
- "help": "The percentage of the train set used as validation set in case there's no validation split"
151
- },
152
  )
153
  block_size: Optional[int] = field(
154
  default=None,
@@ -162,7 +144,7 @@ class DataTrainingArguments:
162
  default=False, metadata={"help": "Stream the dataset."}
163
  )
164
  overwrite_cache: bool = field(
165
- default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
166
  )
167
  preprocessing_num_workers: Optional[int] = field(
168
  default=None,
@@ -170,15 +152,12 @@ class DataTrainingArguments:
170
  )
171
 
172
  def __post_init__(self):
173
- if self.dataset_name is None and self.train_file is None and self.validation_file is None:
174
- raise ValueError("Need either a dataset name or a training/validation file.")
175
  else:
176
  if self.train_file is not None:
177
  extension = self.train_file.split(".")[-1]
178
  assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
179
- if self.validation_file is not None:
180
- extension = self.validation_file.split(".")[-1]
181
- assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
182
 
183
 
184
  class TrainState(train_state.TrainState):
@@ -188,28 +167,19 @@ class TrainState(train_state.TrainState):
188
  return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
189
 
190
 
191
- def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
192
  """
193
  Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
194
  Shuffle batches if `shuffle` is `True`.
195
  """
196
- steps_per_epoch = len(dataset) // batch_size
197
-
198
- if shuffle:
199
- batch_idx = jax.random.permutation(rng, len(dataset))
200
- else:
201
- batch_idx = jnp.arange(len(dataset))
202
-
203
- batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
204
- batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
205
-
206
- for idx in batch_idx:
207
- batch = dataset[idx]
208
- batch = {k: jnp.array(v) for k, v in batch.items()}
209
-
210
- batch = shard(batch)
211
-
212
- yield batch
213
 
214
 
215
  def write_train_metric(summary_writer, train_metrics, train_time, step):
@@ -222,11 +192,6 @@ def write_train_metric(summary_writer, train_metrics, train_time, step):
222
  summary_writer.scalar(tag, val, step - len(vals) + i + 1)
223
 
224
 
225
- def write_eval_metric(summary_writer, eval_metrics, step):
226
- for metric_name, value in eval_metrics.items():
227
- summary_writer.scalar(f"eval_{metric_name}", value, step)
228
-
229
-
230
  def create_learning_rate_fn(
231
  train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
232
  ) -> Callable[[int], jnp.array]:
@@ -284,9 +249,9 @@ def main():
284
  transformers.utils.logging.set_verbosity_error()
285
 
286
  # Set the verbosity to info of the Transformers logger (on main process only):
287
- logger.info(f"Training/evaluation parameters {training_args}")
288
 
289
- # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
290
  # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
291
  # (the dataset will be downloaded automatically from the datasets Hub).
292
  #
@@ -295,35 +260,7 @@ def main():
295
  #
296
  # In distributed training, the load_dataset function guarantees that only one local process can concurrently
297
  # download the dataset.
298
- if data_args.dataset_name is not None:
299
- # Downloading and loading a dataset from the hub.
300
- dataset = load_dataset(
301
- data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, streaming=data_args.streaming, keep_in_memory=False
302
- )
303
-
304
- if "validation" not in dataset.keys():
305
- dataset["validation"] = load_dataset(
306
- data_args.dataset_name,
307
- data_args.dataset_config_name,
308
- split=f"train[:{data_args.validation_split_percentage}%]",
309
- cache_dir=model_args.cache_dir,
310
- )
311
- dataset["train"] = load_dataset(
312
- data_args.dataset_name,
313
- data_args.dataset_config_name,
314
- split=f"train[{data_args.validation_split_percentage}%:]",
315
- cache_dir=model_args.cache_dir,
316
- )
317
- else:
318
- data_files = {}
319
- if data_args.train_file is not None:
320
- data_files["train"] = data_args.train_file
321
- if data_args.validation_file is not None:
322
- data_files["validation"] = data_args.validation_file
323
- extension = data_args.train_file.split(".")[-1]
324
- if extension == "txt":
325
- extension = "text"
326
- dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
327
  # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
328
  # https://huggingface.co/docs/datasets/loading_datasets.html.
329
 
@@ -381,37 +318,6 @@ def main():
381
  assert tokenizer.pad_token == '<PAD>'
382
 
383
  # Preprocessing the datasets.
384
- # First we tokenize all the texts.
385
- if training_args.do_train:
386
- column_names = dataset["train"].column_names
387
- else:
388
- column_names = dataset["validation"].column_names
389
- text_column_name = "text" if "text" in column_names else column_names[0]
390
-
391
- # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
392
- tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")
393
-
394
- def tokenize_function(examples):
395
- with CaptureLogger(tok_logger) as cl:
396
- output = tokenizer(examples[text_column_name])
397
- # clm input could be much much longer than block_size
398
- if "Token indices sequence length is longer than the" in cl.out:
399
- tok_logger.warning(
400
- "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model."
401
- )
402
- return output
403
-
404
- # remove dataset tasks
405
- for k in dataset.keys():
406
- dataset[k].info.task_templates = []
407
-
408
- tokenized_datasets = dataset.map(
409
- tokenize_function,
410
- batched=True,
411
- num_proc=data_args.preprocessing_num_workers,
412
- remove_columns=column_names,
413
- load_from_cache_file=not data_args.overwrite_cache,
414
- )
415
 
416
  if data_args.block_size > tokenizer.model_max_length:
417
  logger.warning(
@@ -422,65 +328,27 @@ def main():
422
 
423
  pad_token_id, start_token_id = tokenizer.pad_token_id, config.decoder_start_token_id
424
 
425
- def clip_texts(examples):
426
- examples["labels"] = examples["input_ids"].copy()
427
-
428
- for i, input_ids in enumerate(examples["input_ids"]):
429
- if len(input_ids) > block_size:
430
- for k in examples.keys():
431
- examples[k][i] = examples[k][i][:block_size]
432
- elif len(input_ids) < block_size:
433
- delta = block_size - len(input_ids)
434
- examples['input_ids'][i] = examples['input_ids'][i] + [pad_token_id] * delta
435
- examples['attention_mask'][i] = examples['attention_mask'][i] + [0] * delta
436
- examples['labels'][i] = examples['labels'][i] + [-100] * delta
437
-
438
- return examples
439
-
440
- logger.info('clip_texts...')
441
- clipped_lm_datasets = tokenized_datasets.map(
442
- clip_texts,
443
- batched=True,
444
- num_proc=data_args.preprocessing_num_workers,
445
- load_from_cache_file=not data_args.overwrite_cache,
446
- )
447
-
448
- def add_decoder_input_ids(examples):
449
- arr_input_ids = jnp.array(examples["input_ids"])
450
- pad = pad_token_id * jnp.ones((arr_input_ids.shape[0], 1), dtype=jnp.int32)
451
- arr_pad_input_ids = jnp.concatenate((arr_input_ids, pad), axis=1)
452
- examples['decoder_input_ids'] = shift_tokens_right(arr_pad_input_ids, pad_token_id, start_token_id)
453
 
454
- arr_attention_mask = jnp.array(examples['attention_mask'])
455
- ones = jnp.ones((arr_attention_mask.shape[0], 1), dtype=jnp.int32)
456
- examples['decoder_attention_mask'] = jnp.concatenate((ones, arr_attention_mask), axis=1)
457
 
458
- for k in ['decoder_input_ids', 'decoder_attention_mask']:
459
- examples[k] = examples[k].tolist()
 
460
 
461
- return examples
 
462
 
463
- logger.info('add_decoder_input_ids...')
464
- lm_datasets = clipped_lm_datasets.map(
465
- add_decoder_input_ids,
466
- batched=True,
467
- num_proc=data_args.preprocessing_num_workers,
468
- load_from_cache_file=not data_args.overwrite_cache,
469
- )
470
 
471
- if training_args.do_train:
472
- if "train" not in tokenized_datasets:
473
- raise ValueError("--do_train requires a train dataset")
474
- train_dataset = lm_datasets["train"]
475
- if data_args.max_train_samples is not None:
476
- train_dataset = train_dataset.select(range(data_args.max_train_samples))
477
 
478
- if training_args.do_eval:
479
- if "validation" not in tokenized_datasets:
480
- raise ValueError("--do_eval requires a validation dataset")
481
- eval_dataset = lm_datasets["validation"]
482
- if data_args.max_eval_samples is not None:
483
- eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
484
 
485
  # Enable tensorboard only on the master node
486
  has_tensorboard = is_tensorboard_available()
@@ -507,13 +375,13 @@ def main():
507
  # Store some constant
508
  num_epochs = int(training_args.num_train_epochs)
509
  train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
510
- eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
511
- steps_per_epoch = len(train_dataset) // train_batch_size
512
  total_train_steps = steps_per_epoch * num_epochs
513
 
514
  # Create learning rate schedule
515
  linear_decay_lr_schedule_fn = create_learning_rate_fn(
516
- len(train_dataset),
517
  train_batch_size,
518
  training_args.num_train_epochs,
519
  training_args.warmup_steps,
@@ -602,26 +470,14 @@ def main():
602
 
603
  return new_state, metrics
604
 
605
- # Define eval fn
606
- def eval_step(params, rng, batch):
607
- labels = batch.pop("labels")
608
- logits, latent_codes = model(**batch, params=params, train=False)[:2]
609
- loss = loss_fn(logits, labels, latent_codes, rng)
610
-
611
- # summarize metrics
612
- metrics = {"loss": loss}
613
- metrics = jax.lax.pmean(metrics, axis_name="batch")
614
- return metrics
615
-
616
- # Create parallel version of the train and eval step
617
  p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
618
- p_eval_step = jax.pmap(eval_step, "batch")
619
 
620
  # Replicate the train state on each device
621
  state = state.replicate()
622
 
623
  logger.info("***** Running training *****")
624
- logger.info(f" Num examples = {len(train_dataset)}")
625
  logger.info(f" Num Epochs = {num_epochs}")
626
  logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
627
  logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
@@ -638,15 +494,15 @@ def main():
638
  rng, input_rng = jax.random.split(rng)
639
 
640
  # Generate an epoch by shuffling sampling indices from the train dataset
641
- train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
642
- steps_per_epoch = len(train_dataset) // train_batch_size
643
  # train
644
  for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
645
  batch = next(train_loader)
646
  state, train_metric = p_train_step(state, batch)
647
  train_metrics.append(train_metric)
648
 
649
- cur_step = epoch * (len(train_dataset) // train_batch_size) + step
650
 
651
  if cur_step % training_args.logging_steps == 0 and cur_step > 0:
652
  # Save metrics
@@ -661,36 +517,6 @@ def main():
661
 
662
  train_metrics = []
663
 
664
- if cur_step % training_args.eval_steps == 0 and cur_step > 0:
665
- # ======================== Evaluating ==============================
666
- eval_metrics = []
667
- eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
668
- eval_steps = len(eval_dataset) // eval_batch_size
669
- for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
670
- # Model forward
671
- batch = next(eval_loader)
672
- metrics = p_eval_step(state.params, state.dropout_rng, batch)
673
- eval_metrics.append(metrics)
674
-
675
- # normalize eval metrics
676
- eval_metrics = get_metrics(eval_metrics)
677
- eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
678
-
679
- try:
680
- eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
681
- except OverflowError:
682
- eval_metrics["perplexity"] = float("inf")
683
-
684
- # Print metrics and update progress bar
685
- desc = f"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})"
686
- epochs.write(desc)
687
- epochs.desc = desc
688
-
689
- # Save metrics
690
- if has_tensorboard and jax.process_index() == 0:
691
- cur_step = epoch * (len(train_dataset) // train_batch_size)
692
- write_eval_metric(summary_writer, eval_metrics, cur_step)
693
-
694
  if cur_step % training_args.save_steps == 0 and cur_step > 0:
695
  # save checkpoint after each epoch and push checkpoint to the hub
696
  if jax.process_index() == 0:
 
17
  - [ ] Send the schedule time to the compute_loss method and calculate a coefficient based on that.
18
  '''
19
  import logging
 
20
  import os
21
  import sys
22
  import time
 
30
 
31
  import jax
32
  import jax.numpy as jnp
33
+ import numpy as onp
34
  import optax
35
  import transformers
36
  from flax import jax_utils, traverse_util
 
44
  is_tensorboard_available,
45
  )
46
  from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
 
47
 
48
  from t5_vae_flax.src.t5_vae import FlaxT5VaeForAutoencoding
49
  from t5_vae_flax.src.config import T5VaeConfig
 
112
  @dataclass
113
  class DataTrainingArguments:
114
  """
115
+ Arguments pertaining to what data we are going to input our model for training.
116
  """
117
 
118
  dataset_name: Optional[str] = field(
 
122
  default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
123
  )
124
  train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
 
 
 
 
125
  max_train_samples: Optional[int] = field(
126
  default=None,
127
  metadata={
 
129
  "value if set."
130
  },
131
  )
 
 
 
 
 
 
 
132
  overwrite_cache: bool = field(
133
+ default=False, metadata={"help": "Overwrite the cached training sets"}
 
 
 
 
 
 
134
  )
135
  block_size: Optional[int] = field(
136
  default=None,
 
144
  default=False, metadata={"help": "Stream the dataset."}
145
  )
146
  overwrite_cache: bool = field(
147
+ default=False, metadata={"help": "Overwrite the cached training sets"}
148
  )
149
  preprocessing_num_workers: Optional[int] = field(
150
  default=None,
 
152
  )
153
 
154
  def __post_init__(self):
155
+ if self.dataset_name is None and self.train_file is None:
156
+ raise ValueError("Need either a dataset name or a training file.")
157
  else:
158
  if self.train_file is not None:
159
  extension = self.train_file.split(".")[-1]
160
  assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
 
 
 
161
 
162
 
163
  class TrainState(train_state.TrainState):
 
167
  return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
168
 
169
 
170
+ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int):
171
  """
172
  Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
173
  Shuffle batches if `shuffle` is `True`.
174
  """
175
+ batch = []
176
+ for row in dataset:
177
+ batch.append(row)
178
+ if len(batch) >= batch_size:
179
+ batch = {k: jnp.stack([row[k] for row in batch]) for k in batch[0].keys()}
180
+ batch = shard(batch)
181
+ yield batch
182
+ batch = []
 
 
 
 
 
 
 
 
 
183
 
184
 
185
  def write_train_metric(summary_writer, train_metrics, train_time, step):
 
192
  summary_writer.scalar(tag, val, step - len(vals) + i + 1)
193
 
194
 
 
 
 
 
 
195
  def create_learning_rate_fn(
196
  train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
197
  ) -> Callable[[int], jnp.array]:
 
249
  transformers.utils.logging.set_verbosity_error()
250
 
251
  # Set the verbosity to info of the Transformers logger (on main process only):
252
+ logger.info(f"Training parameters {training_args}")
253
 
254
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training files (see below)
255
  # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
256
  # (the dataset will be downloaded automatically from the datasets Hub).
257
  #
 
260
  #
261
  # In distributed training, the load_dataset function guarantees that only one local process can concurrently
262
  # download the dataset.
263
+ dataset = load_dataset('text', data_files=[f'wikipedia/{i}.txt' for i in range(298)], cache_dir=model_args.cache_dir, streaming=True)['train']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
265
  # https://huggingface.co/docs/datasets/loading_datasets.html.
266
 
 
318
  assert tokenizer.pad_token == '<PAD>'
319
 
320
  # Preprocessing the datasets.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
 
322
  if data_args.block_size > tokenizer.model_max_length:
323
  logger.warning(
 
328
 
329
  pad_token_id, start_token_id = tokenizer.pad_token_id, config.decoder_start_token_id
330
 
331
+ def tokenize_function(examples):
332
+ output = tokenizer(examples["text"], return_tensors='jax', padding='max_length', max_length=block_size, truncation=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
 
334
+ output['labels'] = onp.array(output['input_ids'].copy())
335
+ output['labels'][output['labels'] == pad_token_id] = -100
336
+ output['labels'] = jnp.array(output['labels'])
337
 
338
+ pad = pad_token_id * jnp.ones((output['input_ids'].shape[0], 1), dtype=jnp.int32)
339
+ arr_pad_input_ids = jnp.concatenate((output['input_ids'], pad), axis=1)
340
+ output['decoder_input_ids'] = shift_tokens_right(arr_pad_input_ids, pad_token_id, start_token_id)
341
 
342
+ ones = jnp.ones((output['attention_mask'].shape[0], 1), dtype=jnp.int32)
343
+ output['decoder_attention_mask'] = jnp.concatenate((ones, output['attention_mask']), axis=1)
344
 
345
+ return output
 
 
 
 
 
 
346
 
347
+ tokenized_datasets = dataset.map(tokenize_function, batched=True)
 
 
 
 
 
348
 
349
+ train_dataset = tokenized_datasets
350
+ if data_args.max_train_samples is not None:
351
+ train_dataset = train_dataset.select(range(data_args.max_train_samples))
 
 
 
352
 
353
  # Enable tensorboard only on the master node
354
  has_tensorboard = is_tensorboard_available()
 
375
  # Store some constant
376
  num_epochs = int(training_args.num_train_epochs)
377
  train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
378
+ train_dataset_len = 97876602
379
+ steps_per_epoch = train_dataset_len // train_batch_size
380
  total_train_steps = steps_per_epoch * num_epochs
381
 
382
  # Create learning rate schedule
383
  linear_decay_lr_schedule_fn = create_learning_rate_fn(
384
+ train_dataset_len,
385
  train_batch_size,
386
  training_args.num_train_epochs,
387
  training_args.warmup_steps,
 
470
 
471
  return new_state, metrics
472
 
473
+ # Create parallel version of the train step
 
 
 
 
 
 
 
 
 
 
 
474
  p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
 
475
 
476
  # Replicate the train state on each device
477
  state = state.replicate()
478
 
479
  logger.info("***** Running training *****")
480
+ logger.info(f" Num examples = {train_dataset_len}")
481
  logger.info(f" Num Epochs = {num_epochs}")
482
  logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
483
  logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
 
494
  rng, input_rng = jax.random.split(rng)
495
 
496
  # Generate an epoch by shuffling sampling indices from the train dataset
497
+ train_loader = data_loader(input_rng, train_dataset, train_batch_size)
498
+ steps_per_epoch = train_dataset_len // train_batch_size
499
  # train
500
  for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
501
  batch = next(train_loader)
502
  state, train_metric = p_train_step(state, batch)
503
  train_metrics.append(train_metric)
504
 
505
+ cur_step = epoch * (train_dataset_len // train_batch_size) + step
506
 
507
  if cur_step % training_args.logging_steps == 0 and cur_step > 0:
508
  # Save metrics
 
517
 
518
  train_metrics = []
519
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
520
  if cur_step % training_args.save_steps == 0 and cur_step > 0:
521
  # save checkpoint after each epoch and push checkpoint to the hub
522
  if jax.process_index() == 0:
train.sh ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export RUN_NAME=single_latent
2
+
3
+ # TODO update to not use tokenizer, instead use gpt2 one
4
+ ./venv/bin/python train.py \
5
+ --t5_model_name_or_path="t5-base" \
6
+ --output_dir="output/${RUN_NAME}" \
7
+ --overwrite_output_dir \
8
+ --do_train \
9
+ --n_latent_tokens 1 \
10
+ --latent_token_size 32 \
11
+ --save_steps="2000" \
12
+ --block_size="128" \
13
+ --per_device_train_batch_size="100" \
14
+ --train_file="INVALID.txt" \
15
+ --overwrite_output_dir \
16
+ --num_train_epochs="1" \
17
+
18
+ # 200 batch size, 128 sequence len: ? (breaks)
19
+ # 100 batch size, 128 sequence len: 252:38:58
20
+ # 10 batch size, 128 sequence len: 281:32:53
21
+
22
+ # Got ~12 hours to train, want 3 saves so one save every 4 hours
wiki_sentences.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # unused
2
+ """Wikipedia Sentences"""
3
+
4
+ from __future__ import absolute_import, division, print_function
5
+
6
+ import os
7
+ import json
8
+
9
+ import datasets
10
+
11
+
12
+ _DESCRIPTION = """\
13
+ Dataset of sentences from Wikipedia (from the [Optimus paper](https://arxiv.org/abs/2004.04092)).
14
+ Each is of mex 64 words & <=256 GPT2 tokens.
15
+ Each row is a tokenised sentence.
16
+ {'token_ids': '{gpt2 token ids}'}
17
+ This is to test the semantics of a Transformer-VAEs latent space by interpolating on sentences.
18
+ """
19
+
20
+ NUM_SEGMENTS = 5
21
+ DOWNLOAD_URLS = 'https://drive.google.com/file/d/13NnkYAhwszQxc1C5HHfThnF7c1cjzjAD/view?usp=sharing, https://drive.google.com/file/d/14p6FHip_hGTXC-_7SYaK32BpEhZRDJI4/view?usp=sharing, https://drive.google.com/file/d/1IaRfTFh51Wf_zPtK6tjE6xw-up_Z6EyN/view?usp=sharing, https://drive.google.com/file/d/1KGhV397Xfej56uJ9H10xD7tfLdhWlg4q/view?usp=sharing, https://drive.google.com/file/d/1LfsQ1s9wr1mBG3I1bbvnbyrYmnsrXxZt/view?usp=sharing, https://drive.google.com/file/d/1OctFe_JPR0Ajh77FzWdfeYnWZinKl2sW/view?usp=sharing, https://drive.google.com/file/d/1W-Yi8gHCcT8O5F4TcDHScH7pOb0GQZdu/view?usp=sharing, https://drive.google.com/file/d/1jgHjnpe7Vk1pvRgfnH4S4KiRrpUQyqyp/view?usp=sharing, https://drive.google.com/file/d/1oVst8RG8G2d21DL6q4DwO7aJxE1vA2fc/view?usp=sharing, https://drive.google.com/file/d/1qwckIM8YBbU9bnArB6bAoStY3e9I1kqU/view?usp=sharing'.split(', ')
22
+
23
+
24
+ class WikiSentences(datasets.GeneratorBasedBuilder):
25
+ """Sentences from Wikipedia."""
26
+
27
+ BUILDER_CONFIGS = [datasets.BuilderConfig(name="main", description="Run through json files one by one.",)]
28
+
29
+ def _info(self):
30
+ return datasets.DatasetInfo(
31
+ description=_DESCRIPTION,
32
+ features=datasets.Features(
33
+ {
34
+ 'token_ids': [datasets.Value("int32")],
35
+ }
36
+ ),
37
+ homepage="https://github.com/Fraser-Greenlee/transformer-vae",
38
+ )
39
+
40
+ def _generate_examples(self, filepath):
41
+ """Generate examples."""
42
+ with open(filepath, encoding="utf-8") as json_lines_file:
43
+ for id_, line in enumerate(json_lines_file):
44
+ yield id_, json.loads(line)
45
+ if id_ >= self.config.max_num_samples:
46
+ break