activation / tests /kernels /test_rms_norm.py
TaehyunKimMotif's picture
add readme with precommit hooks and applied pre commit to all files
f517c97
raw
history blame
2.09 kB
import random
import pytest
import torch
import activation
from .utils import assert_close, opcheck
DTYPES = [torch.float, torch.bfloat16, torch.half]
# NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
# D = [512, 13824] # Arbitrary values for testing
NUM_TOKENS = [7, 13] # Arbitrary values for testing
D = [513] # Arbitrary values for testing
SEEDS = [0]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_rms_norm(
num_tokens: int,
d: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
random.seed(seed)
torch.manual_seed(seed)
torch.set_default_device(device)
x = torch.randn(num_tokens, d, dtype=dtype, requires_grad=True)
weight = torch.randn(d, dtype=dtype, requires_grad=True)
eps = 1e-05
x.retain_grad()
weight.retain_grad()
# To separate gradient computation, clone the inputs
x_ref = x.detach().clone().requires_grad_(True)
weight_ref = weight.detach().clone().requires_grad_(True)
torch_layer = torch.nn.RMSNorm(d, eps=eps, dtype=dtype)
torch_layer.weight = torch.nn.Parameter(weight_ref)
op = activation.ops.rms_norm
fn = activation.rms_norm
layer = activation.layers.RMSNorm(d, eps=eps, dtype=dtype)
layer.weight = torch.nn.Parameter(weight)
out = torch.empty(x.shape, dtype=x.dtype, device=x.device)
opcheck(op, (out, x, weight, eps))
out = fn(x, weight, eps)
mod_out = layer(x)
ref_out = torch_layer(x_ref)
assert_close(out, ref_out)
assert_close(mod_out, out, atol=0.0, rtol=0.0)
# test backward pass
out_grad = torch.randn_like(out)
out_grad = out_grad / out_grad.norm()
ref_out.backward(out_grad)
mod_out.backward(out_grad)
assert_close(x.grad, x_ref.grad)
assert_close(layer.weight.grad, torch_layer.weight.grad, rtol=0.05)