Upload custom kernels
Browse files- build/torch-universal/triton_llama_attn/attn.py +46 -46
- torch-ext/triton_llama_attn/.pytest_cache/.gitignore +2 -0
- torch-ext/triton_llama_attn/.pytest_cache/CACHEDIR.TAG +4 -0
- torch-ext/triton_llama_attn/.pytest_cache/README.md +8 -0
- torch-ext/triton_llama_attn/.pytest_cache/v/cache/lastfailed +1 -0
- torch-ext/triton_llama_attn/.pytest_cache/v/cache/nodeids +3 -0
- torch-ext/triton_llama_attn/.pytest_cache/v/cache/stepwise +1 -0
- torch-ext/triton_llama_attn/__pycache__/__init__.cpython-310.pyc +0 -0
- torch-ext/triton_llama_attn/__pycache__/attn.cpython-310-pytest-8.3.5.pyc +0 -0
- torch-ext/triton_llama_attn/attn.py +47 -46
build/torch-universal/triton_llama_attn/attn.py
CHANGED
@@ -144,7 +144,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
|
|
144 |
if fp8_v:
|
145 |
p = p.to(tl.float8e5)
|
146 |
else:
|
147 |
-
p = p.to(tl.
|
148 |
acc = tl.dot(p, v, acc)
|
149 |
# update m_i and l_i
|
150 |
m_i = m_ij
|
@@ -344,7 +344,7 @@ def _attn_fwd_tma(sm_scale, M, #
|
|
344 |
FP8_OUTPUT: tl.constexpr, #
|
345 |
STAGE: tl.constexpr #
|
346 |
):
|
347 |
-
dtype = tl.float8e5 if FP8_OUTPUT else tl.
|
348 |
tl.static_assert(BLOCK_N <= HEAD_DIM)
|
349 |
start_m = tl.program_id(0)
|
350 |
off_hz = tl.program_id(1)
|
@@ -447,14 +447,14 @@ def _attn_bwd_dkdv(dk, dv, #
|
|
447 |
do = tl.load(do_ptrs)
|
448 |
# Compute dV.
|
449 |
ppT = pT
|
450 |
-
ppT = ppT.to(tl.
|
451 |
dv += tl.dot(ppT, do)
|
452 |
# D (= delta) is pre-divided by ds_scale.
|
453 |
Di = tl.load(D + offs_m)
|
454 |
# Compute dP and dS.
|
455 |
dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
|
456 |
dsT = pT * (dpT - Di[None, :])
|
457 |
-
dsT = dsT.to(tl.
|
458 |
dk += tl.dot(dsT, tl.trans(qT))
|
459 |
# Increment pointers.
|
460 |
curr_m += step_m
|
@@ -500,7 +500,7 @@ def _attn_bwd_dq(dq, q, K, V, #
|
|
500 |
# Compute dP and dS.
|
501 |
dp = tl.dot(do, vT).to(tl.float32)
|
502 |
ds = p * (dp - Di[:, None])
|
503 |
-
ds = ds.to(tl.
|
504 |
# Compute dQ.
|
505 |
# NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
|
506 |
dq += tl.dot(ds, tl.trans(kT))
|
@@ -1106,44 +1106,44 @@ def attn_forward_kernel(
|
|
1106 |
|
1107 |
# return is_close
|
1108 |
|
1109 |
-
|
1110 |
-
|
1111 |
-
|
1112 |
-
|
1113 |
-
|
1114 |
-
|
1115 |
-
|
1116 |
-
|
1117 |
-
|
1118 |
-
|
1119 |
-
|
1120 |
-
|
1121 |
-
|
1122 |
-
#
|
1123 |
-
|
1124 |
-
|
1125 |
-
|
1126 |
-
|
1127 |
-
|
1128 |
-
#
|
1129 |
-
|
1130 |
-
|
1131 |
-
|
1132 |
-
|
1133 |
-
|
1134 |
-
#
|
1135 |
-
|
1136 |
-
|
1137 |
-
|
1138 |
-
|
1139 |
-
|
1140 |
-
#
|
1141 |
-
|
1142 |
-
|
1143 |
-
#
|
1144 |
-
#
|
1145 |
-
|
1146 |
-
|
1147 |
-
|
1148 |
-
|
1149 |
-
|
|
|
144 |
if fp8_v:
|
145 |
p = p.to(tl.float8e5)
|
146 |
else:
|
147 |
+
p = p.to(tl.float32)
|
148 |
acc = tl.dot(p, v, acc)
|
149 |
# update m_i and l_i
|
150 |
m_i = m_ij
|
|
|
344 |
FP8_OUTPUT: tl.constexpr, #
|
345 |
STAGE: tl.constexpr #
|
346 |
):
|
347 |
+
dtype = tl.float8e5 if FP8_OUTPUT else tl.float32
|
348 |
tl.static_assert(BLOCK_N <= HEAD_DIM)
|
349 |
start_m = tl.program_id(0)
|
350 |
off_hz = tl.program_id(1)
|
|
|
447 |
do = tl.load(do_ptrs)
|
448 |
# Compute dV.
|
449 |
ppT = pT
|
450 |
+
ppT = ppT.to(tl.float32)
|
451 |
dv += tl.dot(ppT, do)
|
452 |
# D (= delta) is pre-divided by ds_scale.
|
453 |
Di = tl.load(D + offs_m)
|
454 |
# Compute dP and dS.
|
455 |
dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
|
456 |
dsT = pT * (dpT - Di[None, :])
|
457 |
+
dsT = dsT.to(tl.float32)
|
458 |
dk += tl.dot(dsT, tl.trans(qT))
|
459 |
# Increment pointers.
|
460 |
curr_m += step_m
|
|
|
500 |
# Compute dP and dS.
|
501 |
dp = tl.dot(do, vT).to(tl.float32)
|
502 |
ds = p * (dp - Di[:, None])
|
503 |
+
ds = ds.to(tl.float32)
|
504 |
# Compute dQ.
|
505 |
# NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
|
506 |
dq += tl.dot(ds, tl.trans(kT))
|
|
|
1106 |
|
1107 |
# return is_close
|
1108 |
|
1109 |
+
attention = Attention.apply
|
1110 |
+
DEVICE = "cuda:0"
|
1111 |
+
|
1112 |
+
import pytest
|
1113 |
+
@pytest.mark.parametrize("Z, H, N_CTX, HEAD_DIM", [(2, 32, 1024, 64)])
|
1114 |
+
@pytest.mark.parametrize("causal", [True])
|
1115 |
+
def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float32):
|
1116 |
+
torch.manual_seed(20)
|
1117 |
+
q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
|
1118 |
+
k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
|
1119 |
+
v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
|
1120 |
+
sm_scale = 0.5
|
1121 |
+
dout = torch.randn_like(q)
|
1122 |
+
# reference implementation
|
1123 |
+
M = torch.tril(torch.ones((N_CTX, N_CTX), device=DEVICE))
|
1124 |
+
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
|
1125 |
+
if causal:
|
1126 |
+
p[:, :, M == 0] = float("-inf")
|
1127 |
+
p = torch.softmax(p.float(), dim=-1)
|
1128 |
+
# p = torch.exp(p)
|
1129 |
+
ref_out = torch.matmul(p, v)
|
1130 |
+
ref_out.backward(dout)
|
1131 |
+
ref_dv, v.grad = v.grad.clone(), None
|
1132 |
+
ref_dk, k.grad = k.grad.clone(), None
|
1133 |
+
ref_dq, q.grad = q.grad.clone(), None
|
1134 |
+
# triton implementation
|
1135 |
+
tri_out = attention(q, k, v, causal, sm_scale)
|
1136 |
+
tri_out.backward(dout)
|
1137 |
+
tri_dv, v.grad = v.grad.clone(), None
|
1138 |
+
tri_dk, k.grad = k.grad.clone(), None
|
1139 |
+
tri_dq, q.grad = q.grad.clone(), None
|
1140 |
+
# compare
|
1141 |
+
assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0)
|
1142 |
+
rtol = 0.0
|
1143 |
+
# Relative tolerance workaround for known hardware limitation of CDNA2 GPU.
|
1144 |
+
# For details see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices
|
1145 |
+
if torch.version.hip is not None and triton.runtime.driver.active.get_current_target().arch == "gfx90a":
|
1146 |
+
rtol = 1e-2
|
1147 |
+
assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=rtol)
|
1148 |
+
assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=rtol)
|
1149 |
+
assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=rtol)
|
torch-ext/triton_llama_attn/.pytest_cache/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Created by pytest automatically.
|
2 |
+
*
|
torch-ext/triton_llama_attn/.pytest_cache/CACHEDIR.TAG
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Signature: 8a477f597d28d172789f06886806bc55
|
2 |
+
# This file is a cache directory tag created by pytest.
|
3 |
+
# For information about cache directory tags, see:
|
4 |
+
# https://bford.info/cachedir/spec.html
|
torch-ext/triton_llama_attn/.pytest_cache/README.md
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pytest cache directory #
|
2 |
+
|
3 |
+
This directory contains data from the pytest's cache plugin,
|
4 |
+
which provides the `--lf` and `--ff` options, as well as the `cache` fixture.
|
5 |
+
|
6 |
+
**Do not** commit this to version control.
|
7 |
+
|
8 |
+
See [the docs](https://docs.pytest.org/en/stable/how-to/cache.html) for more information.
|
torch-ext/triton_llama_attn/.pytest_cache/v/cache/lastfailed
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{}
|
torch-ext/triton_llama_attn/.pytest_cache/v/cache/nodeids
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
"attn.py::test_op[True-2-32-1024-64]"
|
3 |
+
]
|
torch-ext/triton_llama_attn/.pytest_cache/v/cache/stepwise
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
[]
|
torch-ext/triton_llama_attn/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (258 Bytes). View file
|
|
torch-ext/triton_llama_attn/__pycache__/attn.cpython-310-pytest-8.3.5.pyc
ADDED
Binary file (29.6 kB). View file
|
|
torch-ext/triton_llama_attn/attn.py
CHANGED
@@ -144,7 +144,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
|
|
144 |
if fp8_v:
|
145 |
p = p.to(tl.float8e5)
|
146 |
else:
|
147 |
-
p = p.to(tl.
|
148 |
acc = tl.dot(p, v, acc)
|
149 |
# update m_i and l_i
|
150 |
m_i = m_ij
|
@@ -344,7 +344,7 @@ def _attn_fwd_tma(sm_scale, M, #
|
|
344 |
FP8_OUTPUT: tl.constexpr, #
|
345 |
STAGE: tl.constexpr #
|
346 |
):
|
347 |
-
dtype = tl.float8e5 if FP8_OUTPUT else tl.
|
348 |
tl.static_assert(BLOCK_N <= HEAD_DIM)
|
349 |
start_m = tl.program_id(0)
|
350 |
off_hz = tl.program_id(1)
|
@@ -447,14 +447,14 @@ def _attn_bwd_dkdv(dk, dv, #
|
|
447 |
do = tl.load(do_ptrs)
|
448 |
# Compute dV.
|
449 |
ppT = pT
|
450 |
-
ppT = ppT.to(tl.
|
451 |
dv += tl.dot(ppT, do)
|
452 |
# D (= delta) is pre-divided by ds_scale.
|
453 |
Di = tl.load(D + offs_m)
|
454 |
# Compute dP and dS.
|
455 |
dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
|
456 |
dsT = pT * (dpT - Di[None, :])
|
457 |
-
dsT = dsT.to(tl.
|
458 |
dk += tl.dot(dsT, tl.trans(qT))
|
459 |
# Increment pointers.
|
460 |
curr_m += step_m
|
@@ -500,7 +500,7 @@ def _attn_bwd_dq(dq, q, K, V, #
|
|
500 |
# Compute dP and dS.
|
501 |
dp = tl.dot(do, vT).to(tl.float32)
|
502 |
ds = p * (dp - Di[:, None])
|
503 |
-
ds = ds.to(tl.
|
504 |
# Compute dQ.
|
505 |
# NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
|
506 |
dq += tl.dot(ds, tl.trans(kT))
|
@@ -967,6 +967,7 @@ def attn_forward_kernel(
|
|
967 |
scaling: float,
|
968 |
causal: bool,
|
969 |
):
|
|
|
970 |
return Attention.apply(query, key, value, causal, scaling)
|
971 |
|
972 |
# def test_llama_attention_output():
|
@@ -1105,44 +1106,44 @@ def attn_forward_kernel(
|
|
1105 |
|
1106 |
# return is_close
|
1107 |
|
1108 |
-
|
1109 |
-
|
1110 |
-
|
1111 |
-
|
1112 |
-
|
1113 |
-
|
1114 |
-
|
1115 |
-
|
1116 |
-
|
1117 |
-
|
1118 |
-
|
1119 |
-
|
1120 |
-
|
1121 |
-
#
|
1122 |
-
|
1123 |
-
|
1124 |
-
|
1125 |
-
|
1126 |
-
|
1127 |
-
#
|
1128 |
-
|
1129 |
-
|
1130 |
-
|
1131 |
-
|
1132 |
-
|
1133 |
-
#
|
1134 |
-
|
1135 |
-
|
1136 |
-
|
1137 |
-
|
1138 |
-
|
1139 |
-
#
|
1140 |
-
|
1141 |
-
|
1142 |
-
#
|
1143 |
-
#
|
1144 |
-
|
1145 |
-
|
1146 |
-
|
1147 |
-
|
1148 |
-
|
|
|
144 |
if fp8_v:
|
145 |
p = p.to(tl.float8e5)
|
146 |
else:
|
147 |
+
p = p.to(tl.float32)
|
148 |
acc = tl.dot(p, v, acc)
|
149 |
# update m_i and l_i
|
150 |
m_i = m_ij
|
|
|
344 |
FP8_OUTPUT: tl.constexpr, #
|
345 |
STAGE: tl.constexpr #
|
346 |
):
|
347 |
+
dtype = tl.float8e5 if FP8_OUTPUT else tl.float32
|
348 |
tl.static_assert(BLOCK_N <= HEAD_DIM)
|
349 |
start_m = tl.program_id(0)
|
350 |
off_hz = tl.program_id(1)
|
|
|
447 |
do = tl.load(do_ptrs)
|
448 |
# Compute dV.
|
449 |
ppT = pT
|
450 |
+
ppT = ppT.to(tl.float32)
|
451 |
dv += tl.dot(ppT, do)
|
452 |
# D (= delta) is pre-divided by ds_scale.
|
453 |
Di = tl.load(D + offs_m)
|
454 |
# Compute dP and dS.
|
455 |
dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
|
456 |
dsT = pT * (dpT - Di[None, :])
|
457 |
+
dsT = dsT.to(tl.float32)
|
458 |
dk += tl.dot(dsT, tl.trans(qT))
|
459 |
# Increment pointers.
|
460 |
curr_m += step_m
|
|
|
500 |
# Compute dP and dS.
|
501 |
dp = tl.dot(do, vT).to(tl.float32)
|
502 |
ds = p * (dp - Di[:, None])
|
503 |
+
ds = ds.to(tl.float32)
|
504 |
# Compute dQ.
|
505 |
# NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
|
506 |
dq += tl.dot(ds, tl.trans(kT))
|
|
|
967 |
scaling: float,
|
968 |
causal: bool,
|
969 |
):
|
970 |
+
print("######################### attn_forward_kernel", query.shape, key.shape, value.shape, scaling, causal)
|
971 |
return Attention.apply(query, key, value, causal, scaling)
|
972 |
|
973 |
# def test_llama_attention_output():
|
|
|
1106 |
|
1107 |
# return is_close
|
1108 |
|
1109 |
+
attention = Attention.apply
|
1110 |
+
DEVICE = "cuda:0"
|
1111 |
+
|
1112 |
+
import pytest
|
1113 |
+
@pytest.mark.parametrize("Z, H, N_CTX, HEAD_DIM", [(2, 32, 1024, 64)])
|
1114 |
+
@pytest.mark.parametrize("causal", [True])
|
1115 |
+
def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float32):
|
1116 |
+
torch.manual_seed(20)
|
1117 |
+
q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
|
1118 |
+
k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
|
1119 |
+
v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
|
1120 |
+
sm_scale = 0.5
|
1121 |
+
dout = torch.randn_like(q)
|
1122 |
+
# reference implementation
|
1123 |
+
M = torch.tril(torch.ones((N_CTX, N_CTX), device=DEVICE))
|
1124 |
+
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
|
1125 |
+
if causal:
|
1126 |
+
p[:, :, M == 0] = float("-inf")
|
1127 |
+
p = torch.softmax(p.float(), dim=-1)
|
1128 |
+
# p = torch.exp(p)
|
1129 |
+
ref_out = torch.matmul(p, v)
|
1130 |
+
ref_out.backward(dout)
|
1131 |
+
ref_dv, v.grad = v.grad.clone(), None
|
1132 |
+
ref_dk, k.grad = k.grad.clone(), None
|
1133 |
+
ref_dq, q.grad = q.grad.clone(), None
|
1134 |
+
# triton implementation
|
1135 |
+
tri_out = attention(q, k, v, causal, sm_scale)
|
1136 |
+
tri_out.backward(dout)
|
1137 |
+
tri_dv, v.grad = v.grad.clone(), None
|
1138 |
+
tri_dk, k.grad = k.grad.clone(), None
|
1139 |
+
tri_dq, q.grad = q.grad.clone(), None
|
1140 |
+
# compare
|
1141 |
+
assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0)
|
1142 |
+
rtol = 0.0
|
1143 |
+
# Relative tolerance workaround for known hardware limitation of CDNA2 GPU.
|
1144 |
+
# For details see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices
|
1145 |
+
if torch.version.hip is not None and triton.runtime.driver.active.get_current_target().arch == "gfx90a":
|
1146 |
+
rtol = 1e-2
|
1147 |
+
assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=rtol)
|
1148 |
+
assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=rtol)
|
1149 |
+
assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=rtol)
|