File size: 4,134 Bytes
44e9845
 
 
 
 
 
 
 
 
 
e5e2eeb
 
44e9845
97825b8
 
 
44e9845
 
 
 
 
 
97825b8
 
44e9845
97825b8
 
44e9845
 
e5e2eeb
 
 
 
 
 
 
 
 
 
 
 
 
44e9845
 
 
 
 
e5e2eeb
44e9845
 
 
 
 
 
 
 
 
 
 
e5e2eeb
44e9845
 
 
 
 
e5e2eeb
44e9845
 
 
 
 
e5e2eeb
44e9845
 
 
e5e2eeb
 
 
 
 
 
 
 
 
 
 
 
44e9845
 
 
 
e5e2eeb
44e9845
e5e2eeb
 
 
 
44e9845
e5e2eeb
 
 
 
44e9845
 
 
 
 
 
 
e5e2eeb
44e9845
 
 
e5e2eeb
 
 
44e9845
e5e2eeb
44e9845
e5e2eeb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import random

import pytest
import torch

import activation

from .utils import assert_close, opcheck

DTYPES = [torch.float, torch.bfloat16, torch.half]
NUM_TOKENS = [7, 83, 256, 2048]  # Arbitrary values for testing
D = [1, 7, 512, 13824]  # Arbitrary values for testing
SEEDS = [0]
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]


def norm(x, eps: float) -> torch.Tensor:
    return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)


def poly_norm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor,
              eps: float) -> torch.Tensor:
    x = x.float()
    return (weight[0] * norm(x**3, eps) + weight[1] * norm(x**2, eps) +
            weight[2] * norm(x, eps) + bias).to(weight.dtype)


def mul_poly_norm_all_naive(x: torch.Tensor, mul: torch.Tensor,
                            weight: torch.Tensor, bias: torch.Tensor,
                            eps: float) -> torch.Tensor:
    return poly_norm(x, weight, bias, eps) * mul


#use poly_norm kernel
def mul_poly_norm_partial_naive(x: torch.Tensor, mul: torch.Tensor,
                                weight: torch.Tensor, bias: torch.Tensor,
                                eps: float) -> torch.Tensor:
    return activation.poly_norm(x, weight, bias, eps) * mul


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_fused_mul_poly_norm(
    num_tokens: int,
    d: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
) -> None:
    random.seed(seed)
    torch.manual_seed(seed)
    torch.set_default_device(device)

    x = torch.randn(num_tokens, d, dtype=dtype, requires_grad=True)
    mul = torch.randn(num_tokens, d, dtype=dtype, requires_grad=True)
    weight = torch.randn(3, dtype=dtype, requires_grad=True)
    bias = torch.randn(1, dtype=dtype, requires_grad=True)
    eps = 1e-05

    x.retain_grad()
    mul.retain_grad()
    weight.retain_grad()
    bias.retain_grad()
    # To separate gradient computation, clone the inputs

    x_ref = x.detach().clone().requires_grad_(True)
    mul_ref = mul.detach().clone().requires_grad_(True)
    weight_ref = weight.detach().clone().requires_grad_(True)
    bias_ref = bias.detach().clone().requires_grad_(True)

    x_ref2 = x.detach().clone().requires_grad_(True)
    mul_ref2 = mul.detach().clone().requires_grad_(True)
    weight_ref2 = weight.detach().clone().requires_grad_(True)
    bias_ref2 = bias.detach().clone().requires_grad_(True)

    torch_fn = mul_poly_norm_all_naive
    torch_fn2 = mul_poly_norm_partial_naive

    op = activation.ops.fused_mul_poly_norm
    fn = activation.fused_mul_poly_norm

    layer = activation.layers.FusedMulPolyNorm(eps)
    layer.weight = torch.nn.Parameter(weight)
    layer.bias = torch.nn.Parameter(bias)

    out = torch.empty(x.shape, dtype=x.dtype, device=x.device)
    opcheck(op, (out, x, mul, weight, bias, eps))

    out = fn(x, mul, weight, bias, eps)
    mod_out = layer(x, mul)
    ref_out = torch_fn(x_ref, mul_ref, weight_ref, bias_ref, eps)
    ref_out2 = torch_fn2(x_ref2, mul_ref2, weight_ref2, bias_ref2, eps)

    # Mul amplifies small numeric differences between naive poly_norm and the kernel.
    # When validating against all_naive, use a looser rtol/atol.
    assert_close(out, ref_out, atol=0.01, rtol=0.01)
    assert_close(out, ref_out2)
    assert_close(mod_out, out, atol=0.0, rtol=0.0)

    # test backward pass
    out_grad = torch.randn_like(out)
    out_grad = out_grad / out_grad.norm()

    ref_out.backward(out_grad)
    ref_out2.backward(out_grad)
    mod_out.backward(out_grad)

    assert_close(x.grad, x_ref.grad)
    assert_close(x.grad, x_ref2.grad)
    assert_close(mul.grad, mul_ref.grad)
    assert_close(mul.grad, mul_ref2.grad)
    assert_close(layer.bias.grad, bias_ref.grad, rtol=0.05)
    assert_close(layer.bias.grad, bias_ref2.grad, rtol=0.05)
    assert_close(layer.weight.grad, weight_ref.grad, rtol=0.05)
    assert_close(layer.weight.grad, weight_ref2.grad, rtol=0.05)