File size: 4,134 Bytes
44e9845 e5e2eeb 44e9845 97825b8 44e9845 97825b8 44e9845 97825b8 44e9845 e5e2eeb 44e9845 e5e2eeb 44e9845 e5e2eeb 44e9845 e5e2eeb 44e9845 e5e2eeb 44e9845 e5e2eeb 44e9845 e5e2eeb 44e9845 e5e2eeb 44e9845 e5e2eeb 44e9845 e5e2eeb 44e9845 e5e2eeb 44e9845 e5e2eeb 44e9845 e5e2eeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import random
import pytest
import torch
import activation
from .utils import assert_close, opcheck
DTYPES = [torch.float, torch.bfloat16, torch.half]
NUM_TOKENS = [7, 83, 256, 2048] # Arbitrary values for testing
D = [1, 7, 512, 13824] # Arbitrary values for testing
SEEDS = [0]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
def norm(x, eps: float) -> torch.Tensor:
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
def poly_norm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor,
eps: float) -> torch.Tensor:
x = x.float()
return (weight[0] * norm(x**3, eps) + weight[1] * norm(x**2, eps) +
weight[2] * norm(x, eps) + bias).to(weight.dtype)
def mul_poly_norm_all_naive(x: torch.Tensor, mul: torch.Tensor,
weight: torch.Tensor, bias: torch.Tensor,
eps: float) -> torch.Tensor:
return poly_norm(x, weight, bias, eps) * mul
#use poly_norm kernel
def mul_poly_norm_partial_naive(x: torch.Tensor, mul: torch.Tensor,
weight: torch.Tensor, bias: torch.Tensor,
eps: float) -> torch.Tensor:
return activation.poly_norm(x, weight, bias, eps) * mul
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_fused_mul_poly_norm(
num_tokens: int,
d: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
random.seed(seed)
torch.manual_seed(seed)
torch.set_default_device(device)
x = torch.randn(num_tokens, d, dtype=dtype, requires_grad=True)
mul = torch.randn(num_tokens, d, dtype=dtype, requires_grad=True)
weight = torch.randn(3, dtype=dtype, requires_grad=True)
bias = torch.randn(1, dtype=dtype, requires_grad=True)
eps = 1e-05
x.retain_grad()
mul.retain_grad()
weight.retain_grad()
bias.retain_grad()
# To separate gradient computation, clone the inputs
x_ref = x.detach().clone().requires_grad_(True)
mul_ref = mul.detach().clone().requires_grad_(True)
weight_ref = weight.detach().clone().requires_grad_(True)
bias_ref = bias.detach().clone().requires_grad_(True)
x_ref2 = x.detach().clone().requires_grad_(True)
mul_ref2 = mul.detach().clone().requires_grad_(True)
weight_ref2 = weight.detach().clone().requires_grad_(True)
bias_ref2 = bias.detach().clone().requires_grad_(True)
torch_fn = mul_poly_norm_all_naive
torch_fn2 = mul_poly_norm_partial_naive
op = activation.ops.fused_mul_poly_norm
fn = activation.fused_mul_poly_norm
layer = activation.layers.FusedMulPolyNorm(eps)
layer.weight = torch.nn.Parameter(weight)
layer.bias = torch.nn.Parameter(bias)
out = torch.empty(x.shape, dtype=x.dtype, device=x.device)
opcheck(op, (out, x, mul, weight, bias, eps))
out = fn(x, mul, weight, bias, eps)
mod_out = layer(x, mul)
ref_out = torch_fn(x_ref, mul_ref, weight_ref, bias_ref, eps)
ref_out2 = torch_fn2(x_ref2, mul_ref2, weight_ref2, bias_ref2, eps)
# Mul amplifies small numeric differences between naive poly_norm and the kernel.
# When validating against all_naive, use a looser rtol/atol.
assert_close(out, ref_out, atol=0.01, rtol=0.01)
assert_close(out, ref_out2)
assert_close(mod_out, out, atol=0.0, rtol=0.0)
# test backward pass
out_grad = torch.randn_like(out)
out_grad = out_grad / out_grad.norm()
ref_out.backward(out_grad)
ref_out2.backward(out_grad)
mod_out.backward(out_grad)
assert_close(x.grad, x_ref.grad)
assert_close(x.grad, x_ref2.grad)
assert_close(mul.grad, mul_ref.grad)
assert_close(mul.grad, mul_ref2.grad)
assert_close(layer.bias.grad, bias_ref.grad, rtol=0.05)
assert_close(layer.bias.grad, bias_ref2.grad, rtol=0.05)
assert_close(layer.weight.grad, weight_ref.grad, rtol=0.05)
assert_close(layer.weight.grad, weight_ref2.grad, rtol=0.05)
|