medmekk HF Staff commited on
Commit
9099bf8
·
verified ·
1 Parent(s): 755e5cd

Upload custom kernels

Browse files
build/torch-universal/triton_llama_attn/attn.py CHANGED
@@ -144,7 +144,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
144
  if fp8_v:
145
  p = p.to(tl.float8e5)
146
  else:
147
- p = p.to(tl.float16)
148
  acc = tl.dot(p, v, acc)
149
  # update m_i and l_i
150
  m_i = m_ij
@@ -344,7 +344,7 @@ def _attn_fwd_tma(sm_scale, M, #
344
  FP8_OUTPUT: tl.constexpr, #
345
  STAGE: tl.constexpr #
346
  ):
347
- dtype = tl.float8e5 if FP8_OUTPUT else tl.float16
348
  tl.static_assert(BLOCK_N <= HEAD_DIM)
349
  start_m = tl.program_id(0)
350
  off_hz = tl.program_id(1)
@@ -447,14 +447,14 @@ def _attn_bwd_dkdv(dk, dv, #
447
  do = tl.load(do_ptrs)
448
  # Compute dV.
449
  ppT = pT
450
- ppT = ppT.to(tl.float16)
451
  dv += tl.dot(ppT, do)
452
  # D (= delta) is pre-divided by ds_scale.
453
  Di = tl.load(D + offs_m)
454
  # Compute dP and dS.
455
  dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
456
  dsT = pT * (dpT - Di[None, :])
457
- dsT = dsT.to(tl.float16)
458
  dk += tl.dot(dsT, tl.trans(qT))
459
  # Increment pointers.
460
  curr_m += step_m
@@ -500,7 +500,7 @@ def _attn_bwd_dq(dq, q, K, V, #
500
  # Compute dP and dS.
501
  dp = tl.dot(do, vT).to(tl.float32)
502
  ds = p * (dp - Di[:, None])
503
- ds = ds.to(tl.float16)
504
  # Compute dQ.
505
  # NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
506
  dq += tl.dot(ds, tl.trans(kT))
@@ -1106,44 +1106,44 @@ def attn_forward_kernel(
1106
 
1107
  # return is_close
1108
 
1109
- # attention = Attention.apply
1110
- # DEVICE = "cuda:0"
1111
-
1112
- # import pytest
1113
- # @pytest.mark.parametrize("Z, H, N_CTX, HEAD_DIM", [(1, 2, 1024, 64)])
1114
- # @pytest.mark.parametrize("causal", [True])
1115
- # def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16):
1116
- # torch.manual_seed(20)
1117
- # q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
1118
- # k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
1119
- # v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
1120
- # sm_scale = 0.5
1121
- # dout = torch.randn_like(q)
1122
- # # reference implementation
1123
- # M = torch.tril(torch.ones((N_CTX, N_CTX), device=DEVICE))
1124
- # p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
1125
- # if causal:
1126
- # p[:, :, M == 0] = float("-inf")
1127
- # p = torch.softmax(p.float(), dim=-1).half()
1128
- # # p = torch.exp(p)
1129
- # ref_out = torch.matmul(p, v)
1130
- # ref_out.backward(dout)
1131
- # ref_dv, v.grad = v.grad.clone(), None
1132
- # ref_dk, k.grad = k.grad.clone(), None
1133
- # ref_dq, q.grad = q.grad.clone(), None
1134
- # # triton implementation
1135
- # tri_out = attention(q, k, v, causal, sm_scale).half()
1136
- # tri_out.backward(dout)
1137
- # tri_dv, v.grad = v.grad.clone(), None
1138
- # tri_dk, k.grad = k.grad.clone(), None
1139
- # tri_dq, q.grad = q.grad.clone(), None
1140
- # # compare
1141
- # assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0)
1142
- # rtol = 0.0
1143
- # # Relative tolerance workaround for known hardware limitation of CDNA2 GPU.
1144
- # # For details see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices
1145
- # if torch.version.hip is not None and triton.runtime.driver.active.get_current_target().arch == "gfx90a":
1146
- # rtol = 1e-2
1147
- # assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=rtol)
1148
- # assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=rtol)
1149
- # assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=rtol)
 
144
  if fp8_v:
145
  p = p.to(tl.float8e5)
146
  else:
147
+ p = p.to(tl.float32)
148
  acc = tl.dot(p, v, acc)
149
  # update m_i and l_i
150
  m_i = m_ij
 
344
  FP8_OUTPUT: tl.constexpr, #
345
  STAGE: tl.constexpr #
346
  ):
347
+ dtype = tl.float8e5 if FP8_OUTPUT else tl.float32
348
  tl.static_assert(BLOCK_N <= HEAD_DIM)
349
  start_m = tl.program_id(0)
350
  off_hz = tl.program_id(1)
 
447
  do = tl.load(do_ptrs)
448
  # Compute dV.
449
  ppT = pT
450
+ ppT = ppT.to(tl.float32)
451
  dv += tl.dot(ppT, do)
452
  # D (= delta) is pre-divided by ds_scale.
453
  Di = tl.load(D + offs_m)
454
  # Compute dP and dS.
455
  dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
456
  dsT = pT * (dpT - Di[None, :])
457
+ dsT = dsT.to(tl.float32)
458
  dk += tl.dot(dsT, tl.trans(qT))
459
  # Increment pointers.
460
  curr_m += step_m
 
500
  # Compute dP and dS.
501
  dp = tl.dot(do, vT).to(tl.float32)
502
  ds = p * (dp - Di[:, None])
503
+ ds = ds.to(tl.float32)
504
  # Compute dQ.
505
  # NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
506
  dq += tl.dot(ds, tl.trans(kT))
 
1106
 
1107
  # return is_close
1108
 
1109
+ attention = Attention.apply
1110
+ DEVICE = "cuda:0"
1111
+
1112
+ import pytest
1113
+ @pytest.mark.parametrize("Z, H, N_CTX, HEAD_DIM", [(2, 32, 1024, 64)])
1114
+ @pytest.mark.parametrize("causal", [True])
1115
+ def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float32):
1116
+ torch.manual_seed(20)
1117
+ q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
1118
+ k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
1119
+ v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
1120
+ sm_scale = 0.5
1121
+ dout = torch.randn_like(q)
1122
+ # reference implementation
1123
+ M = torch.tril(torch.ones((N_CTX, N_CTX), device=DEVICE))
1124
+ p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
1125
+ if causal:
1126
+ p[:, :, M == 0] = float("-inf")
1127
+ p = torch.softmax(p.float(), dim=-1)
1128
+ # p = torch.exp(p)
1129
+ ref_out = torch.matmul(p, v)
1130
+ ref_out.backward(dout)
1131
+ ref_dv, v.grad = v.grad.clone(), None
1132
+ ref_dk, k.grad = k.grad.clone(), None
1133
+ ref_dq, q.grad = q.grad.clone(), None
1134
+ # triton implementation
1135
+ tri_out = attention(q, k, v, causal, sm_scale)
1136
+ tri_out.backward(dout)
1137
+ tri_dv, v.grad = v.grad.clone(), None
1138
+ tri_dk, k.grad = k.grad.clone(), None
1139
+ tri_dq, q.grad = q.grad.clone(), None
1140
+ # compare
1141
+ assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0)
1142
+ rtol = 0.0
1143
+ # Relative tolerance workaround for known hardware limitation of CDNA2 GPU.
1144
+ # For details see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices
1145
+ if torch.version.hip is not None and triton.runtime.driver.active.get_current_target().arch == "gfx90a":
1146
+ rtol = 1e-2
1147
+ assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=rtol)
1148
+ assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=rtol)
1149
+ assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=rtol)
torch-ext/triton_llama_attn/.pytest_cache/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Created by pytest automatically.
2
+ *
torch-ext/triton_llama_attn/.pytest_cache/CACHEDIR.TAG ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ Signature: 8a477f597d28d172789f06886806bc55
2
+ # This file is a cache directory tag created by pytest.
3
+ # For information about cache directory tags, see:
4
+ # https://bford.info/cachedir/spec.html
torch-ext/triton_llama_attn/.pytest_cache/README.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # pytest cache directory #
2
+
3
+ This directory contains data from the pytest's cache plugin,
4
+ which provides the `--lf` and `--ff` options, as well as the `cache` fixture.
5
+
6
+ **Do not** commit this to version control.
7
+
8
+ See [the docs](https://docs.pytest.org/en/stable/how-to/cache.html) for more information.
torch-ext/triton_llama_attn/.pytest_cache/v/cache/lastfailed ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
torch-ext/triton_llama_attn/.pytest_cache/v/cache/nodeids ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [
2
+ "attn.py::test_op[True-2-32-1024-64]"
3
+ ]
torch-ext/triton_llama_attn/.pytest_cache/v/cache/stepwise ADDED
@@ -0,0 +1 @@
 
 
1
+ []
torch-ext/triton_llama_attn/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (258 Bytes). View file
 
torch-ext/triton_llama_attn/__pycache__/attn.cpython-310-pytest-8.3.5.pyc ADDED
Binary file (29.6 kB). View file
 
torch-ext/triton_llama_attn/attn.py CHANGED
@@ -144,7 +144,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
144
  if fp8_v:
145
  p = p.to(tl.float8e5)
146
  else:
147
- p = p.to(tl.float16)
148
  acc = tl.dot(p, v, acc)
149
  # update m_i and l_i
150
  m_i = m_ij
@@ -344,7 +344,7 @@ def _attn_fwd_tma(sm_scale, M, #
344
  FP8_OUTPUT: tl.constexpr, #
345
  STAGE: tl.constexpr #
346
  ):
347
- dtype = tl.float8e5 if FP8_OUTPUT else tl.float16
348
  tl.static_assert(BLOCK_N <= HEAD_DIM)
349
  start_m = tl.program_id(0)
350
  off_hz = tl.program_id(1)
@@ -447,14 +447,14 @@ def _attn_bwd_dkdv(dk, dv, #
447
  do = tl.load(do_ptrs)
448
  # Compute dV.
449
  ppT = pT
450
- ppT = ppT.to(tl.float16)
451
  dv += tl.dot(ppT, do)
452
  # D (= delta) is pre-divided by ds_scale.
453
  Di = tl.load(D + offs_m)
454
  # Compute dP and dS.
455
  dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
456
  dsT = pT * (dpT - Di[None, :])
457
- dsT = dsT.to(tl.float16)
458
  dk += tl.dot(dsT, tl.trans(qT))
459
  # Increment pointers.
460
  curr_m += step_m
@@ -500,7 +500,7 @@ def _attn_bwd_dq(dq, q, K, V, #
500
  # Compute dP and dS.
501
  dp = tl.dot(do, vT).to(tl.float32)
502
  ds = p * (dp - Di[:, None])
503
- ds = ds.to(tl.float16)
504
  # Compute dQ.
505
  # NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
506
  dq += tl.dot(ds, tl.trans(kT))
@@ -967,6 +967,7 @@ def attn_forward_kernel(
967
  scaling: float,
968
  causal: bool,
969
  ):
 
970
  return Attention.apply(query, key, value, causal, scaling)
971
 
972
  # def test_llama_attention_output():
@@ -1105,44 +1106,44 @@ def attn_forward_kernel(
1105
 
1106
  # return is_close
1107
 
1108
- # attention = Attention.apply
1109
- # DEVICE = "cuda:0"
1110
-
1111
- # import pytest
1112
- # @pytest.mark.parametrize("Z, H, N_CTX, HEAD_DIM", [(1, 2, 1024, 64)])
1113
- # @pytest.mark.parametrize("causal", [True])
1114
- # def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16):
1115
- # torch.manual_seed(20)
1116
- # q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
1117
- # k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
1118
- # v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
1119
- # sm_scale = 0.5
1120
- # dout = torch.randn_like(q)
1121
- # # reference implementation
1122
- # M = torch.tril(torch.ones((N_CTX, N_CTX), device=DEVICE))
1123
- # p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
1124
- # if causal:
1125
- # p[:, :, M == 0] = float("-inf")
1126
- # p = torch.softmax(p.float(), dim=-1).half()
1127
- # # p = torch.exp(p)
1128
- # ref_out = torch.matmul(p, v)
1129
- # ref_out.backward(dout)
1130
- # ref_dv, v.grad = v.grad.clone(), None
1131
- # ref_dk, k.grad = k.grad.clone(), None
1132
- # ref_dq, q.grad = q.grad.clone(), None
1133
- # # triton implementation
1134
- # tri_out = attention(q, k, v, causal, sm_scale).half()
1135
- # tri_out.backward(dout)
1136
- # tri_dv, v.grad = v.grad.clone(), None
1137
- # tri_dk, k.grad = k.grad.clone(), None
1138
- # tri_dq, q.grad = q.grad.clone(), None
1139
- # # compare
1140
- # assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0)
1141
- # rtol = 0.0
1142
- # # Relative tolerance workaround for known hardware limitation of CDNA2 GPU.
1143
- # # For details see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices
1144
- # if torch.version.hip is not None and triton.runtime.driver.active.get_current_target().arch == "gfx90a":
1145
- # rtol = 1e-2
1146
- # assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=rtol)
1147
- # assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=rtol)
1148
- # assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=rtol)
 
144
  if fp8_v:
145
  p = p.to(tl.float8e5)
146
  else:
147
+ p = p.to(tl.float32)
148
  acc = tl.dot(p, v, acc)
149
  # update m_i and l_i
150
  m_i = m_ij
 
344
  FP8_OUTPUT: tl.constexpr, #
345
  STAGE: tl.constexpr #
346
  ):
347
+ dtype = tl.float8e5 if FP8_OUTPUT else tl.float32
348
  tl.static_assert(BLOCK_N <= HEAD_DIM)
349
  start_m = tl.program_id(0)
350
  off_hz = tl.program_id(1)
 
447
  do = tl.load(do_ptrs)
448
  # Compute dV.
449
  ppT = pT
450
+ ppT = ppT.to(tl.float32)
451
  dv += tl.dot(ppT, do)
452
  # D (= delta) is pre-divided by ds_scale.
453
  Di = tl.load(D + offs_m)
454
  # Compute dP and dS.
455
  dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
456
  dsT = pT * (dpT - Di[None, :])
457
+ dsT = dsT.to(tl.float32)
458
  dk += tl.dot(dsT, tl.trans(qT))
459
  # Increment pointers.
460
  curr_m += step_m
 
500
  # Compute dP and dS.
501
  dp = tl.dot(do, vT).to(tl.float32)
502
  ds = p * (dp - Di[:, None])
503
+ ds = ds.to(tl.float32)
504
  # Compute dQ.
505
  # NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
506
  dq += tl.dot(ds, tl.trans(kT))
 
967
  scaling: float,
968
  causal: bool,
969
  ):
970
+ print("######################### attn_forward_kernel", query.shape, key.shape, value.shape, scaling, causal)
971
  return Attention.apply(query, key, value, causal, scaling)
972
 
973
  # def test_llama_attention_output():
 
1106
 
1107
  # return is_close
1108
 
1109
+ attention = Attention.apply
1110
+ DEVICE = "cuda:0"
1111
+
1112
+ import pytest
1113
+ @pytest.mark.parametrize("Z, H, N_CTX, HEAD_DIM", [(2, 32, 1024, 64)])
1114
+ @pytest.mark.parametrize("causal", [True])
1115
+ def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float32):
1116
+ torch.manual_seed(20)
1117
+ q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
1118
+ k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
1119
+ v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
1120
+ sm_scale = 0.5
1121
+ dout = torch.randn_like(q)
1122
+ # reference implementation
1123
+ M = torch.tril(torch.ones((N_CTX, N_CTX), device=DEVICE))
1124
+ p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
1125
+ if causal:
1126
+ p[:, :, M == 0] = float("-inf")
1127
+ p = torch.softmax(p.float(), dim=-1)
1128
+ # p = torch.exp(p)
1129
+ ref_out = torch.matmul(p, v)
1130
+ ref_out.backward(dout)
1131
+ ref_dv, v.grad = v.grad.clone(), None
1132
+ ref_dk, k.grad = k.grad.clone(), None
1133
+ ref_dq, q.grad = q.grad.clone(), None
1134
+ # triton implementation
1135
+ tri_out = attention(q, k, v, causal, sm_scale)
1136
+ tri_out.backward(dout)
1137
+ tri_dv, v.grad = v.grad.clone(), None
1138
+ tri_dk, k.grad = k.grad.clone(), None
1139
+ tri_dq, q.grad = q.grad.clone(), None
1140
+ # compare
1141
+ assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0)
1142
+ rtol = 0.0
1143
+ # Relative tolerance workaround for known hardware limitation of CDNA2 GPU.
1144
+ # For details see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices
1145
+ if torch.version.hip is not None and triton.runtime.driver.active.get_current_target().arch == "gfx90a":
1146
+ rtol = 1e-2
1147
+ assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=rtol)
1148
+ assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=rtol)
1149
+ assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=rtol)