In [1]:
import torch
from torch.autograd import Variable

# [1] Checkpointing sequential models

In [5]:
from torch.utils.checkpoint import checkpoint_sequential
import torch.nn as nn

model = nn.Sequential(
 nn.Linear(100, 50),
 nn.ReLU(),
 nn.Linear(50, 20),
 nn.ReLU(),
 nn.Linear(20, 5),
 nn.ReLU()
)

input_var = Variable(torch.randn(1, 100), requires_grad=True)
segments = 2

modules = [module for k, module in model._modules.items()]
modules

[Linear(in_features=100, out_features=50, bias=True),
 ReLU(),
 Linear(in_features=50, out_features=20, bias=True),
 ReLU(),
 Linear(in_features=20, out_features=5, bias=True),
 ReLU()]

In [7]:
out = checkpoint_sequential(modules, segments, input_var, use_reentrant=False)
out

tensor([[0.0000, 0.3800, 0.0000, 0.0000, 0.0000]], grad_fn=)

In [8]:
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model.zero_grad()
out.sum().backward()

In [9]:
# now we save the output and parameter gradients that we will use for comparison purposes with
# the non-checkpointed run.
output_checkpointed = out.data.clone()
grad_checkpointed = {}
for name, param in model.named_parameters():
 grad_checkpointed[name] = param.grad.data.clone()

Now that we have executed the checkpointed pass on the model, let's also run the non-checkpointed model and verify that the checkpoint API doesn't change the model outputs or the parameter gradients.

In [10]:
# non-checkpointed run of the model
original = model

# create a new variable using the same tensor data
x = Variable(input_var.data, requires_grad=True)

# get the model output and save it to prevent any modifications
out = original(x)
out_not_checkpointed = out.data.clone()

# calculate the gradient now and save the parameter gradients values
original.zero_grad()
out.sum().backward()
grad_not_checkpointed = {}
for name, param in model.named_parameters():
 grad_not_checkpointed[name] = param.grad.data.clone()

Now that we have done the checkpointed and non-checkpointed pass of the model and saved the output and parameter gradients, let's compare their values

In [13]:
try:
 assert torch.equal(output_checkpointed, out_not_checkpointed), "Outputs do not match!"
 for name in grad_checkpointed:
 assert torch.equal(grad_checkpointed[name], grad_not_checkpointed[name]), f"Gradients for {name} do not match!"
 print("Checkpointed and non-checkpointed results match!")
except AssertionError as e:
 print(f"Assertion failed: {e}")

Checkpointed and non-checkpointed results match!
