activation / tests /kernels /test_poly_norm.py
TaehyunKimMotif's picture
add readme with precommit hooks and applied pre commit to all files
f517c97
raw
history blame
2.63 kB
import random
import pytest
import torch
import activation
from .utils import assert_close, opcheck
DTYPES = [torch.float, torch.bfloat16, torch.half]
# NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
# D = [512, 13824] # Arbitrary values for testing
NUM_TOKENS = [7, 13] # Arbitrary values for testing
D = [513] # Arbitrary values for testing
SEEDS = [0]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
def norm(x, eps: float) -> torch.Tensor:
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
def poly_norm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor,
eps: float) -> torch.Tensor:
x = x.float()
return (weight[0] * norm(x**3, eps) + weight[1] * norm(x**2, eps) +
weight[2] * norm(x, eps) + bias).to(weight.dtype)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_poly_norm(
num_tokens: int,
d: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
random.seed(seed)
torch.manual_seed(seed)
torch.set_default_device(device)
x = torch.randn(num_tokens, d, dtype=dtype, requires_grad=True)
weight = torch.randn(3, dtype=dtype, requires_grad=True)
bias = torch.randn(1, dtype=dtype, requires_grad=True)
eps = 1e-05
x.retain_grad()
weight.retain_grad()
bias.retain_grad()
# To separate gradient computation, clone the inputs
x_ref = x.detach().clone().requires_grad_(True)
weight_ref = weight.detach().clone().requires_grad_(True)
bias_ref = bias.detach().clone().requires_grad_(True)
torch_fn = poly_norm
op = activation.ops.poly_norm
fn = activation.poly_norm
layer = activation.layers.PolyNorm(eps)
layer.weight = torch.nn.Parameter(weight)
layer.bias = torch.nn.Parameter(bias)
out = torch.empty(x.shape, dtype=x.dtype, device=x.device)
opcheck(op, (out, x, weight, bias, eps))
out = fn(x, weight, bias, eps)
mod_out = layer(x)
ref_out = torch_fn(x_ref, weight_ref, bias_ref, eps)
assert_close(out, ref_out)
assert_close(mod_out, out, atol=0.0, rtol=0.0)
# test backward pass
out_grad = torch.randn_like(out)
out_grad = out_grad / out_grad.norm()
ref_out.backward(out_grad)
mod_out.backward(out_grad)
assert_close(x.grad, x_ref.grad)
assert_close(layer.bias.grad, bias_ref.grad, rtol=0.05)
assert_close(layer.weight.grad, weight_ref.grad, rtol=0.05)