activation / tests /kernels /test_poly_norm_perf.py
TaehyunKimMotif's picture
add readme with precommit hooks and applied pre commit to all files
f517c97
raw
history blame
3.16 kB
import random
from dataclasses import dataclass
import pytest
import torch
import activation
from .test_poly_norm import poly_norm
from .utils import assert_close
CASES = [
((1, 2048, 8192), torch.bfloat16),
((1, 2048, 16384), torch.bfloat16),
((1, 16384, 8192), torch.bfloat16),
((1, 16384, 16384), torch.bfloat16),
]
NUM_REP = 100
@dataclass
class PerfResult:
type: str # forward or backward
shape: tuple
dtype: torch.dtype
kernel_time_ms: float
torch_time_ms: float
@property
def speedup(self) -> float:
return self.torch_time_ms / self.kernel_time_ms
PERF_RESULTS: list[PerfResult] = []
@pytest.mark.parametrize("cases", CASES)
@pytest.mark.perf
def test_poly_norm(
cases: tuple,
do_plot: bool,
) -> None:
random.seed(12345)
torch.manual_seed(12345)
torch.set_default_device("cuda")
shape, dtype = cases
x = torch.randn(shape, dtype=dtype, requires_grad=True)
weight = torch.randn(3, dtype=dtype, requires_grad=True)
bias = torch.randn(1, dtype=dtype, requires_grad=True)
eps = 1e-05
x.retain_grad()
weight.retain_grad()
bias.retain_grad()
# To separate gradient computation, clone the inputs
x_ref = x.detach().clone().requires_grad_(True)
weight_ref = weight.detach().clone().requires_grad_(True)
bias_ref = bias.detach().clone().requires_grad_(True)
torch_fn = poly_norm
layer = activation.layers.PolyNorm(eps)
layer.weight = torch.nn.Parameter(weight)
layer.bias = torch.nn.Parameter(bias)
# Check correctness
mod_out = layer(x)
ref_out = torch_fn(x_ref, weight_ref, bias_ref, eps)
assert_close(mod_out, ref_out)
out_grad = torch.rand_like(ref_out)
out_grad = out_grad / out_grad.norm()
ref_out.backward(out_grad, retain_graph=True)
mod_out.backward(out_grad, retain_graph=True)
assert_close(x.grad, x_ref.grad)
assert_close(layer.bias.grad, bias_ref.grad, rtol=0.05)
assert_close(layer.weight.grad, weight_ref.grad, rtol=0.05)
def time_cuda(fn):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
for _ in range(5):
fn()
start.record()
for _ in range(NUM_REP):
fn()
end.record()
torch.cuda.synchronize()
return start.elapsed_time(end) / NUM_REP
kernel_time_ms = time_cuda(lambda: layer(x))
torch_fn_time = time_cuda(
lambda: torch_fn(x_ref, weight_ref, bias_ref, eps))
PERF_RESULTS.append(
PerfResult(
type="forward",
shape=shape,
dtype=dtype,
kernel_time_ms=kernel_time_ms,
torch_time_ms=torch_fn_time,
))
kernel_time_ms = time_cuda(
lambda: mod_out.backward(out_grad, retain_graph=True))
torch_fn_time = time_cuda(
lambda: ref_out.backward(out_grad, retain_graph=True))
PERF_RESULTS.append(
PerfResult(
type="backward",
shape=shape,
dtype=dtype,
kernel_time_ms=kernel_time_ms,
torch_time_ms=torch_fn_time,
))