{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "id": "CqFGp-OjP0_G" }, "outputs": [], "source": [ "import torch\n", "from torch.autograd import Variable" ] }, { "cell_type": "markdown", "metadata": { "id": "to7suvjJQJAM" }, "source": [ "# [1] Checkpointing sequential models" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "1YmlCf4MQEXV", "outputId": "03833d29-11aa-4def-a9e4-650e349201a3" }, "outputs": [ { "data": { "text/plain": [ "[Linear(in_features=100, out_features=50, bias=True),\n", " ReLU(),\n", " Linear(in_features=50, out_features=20, bias=True),\n", " ReLU(),\n", " Linear(in_features=20, out_features=5, bias=True),\n", " ReLU()]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from torch.utils.checkpoint import checkpoint_sequential\n", "import torch.nn as nn\n", "\n", "model = nn.Sequential(\n", " nn.Linear(100, 50),\n", " nn.ReLU(),\n", " nn.Linear(50, 20),\n", " nn.ReLU(),\n", " nn.Linear(20, 5),\n", " nn.ReLU()\n", ")\n", "\n", "input_var = Variable(torch.randn(1, 100), requires_grad=True)\n", "segments = 2\n", "\n", "modules = [module for k, module in model._modules.items()]\n", "modules" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "aHSqU-keQaPe", "outputId": "7ebc66fb-99ab-4d22-fa39-5710fb7ca2cd" }, "outputs": [ { "data": { "text/plain": [ "tensor([[0.0000, 0.3800, 0.0000, 0.0000, 0.0000]], grad_fn=)" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "out = checkpoint_sequential(modules, segments, input_var, use_reentrant=False)\n", "out" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "Q94h7De4RBGA" }, "outputs": [], "source": [ "# run the backwards pass on the model. For backwards pass, for simplicity purpose,\n", "# we won't calculate the loss and rather backprop on out.sum()\n", "model.zero_grad()\n", "out.sum().backward()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "LgNWA7fyRGAk" }, "outputs": [], "source": [ "# now we save the output and parameter gradients that we will use for comparison purposes with\n", "# the non-checkpointed run.\n", "output_checkpointed = out.data.clone()\n", "grad_checkpointed = {}\n", "for name, param in model.named_parameters():\n", " grad_checkpointed[name] = param.grad.data.clone()" ] }, { "cell_type": "markdown", "metadata": { "id": "qkdJd-B3RRWh" }, "source": [ "Now that we have executed the checkpointed pass on the model, let's also run the non-checkpointed model and verify that the checkpoint API doesn't change the model outputs or the parameter gradients." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "Ts5GQzxkRVrU" }, "outputs": [], "source": [ "# non-checkpointed run of the model\n", "original = model\n", "\n", "# create a new variable using the same tensor data\n", "x = Variable(input_var.data, requires_grad=True)\n", "\n", "# get the model output and save it to prevent any modifications\n", "out = original(x)\n", "out_not_checkpointed = out.data.clone()\n", "\n", "# calculate the gradient now and save the parameter gradients values\n", "original.zero_grad()\n", "out.sum().backward()\n", "grad_not_checkpointed = {}\n", "for name, param in model.named_parameters():\n", " grad_not_checkpointed[name] = param.grad.data.clone()" ] }, { "cell_type": "markdown", "metadata": { "id": "YiV1VBzyRX2Y" }, "source": [ "Now that we have done the checkpointed and non-checkpointed pass of the model and saved the output and parameter gradients, let's compare their values" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "v9Tj9o8VRYq2", "outputId": "bd8a8100-d660-4858-eb48-4a85aca01c69" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Checkpointed and non-checkpointed results match!\n" ] } ], "source": [ "try:\n", " assert torch.equal(output_checkpointed, out_not_checkpointed), \"Outputs do not match!\"\n", " for name in grad_checkpointed:\n", " assert torch.equal(grad_checkpointed[name], grad_not_checkpointed[name]), f\"Gradients for {name} do not match!\"\n", " print(\"Checkpointed and non-checkpointed results match!\")\n", "except AssertionError as e:\n", " print(f\"Assertion failed: {e}\")" ] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }