{
"cells": [
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import typing\n",
"import functorch\n",
"import itertools"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 2.3 Tensors\n",
"### We diagrams tensors, which can be vertically and horizontally decomposed.\n",
"
"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0.6837, 0.6853]])"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# This diagram shows a function h : 3, 4 2, 6 -> 1 2 constructed out of f: 4 2, 6 -> 3 3 and g: 3, 3 3 -> 1 2\n",
"# We use assertions and random outputs to represent generic functions, and how diagrams relate to code.\n",
"T = torch.Tensor\n",
"def f(x0 : T, x1 : T):\n",
" \"\"\" f: 4 2, 6 -> 3 3 \"\"\"\n",
" assert x0.size() == torch.Size([4,2])\n",
" assert x1.size() == torch.Size([6])\n",
" return torch.rand([3,3])\n",
"def g(x0 : T, x1: T):\n",
" \"\"\" g: 3, 3 3 -> 1 2 \"\"\"\n",
" assert x0.size() == torch.Size([3])\n",
" assert x1.size() == torch.Size([3, 3])\n",
" return torch.rand([1,2])\n",
"def h(x0 : T, x1 : T, x2 : T):\n",
" \"\"\" h: 3, 4 2, 6 -> 1 2\"\"\"\n",
" assert x0.size() == torch.Size([3])\n",
" assert x1.size() == torch.Size([4, 2])\n",
" assert x2.size() == torch.Size([6])\n",
" return g(x0, f(x1,x2))\n",
"\n",
"h(torch.rand([3]), torch.rand([4, 2]), torch.rand([6]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2.3.1 Indexes\n",
"### Figure 8: Indexes\n",
"
"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([6, 7, 8])"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Extracting a subtensor is a process we are familiar with. Consider,\n",
"# A (4 3) tensor\n",
"table = torch.arange(0,12).view(4,3)\n",
"row = table[2,:]\n",
"row"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Figure 9: Subtensors\n",
"
"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"# Different orders of access give the same result.\n",
"# Set up a random (5 7) tensor\n",
"a, b = 5, 7\n",
"Xab = torch.rand([a] + [b])\n",
"# Show that all pairs of indexes give the same result\n",
"for ia, jb in itertools.product(range(a), range(b)):\n",
" assert Xab[ia, jb] == Xab[ia, :][jb]\n",
" assert Xab[ia, jb] == Xab[:, jb][ia]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2.3.2 Broadcasting\n",
"### Figure 10: Broadcasting\n",
"
\n",
"
"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"a, b, c, d = [3], [2], [4], [3]\n",
"T = torch.Tensor\n",
"\n",
"# We have some function from a to b;\n",
"def G(Xa: T) -> T:\n",
" \"\"\" G: a -> b \"\"\"\n",
" return sum(Xa**2) + torch.ones(b)\n",
"\n",
"# We could bootstrap a definition of broadcasting,\n",
"# Note that we are using spaces to indicate tensoring. \n",
"# We will use commas for tupling, which is in line with standard notation while writing code.\n",
"def Gc(Xac: T) -> T:\n",
" \"\"\" G c : a c -> b c \"\"\"\n",
" Ybc = torch.zeros(b + c)\n",
" for j in range(c[0]):\n",
" Ybc[:,jc] = G(Xac[:,jc])\n",
" return Ybc\n",
"\n",
"# Or use a PyTorch command,\n",
"# G *: a * -> b *\n",
"Gs = torch.vmap(G, -1, -1)\n",
"\n",
"# We feed a random input, and see whether applying an index before or after\n",
"# gives the same result.\n",
"Xac = torch.rand(a + c)\n",
"for jc in range(c[0]):\n",
" assert torch.allclose(G(Xac[:,jc]), Gc(Xac)[:,jc])\n",
" assert torch.allclose(G(Xac[:,jc]), Gs(Xac)[:,jc])\n",
"\n",
"# This shows how our definition of broadcasting lines up with that used by PyTorch vmap."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Figure 11: Inner Broadcasting\n",
"
\n",
"
"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"a, b, c, d = [3], [2], [4], [3]\n",
"T = torch.Tensor\n",
"\n",
"# We have some function which can be inner broadcast,\n",
"def H(Xa: T, Xd: T) -> T:\n",
" \"\"\" H: a, d -> b \"\"\"\n",
" return torch.sum(torch.sqrt(Xa**2)) + torch.sum(torch.sqrt(Xd ** 2)) + torch.ones(b)\n",
"\n",
"# We can bootstrap inner broadcasting,\n",
"def Hc0(Xca: T, Xd : T) -> T:\n",
" \"\"\" c0 H: c a, d -> c d \"\"\"\n",
" # Recall that we defined a, b, c, d in [_] arrays.\n",
" Ycb = torch.zeros(c + b)\n",
" for ic in range(c[0]):\n",
" Ycb[ic, :] = H(Xca[ic, :], Xd)\n",
" return Ycb\n",
"\n",
"# But vmap offers a clear way of doing it,\n",
"# *0 H: * a, d -> * c\n",
"Hs0 = torch.vmap(H, (0, None), 0)\n",
"\n",
"# We can show this satisfies Definition 2.14 by,\n",
"Xca = torch.rand(c + a)\n",
"Xd = torch.rand(d)\n",
"for ic in range(c[0]):\n",
" assert torch.allclose(Hc0(Xca, Xd)[ic, :], H(Xca[ic, :], Xd))\n",
" assert torch.allclose(Hs0(Xca, Xd)[ic, :], H(Xca[ic, :], Xd))\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Figure 12 Elementwise operations\n",
"
"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"\n",
"# Elementwise operations are implemented as usual ie\n",
"def f(x):\n",
" \"f : 1 -> 1\"\n",
" return x ** 2\n",
"\n",
"# We broadcast an elementwise operation,\n",
"# f *: * -> *\n",
"fs = torch.vmap(f)\n",
"\n",
"Xa = torch.rand(a)\n",
"for i in range(a[0]):\n",
" # And see that it aligns with the index before = index after framework.\n",
" assert torch.allclose(f(Xa[i]), fs(Xa)[i])\n",
" # But, elementwise operations are implied, so no special implementation is needed. \n",
" assert torch.allclose(f(Xa[i]), f(Xa)[i])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 2.4 Linearity\n",
"## 2.4.2 Implementing Linearity and Common Operations\n",
"### Figure 17: Multi-head Attention and Einsum\n",
"
"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"import einops\n",
"x, y, k, h = 5, 3, 4, 2\n",
"Q = torch.rand([y, k, h])\n",
"K = torch.rand([x, k, h])\n",
"\n",
"# Local memory contains,\n",
"# Q: y k h # K: x k h\n",
"# Outer products, transposes, inner products, and\n",
"# diagonalization reduce to einops expressions.\n",
"# Transpose K,\n",
"K = einops.einsum(K, 'x k h -> k x h')\n",
"# Outer product and diagonalize,\n",
"X = einops.einsum(Q, K, 'y k1 h, k2 x h -> y k1 k2 x h')\n",
"# Inner product,\n",
"X = einops.einsum(X, 'y k k x h -> y x h')\n",
"# Scale,\n",
"X = X / math.sqrt(k)\n",
"\n",
"Q = torch.rand([y, k, h])\n",
"K = torch.rand([x, k, h])\n",
"\n",
"# Local memory contains,\n",
"# Q: y k h # K: x k h\n",
"X = einops.einsum(Q, K, 'y k h, x k h -> y x h')\n",
"X = X / math.sqrt(k)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2.4.3 Linear Algebra\n",
"### Figure 18: Graphical Linear Algebra\n",
"
"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"# We will do an exercise implementing some of these equivalences.\n",
"# The reader can follow this exercise to get a better sense of how linear functions can be implemented,\n",
"# and how different forms are equivalent.\n",
"\n",
"a, b, c, d = [3], [4], [5], [3]\n",
"\n",
"# We will be using this function *a lot*\n",
"es = einops.einsum\n",
"\n",
"# F: a b c\n",
"F_matrix = torch.rand(a + b + c)\n",
"\n",
"# As an exericse we will show that the linear map F: a -> b c can be transposed in two ways.\n",
"# Either, we can broadcast, or take an outer product. We will show these are the same.\n",
"\n",
"# Transposing by broadcasting\n",
"# \n",
"def F_func(Xa: T):\n",
" \"\"\" F: a -> b c \"\"\"\n",
" return es(Xa,F_matrix,'a,a b c->b c',)\n",
"# * F: * a -> * b c\n",
"F_broadcast = torch.vmap(F_func, 0, 0)\n",
"\n",
"# We then reduce it, as in the diagram,\n",
"# b a -> b b c -> c\n",
"def F_broadcast_transpose(Xba: T):\n",
" \"\"\" (b F) (.b c): b a -> c \"\"\"\n",
" Xbbc = F_broadcast(Xba)\n",
" return es(Xbbc, 'b b c -> c')\n",
"\n",
"# Transpoing by linearity\n",
"#\n",
"# We take the outer product of Id(b) and F, and follow up with a inner product.\n",
"# This gives us,\n",
"F_outerproduct = es(torch.eye(b[0]), F_matrix,'b0 b1, a b2 c->b0 b1 a b2 c',)\n",
"# Think of this as Id(b) F: b0 a -> b1 b2 c arranged into an associated b0 b1 a b2 c tensor.\n",
"# We then take the inner product. This gives a (b a c) matrix, which can be used for a (b a -> c) map.\n",
"F_linear_transpose = es(F_outerproduct,'b B a B c->b a c',)\n",
"\n",
"# We contend that these are the same.\n",
"#\n",
"Xba = torch.rand(b + a)\n",
"assert torch.allclose(\n",
" F_broadcast_transpose(Xba), \n",
" es(Xba,F_linear_transpose, 'b a, b a c -> c'))\n",
"\n",
"# Furthermore, lets prove the unit-inner product identity.\n",
"#\n",
"# The first step is an outer product with the unit,\n",
"outerUnit = lambda Xb: es(Xb, torch.eye(b[0]), 'b0, b1 b2 -> b0 b1 b2')\n",
"# The next is a inner product over the first two axes,\n",
"dotOuter = lambda Xbbb: es(Xbbb, 'b0 b0 b1 -> b1')\n",
"# Applying both of these *should* be the identity, and hence leave any input unchanged.\n",
"Xb = torch.rand(b)\n",
"assert torch.allclose(\n",
" Xb,\n",
" dotOuter(outerUnit(Xb)))\n",
"\n",
"# Therefore, we can confidently use the expressions in Figure 18 to manipulate expressions."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 3.1 Basic Multi-Layer Perceptron\n",
"### Figure 19: Implementing a Basic Multi-Layer Perceptron\n",
"
"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Softmax(\n",
" dim=tensor([[ 0.0150, -0.0301, 0.1395, -0.0558, 0.0024, -0.0613, -0.0163, 0.0134,\n",
" 0.0577, -0.0624]], grad_fn=)\n",
")"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import torch.nn as nn\n",
"# Basic Image Recogniser\n",
"# This is a close copy of an introductory PyTorch tutorial:\n",
"# https://pytorch.org/tutorials/beginner/basics/buildmodel_tutorial.html\n",
"class BasicImageRecogniser(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.flatten = nn.Flatten()\n",
" self.linear_relu_stack = nn.Sequential(\n",
" nn.Linear(28*28, 512),\n",
" nn.ReLU(),\n",
" nn.Linear(512, 512),\n",
" nn.ReLU(),\n",
" nn.Linear(512, 10),\n",
" )\n",
" def forward(self, x):\n",
" x = self.flatten(x)\n",
" x = self.linear_relu_stack(x)\n",
" y_pred = nn.Softmax(x)\n",
" return y_pred\n",
" \n",
"my_BasicImageRecogniser = BasicImageRecogniser()\n",
"my_BasicImageRecogniser.forward(torch.rand([1,28,28]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 3.2 Neural Circuit Diagrams for the Transformer Architecture\n",
"### Figure 20: Scaled Dot-Product Attention\n",
"
"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"# Note, that we need to accomodate batches, hence the ... to capture additional axes.\n",
"\n",
"# We can do the algorithm step by step,\n",
"def ScaledDotProductAttention(q: T, k: T, v: T) -> T:\n",
" ''' yk, xk, xk -> yk '''\n",
" klength = k.size()[-1]\n",
" # Transpose\n",
" k = einops.einsum(k, '... x k -> ... k x')\n",
" # Matrix Multiply / Inner Product\n",
" x = einops.einsum(q, k, '... y k, ... k x -> ... y x')\n",
" # Scale\n",
" x = x / math.sqrt(klength)\n",
" # SoftMax\n",
" x = torch.nn.Softmax(-1)(x)\n",
" # Matrix Multiply / Inner Product\n",
" x = einops.einsum(x, v, '... y x, ... x k -> ... y k')\n",
" return x\n",
"\n",
"# Alternatively, we can simultaneously broadcast linear functions.\n",
"def ScaledDotProductAttention(q: T, k: T, v: T) -> T:\n",
" ''' yk, xk, xk -> yk '''\n",
" klength = k.size()[-1]\n",
" # Inner Product and Scale\n",
" x = einops.einsum(q, k, '... y k, ... x k -> ... y x')\n",
" # Scale and SoftMax \n",
" x = torch.nn.Softmax(-1)(x / math.sqrt(klength))\n",
" # Final Inner Product\n",
" x = einops.einsum(x, v, '... y x, ... x k -> ... y k')\n",
" return x"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Figure 21: Multi-Head Attention\n",
"
\n",
"\n",
"We will be implementing this algorithm. This shows us how we go from diagrams to implementations, and begins to give an idea of how organized diagrams leads to organized code."
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"def MultiHeadDotProductAttention(q: T, k: T, v: T) -> T:\n",
" ''' ykh, xkh, xkh -> ykh '''\n",
" klength = k.size()[-2]\n",
" x = einops.einsum(q, k, '... y k h, ... x k h -> ... y x h')\n",
" x = torch.nn.Softmax(-2)(x / math.sqrt(klength))\n",
" x = einops.einsum(x, v, '... y x h, ... x k h -> ... y k h')\n",
" return x\n",
"\n",
"# We implement this component as a neural network model.\n",
"# This is necessary when there are bold, learned components that need to be initialized.\n",
"class MultiHeadAttention(nn.Module):\n",
" # Multi-Head attention has various settings, which become variables\n",
" # for the initializer.\n",
" def __init__(self, m, k, h):\n",
" super().__init__()\n",
" self.m, self.k, self.h = m, k, h\n",
" # Set up all the boldface, learned components\n",
" # Note how they bind axes we want to split, which we do later with einops.\n",
" self.Lq = nn.Linear(m, k*h, False)\n",
" self.Lk = nn.Linear(m, k*h, False)\n",
" self.Lv = nn.Linear(m, k*h, False)\n",
" self.Lo = nn.Linear(k*h, m, False)\n",
"\n",
"\n",
" # We have endogenous data (Eym) and external / injected data (Xxm)\n",
" def forward(self, Eym, Xxm):\n",
" \"\"\" y m, x m -> y m \"\"\"\n",
" # We first generate query, key, and value vectors.\n",
" # Linear layers are automatically broadcast.\n",
"\n",
" # However, the k and h axes are bound. We define an unbinder to handle the outputs,\n",
" unbind = lambda x: einops.rearrange(x, '... (k h)->... k h', h=self.h)\n",
" q = unbind(self.Lq(Eym))\n",
" k = unbind(self.Lk(Xxm))\n",
" v = unbind(self.Lv(Xxm))\n",
"\n",
" # We feed q, k, and v to standard Multi-Head inner product Attention\n",
" o = MultiHeadDotProductAttention(q, k, v)\n",
"\n",
" # Rebind to feed to the final learned layer,\n",
" o = einops.rearrange(o, '... k h-> ... (k h)', h=self.h)\n",
" return self.Lo(o)\n",
"\n",
"# Now we can run it on fake data;\n",
"y, x, m, jc, heads = [20], [22], [128], [16], 4\n",
"# Internal Data\n",
"Eym = torch.rand(y + m)\n",
"# External Data\n",
"Xxm = torch.rand(x + m)\n",
"\n",
"mha = MultiHeadAttention(m[0],jc[0],heads)\n",
"assert list(mha.forward(Eym, Xxm).size()) == y + m\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 3.4 Computer Vision\n",
"\n",
"Here, we really start to understand why splitting diagrams into ``fenced off'' blocks aids implementation. \n",
"In addition to making diagrams easier to understand and patterns more clearn, blocks indicate how code can structured and organized.\n",
"\n",
"## Figure 26: Identity Residual Network\n",
"
\n"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"# For Figure 26, every fenced off region is its own module.\n",
"\n",
"# Batch norm and then activate is a repeated motif,\n",
"class NormActivate(nn.Sequential):\n",
" def __init__(self, nf, Norm=nn.BatchNorm2d, Activation=nn.ReLU):\n",
" super().__init__(Norm(nf), Activation())\n",
"\n",
"def size_to_string(size):\n",
" return \" \".join(map(str,list(size)))\n",
"\n",
"# The Identity ResNet block breaks down into a manageable sequence of components.\n",
"class IdentityResNet(nn.Sequential):\n",
" def __init__(self, N=3, n_mu=[16,64,128,256], y=10):\n",
" super().__init__(\n",
" nn.Conv2d(3, n_mu[0], 3, padding=1),\n",
" Block(1, N, n_mu[0], n_mu[1]),\n",
" Block(2, N, n_mu[1], n_mu[2]),\n",
" Block(2, N, n_mu[2], n_mu[3]),\n",
" NormActivate(n_mu[3]),\n",
" nn.AdaptiveAvgPool2d(1),\n",
" nn.Flatten(),\n",
" nn.Linear(n_mu[3], y),\n",
" nn.Softmax(-1),\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The Block can be defined in a seperate model, keeping the code manageable and closely connected to the diagram.\n",
"\n",
"
"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"# We then follow how diagrams define each ``block''\n",
"class Block(nn.Sequential):\n",
" def __init__(self, s, N, n0, n1):\n",
" \"\"\" n0 and n1 as inputs to the initializer are implicit from having them in the domain and codomain in the diagram. \"\"\"\n",
" nb = n1 // 4\n",
" super().__init__(\n",
" *[\n",
" NormActivate(n0),\n",
" ResidualConnection(\n",
" nn.Sequential(\n",
" nn.Conv2d(n0, nb, 1, s),\n",
" NormActivate(nb),\n",
" nn.Conv2d(nb, nb, 3, padding=1),\n",
" NormActivate(nb),\n",
" nn.Conv2d(nb, n1, 1),\n",
" ),\n",
" nn.Conv2d(n0, n1, 1, s),\n",
" )\n",
" ] + [\n",
" ResidualConnection(\n",
" nn.Sequential(\n",
" NormActivate(n1),\n",
" nn.Conv2d(n1, nb, 1),\n",
" NormActivate(nb),\n",
" nn.Conv2d(nb, nb, 3, padding=1),\n",
" NormActivate(nb),\n",
" nn.Conv2d(nb, n1, 1)\n",
" ),\n",
" )\n",
" ] * N\n",
" \n",
" ) \n",
"# Residual connections are a repeated pattern in the diagram. So, we are motivated to encapsulate them\n",
"# as a seperate module.\n",
"class ResidualConnection(nn.Module):\n",
" def __init__(self, mainline : nn.Module, connection : nn.Module | None = None) -> None:\n",
" super().__init__()\n",
" self.main = mainline\n",
" self.secondary = nn.Identity() if connection == None else connection\n",
" def forward(self, x):\n",
" return self.main(x) + self.secondary(x)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"# A standard image processing algorithm has inputs shaped b c h w.\n",
"b, c, hw = [3], [3], [16, 16]\n",
"\n",
"idresnet = IdentityResNet()\n",
"Xbchw = torch.rand(b + c + hw)\n",
"\n",
"# And we see if the overall size is maintained,\n",
"assert list(idresnet.forward(Xbchw).size()) == b + [10]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The UNet is a more complicated algorithm than residual networks. The ``fenced off'' sections help keep our code organized. Diagrams streamline implementation, and helps keep code organized.\n",
"\n",
"## Figure 27: The UNet architecture\n",
"
"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"# We notice that double convolution where the numbers of channels change is a repeated motif.\n",
"# We denote the input with c0 and output with c1. \n",
"# This can also be done for subsequent members of an iteration.\n",
"# When we go down an iteration eg. 5, 4, etc. we may have the input be c1 and the output c0.\n",
"class DoubleConvolution(nn.Sequential):\n",
" def __init__(self, c0, c1, Activation=nn.ReLU):\n",
" super().__init__(\n",
" nn.Conv2d(c0, c1, 3, padding=1),\n",
" Activation(),\n",
" nn.Conv2d(c0, c1, 3, padding=1),\n",
" Activation(),\n",
" )\n",
"\n",
"# The model is specified for a very specific number of layers,\n",
"# so we will not make it very flexible.\n",
"class UNet(nn.Module):\n",
" def __init__(self, y=2):\n",
" super().__init__()\n",
" # Set up the channel sizes;\n",
" c = [1 if i == 0 else 64 * 2 ** i for i in range(6)]\n",
"\n",
" # Saving and loading from memory means we can not use a single,\n",
" # sequential chain.\n",
"\n",
" # Set up and initialize the components;\n",
" self.DownScaleBlocks = [\n",
" DownScaleBlock(c[i],c[i+1])\n",
" for i in range(0,4)\n",
" ] # Note how this imitates the lambda operators in the diagram.\n",
" self.middleDoubleConvolution = DoubleConvolution(c[4], c[5])\n",
" self.middleUpscale = nn.ConvTranspose2d(c[5], c[4], 2, 2, 1)\n",
" self.upScaleBlocks = [\n",
" UpScaleBlock(c[5-i],c[4-i])\n",
" for i in range(1,4)\n",
" ]\n",
" self.finalConvolution = nn.Conv2d(c[1], y)\n",
"\n",
" def forward(self, x):\n",
" cLambdas = []\n",
" for dsb in self.DownScaleBlocks:\n",
" x, cLambda = dsb(x)\n",
" cLambdas.append(cLambda)\n",
" x = self.middleDoubleConvolution(x)\n",
" x = self.middleUpscale(x)\n",
" for usb in self.upScaleBlocks:\n",
" cLambda = cLambdas.pop()\n",
" x = usb(x, cLambda)\n",
" x = self.finalConvolution(x)\n",
"\n",
"class DownScaleBlock(nn.Module):\n",
" def __init__(self, c0, c1) -> None:\n",
" super().__init__()\n",
" self.doubleConvolution = DoubleConvolution(c0, c1)\n",
" self.downScaler = nn.MaxPool2d(2, 2, 1)\n",
" def forward(self, x):\n",
" cLambda = self.doubleConvolution(x)\n",
" x = self.downScaler(cLambda)\n",
" return x, cLambda\n",
"\n",
"class UpScaleBlock(nn.Module):\n",
" def __init__(self, c1, c0) -> None:\n",
" super().__init__()\n",
" self.doubleConvolution = DoubleConvolution(2*c1, c1)\n",
" self.upScaler = nn.ConvTranspose2d(c1,c0,2,2,1)\n",
" def forward(self, x, cLambda):\n",
" # Concatenation occurs over the C channel axis (dim=1)\n",
" x = torch.concat(x, cLambda, 1)\n",
" x = self.doubleConvolution(x)\n",
" x = self.upScaler(x)\n",
" return x"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 3.5 Vision Transformer\n",
"\n",
"We adapt our code for Multi-Head Attention to apply it to the vision case. This is a good exercise in how neural circuit diagrams allow code to be easily adapted for new modalities.\n",
"## Figure 28: Visual Attention\n",
"
"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 33, 15, 15])"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class VisualAttention(nn.Module):\n",
" def __init__(self, c, k, heads = 1, kernel = 1, stride = 1):\n",
" super().__init__()\n",
" \n",
" # w gives the kernel size, which we make adjustable.\n",
" self.c, self.k, self.h, self.w = c, k, heads, kernel\n",
" # Set up all the boldface, learned components\n",
" # Note how standard components may not have axes bound in \n",
" # the same way as diagrams. This requires us to rearrange\n",
" # using the einops package.\n",
"\n",
" # The learned layers form convolutions\n",
" self.Cq = nn.Conv2d(c, k * heads, kernel, stride)\n",
" self.Ck = nn.Conv2d(c, k * heads, kernel, stride)\n",
" self.Cv = nn.Conv2d(c, k * heads, kernel, stride)\n",
" self.Co = nn.ConvTranspose2d(\n",
" k * heads, c, kernel, stride)\n",
"\n",
" # Defined previously, closely follows the diagram.\n",
" def MultiHeadDotProductAttention(self, q: T, k: T, v: T) -> T:\n",
" ''' ykh, xkh, xkh -> ykh '''\n",
" klength = k.size()[-2]\n",
" x = einops.einsum(q, k, '... y k h, ... x k h -> ... y x h')\n",
" x = torch.nn.Softmax(-2)(x / math.sqrt(klength))\n",
" x = einops.einsum(x, v, '... y x h, ... x k h -> ... y k h')\n",
" return x\n",
"\n",
" # We have endogenous data (EYc) and external / injected data (XXc)\n",
" def forward(self, EcY, XcX):\n",
" \"\"\" cY, cX -> cY \n",
" The visual attention algorithm. Injects information from Xc into Yc. \"\"\"\n",
" # query, key, and value vectors.\n",
" # We unbind the k h axes which were produced by the convolutions, and feed them\n",
" # in the normal manner to MultiHeadDotProductAttention.\n",
" unbind = lambda x: einops.rearrange(x, 'N (k h) H W -> N (H W) k h', h=self.h)\n",
" # Save size to recover it later\n",
" q = self.Cq(EcY)\n",
" W = q.size()[-1]\n",
"\n",
" # By appropriately managing the axes, minimal changes to our previous code\n",
" # is necessary.\n",
" q = unbind(q)\n",
" k = unbind(self.Ck(XcX))\n",
" v = unbind(self.Cv(XcX))\n",
" o = self.MultiHeadDotProductAttention(q, k, v)\n",
"\n",
" # Rebind to feed to the transposed convolution layer.\n",
" o = einops.rearrange(o, 'N (H W) k h -> N (k h) H W', \n",
" h=self.h, W=W)\n",
" return self.Co(o)\n",
"\n",
"# Single batch element,\n",
"b = [1]\n",
"Y, X, c, k = [16, 16], [16, 16], [33], 8\n",
"# The additional configurations,\n",
"heads, kernel, stride = 4, 3, 3\n",
"\n",
"# Internal Data,\n",
"EYc = torch.rand(b + c + Y)\n",
"# External Data,\n",
"XXc = torch.rand(b + c + X)\n",
"\n",
"# We can now run the algorithm,\n",
"visualAttention = VisualAttention(c[0], k, heads, kernel, stride)\n",
"\n",
"# Interestingly, the height/width reduces by 1 for stride\n",
"# values above 1. Otherwise, it stays the same.\n",
"visualAttention.forward(EYc, XXc).size()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Appendix"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [],
"source": [
"# A container to track the size of modules,\n",
"# Replace a module definition eg.\n",
"# > self.Cq = nn.Conv2d(c, k * heads, kernel, stride)\n",
"# With;\n",
"# > self.Cq = Tracker(nn.Conv2d(c, k * heads, kernel, stride), \"Query convolution\")\n",
"# And the input / output sizes (to check diagrams) will be printed.\n",
"class Tracker(nn.Module):\n",
" def __init__(self, module: nn.Module, name : str = \"\"):\n",
" super().__init__()\n",
" self.module = module\n",
" if name:\n",
" self.name = name\n",
" else:\n",
" self.name = self.module._get_name()\n",
" def forward(self, x):\n",
" x_size = size_to_string(x.size())\n",
" x = self.module.forward(x)\n",
" y_size = size_to_string(x.size())\n",
" print(f\"{self.name}: \\t {x_size} -> {y_size}\")\n",
" return x"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}