Added Gradient Checkpointing and fix bugs

#6
Files changed (1) hide show
  1. smi-ted/training/trainer.py +36 -2
smi-ted/training/trainer.py CHANGED
@@ -2,12 +2,16 @@
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
 
5
  from torch.utils.data import DataLoader
6
  from torch.nn.parallel import DistributedDataParallel as DDP
 
7
 
8
  # Standard library
9
  from tqdm import tqdm
10
  import pandas as pd
 
 
11
  import os
12
 
13
 
@@ -41,6 +45,7 @@ class Trainer:
41
  self.model = DDP(self.model, device_ids=[self.local_rank])
42
 
43
  def _load_checkpoint(self, checkpoint_path):
 
44
  loc = f"cuda:{self.local_rank}"
45
  ckpt_dict = torch.load(checkpoint_path, map_location=loc)
46
  if os.path.exists(os.path.join(self.save_checkpoint_path, 'OPTIMIZER_STATES.pt')):
@@ -262,6 +267,12 @@ class TrainerEncoderDecoder(Trainer):
262
  if self.local_rank == 0:
263
  loss_list.to_csv(os.path.join(self.config.save_checkpoint_path, f'training_loss_epoch{epoch}.csv'), index=False)
264
 
 
 
 
 
 
 
265
  def _run_batch(self, batch_idx, bucket_idx_masked, bucket_targets, bucket_idx_not_masked):
266
  self.optimE.zero_grad(set_to_none=True)
267
  self.optimD.zero_grad(set_to_none=True)
@@ -292,7 +303,13 @@ class TrainerEncoderDecoder(Trainer):
292
  for param in self.model.module.decoder.parameters():
293
  param.requires_grad = False
294
 
295
- logits = self.model.module.encoder(idx_masked)
 
 
 
 
 
 
296
  logits = logits.view(-1, logits.size(-1))
297
  targets = targets.view(-1)
298
  errorE_tmp = self.criterionC(logits, targets) / len(bucket_idx_masked)
@@ -370,6 +387,12 @@ class TrainerDirectDecoder(Trainer):
370
  self.criterionC = nn.CrossEntropyLoss(ignore_index=-100)
371
  self.criterionR = nn.MSELoss()
372
 
 
 
 
 
 
 
373
  def _run_batch(self, bucket_idx_masked, bucket_targets, bucket_idx_not_masked):
374
  padding_idx = 2
375
  error = torch.zeros(1).to(self.local_rank)
@@ -385,7 +408,18 @@ class TrainerDirectDecoder(Trainer):
385
  mask = (idx_masked != padding_idx)
386
 
387
  # encoder forward
388
- true_set, true_cte = self.model.module.encoder(idx_masked, mask=mask, inference=True)
 
 
 
 
 
 
 
 
 
 
 
389
 
390
  # add padding
391
  input_mask_expanded = mask.unsqueeze(-1).expand(true_cte.size()).float()
 
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
+ import torch.utils.checkpoint as checkpoint
6
  from torch.utils.data import DataLoader
7
  from torch.nn.parallel import DistributedDataParallel as DDP
8
+ from fast_transformers.masking import LengthMask
9
 
10
  # Standard library
11
  from tqdm import tqdm
12
  import pandas as pd
13
+ import numpy as np
14
+ import random
15
  import os
16
 
17
 
 
45
  self.model = DDP(self.model, device_ids=[self.local_rank])
46
 
47
  def _load_checkpoint(self, checkpoint_path):
48
+ opt_dict = None
49
  loc = f"cuda:{self.local_rank}"
50
  ckpt_dict = torch.load(checkpoint_path, map_location=loc)
51
  if os.path.exists(os.path.join(self.save_checkpoint_path, 'OPTIMIZER_STATES.pt')):
 
267
  if self.local_rank == 0:
268
  loss_list.to_csv(os.path.join(self.config.save_checkpoint_path, f'training_loss_epoch{epoch}.csv'), index=False)
269
 
270
+ def custom(self, module):
271
+ def custom_forward(*inputs):
272
+ inputs = module(inputs[0])
273
+ return inputs
274
+ return custom_forward
275
+
276
  def _run_batch(self, batch_idx, bucket_idx_masked, bucket_targets, bucket_idx_not_masked):
277
  self.optimE.zero_grad(set_to_none=True)
278
  self.optimD.zero_grad(set_to_none=True)
 
303
  for param in self.model.module.decoder.parameters():
304
  param.requires_grad = False
305
 
306
+ # encoder forward
307
+ x = self.model.module.encoder.tok_emb(idx_masked)
308
+ x = self.model.module.encoder.drop(x)
309
+ x = checkpoint.checkpoint(self.custom(self.model.module.encoder.blocks), x)
310
+ logits = self.model.module.encoder.lang_model(x)
311
+
312
+ # loss function
313
  logits = logits.view(-1, logits.size(-1))
314
  targets = targets.view(-1)
315
  errorE_tmp = self.criterionC(logits, targets) / len(bucket_idx_masked)
 
387
  self.criterionC = nn.CrossEntropyLoss(ignore_index=-100)
388
  self.criterionR = nn.MSELoss()
389
 
390
+ def custom(self, module):
391
+ def custom_forward(*inputs):
392
+ inputs = module(inputs[0], length_mask=inputs[1])
393
+ return inputs
394
+ return custom_forward
395
+
396
  def _run_batch(self, bucket_idx_masked, bucket_targets, bucket_idx_not_masked):
397
  padding_idx = 2
398
  error = torch.zeros(1).to(self.local_rank)
 
408
  mask = (idx_masked != padding_idx)
409
 
410
  # encoder forward
411
+ x = self.model.module.encoder.tok_emb(idx_masked)
412
+ x = self.model.module.encoder.drop(x)
413
+ x = checkpoint.checkpoint(self.custom(self.model.module.encoder.blocks), x, LengthMask(mask.sum(-1), max_len=idx_masked.shape[1]))
414
+
415
+ # mean pooling
416
+ input_masked_expanded = mask.unsqueeze(-1).expand(x.size()).float()
417
+ sum_embeddings = torch.sum(x*input_masked_expanded, 1)
418
+ sum_mask = torch.clamp(input_masked_expanded.sum(1), min=1e-9)
419
+ true_set = sum_embeddings/sum_mask
420
+ true_cte = x
421
+ del x
422
+ torch.cuda.empty_cache()
423
 
424
  # add padding
425
  input_mask_expanded = mask.unsqueeze(-1).expand(true_cte.size()).float()