medmekk HF Staff commited on
Commit
43eb41d
·
verified ·
1 Parent(s): d054198

Upload custom kernels

Browse files
README.md ADDED
File without changes
build.toml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [general]
2
+ name = "triton_llama_attn"
3
+
4
+ [torch]
5
+ universal = true
build/torch-universal/triton_llama_attn/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .attn import attn_forward_kernel
2
+
3
+ __all__ = ["attn_forward_kernel"]
build/torch-universal/triton_llama_attn/attn.py ADDED
@@ -0,0 +1,1148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers.models.llama.configuration_llama import LlamaConfig
3
+ from transformers import AutoModelForCausalLM
4
+ import triton.tools.experimental_descriptor
5
+ from typing import Tuple, Optional, Callable
6
+ import triton
7
+ import triton.language as tl
8
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
9
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
10
+ from transformers import Cache
11
+ import torch.nn as nn
12
+ from transformers.processing_utils import Unpack
13
+ # ENABLE_LHS_TO_TMEM is an experimental environment variable for Blackwell.
14
+ # If it is set to 1 it can improve performance of Blackwell attention. However,
15
+ # it defaults to 0 as it is known to cause correctness issues outside of the
16
+ # _attn_fwd_tma kernel below.
17
+
18
+ # DEVICE = triton.runtime.driver.active.get_active_torch_device()
19
+
20
+
21
+ def is_hip():
22
+ return triton.runtime.driver.active.get_current_target().backend == "hip"
23
+
24
+
25
+ def is_cuda():
26
+ return triton.runtime.driver.active.get_current_target().backend == "cuda"
27
+
28
+
29
+ def supports_tma():
30
+ return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
31
+
32
+
33
+ HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl)
34
+
35
+ if HAS_TMA_DESC:
36
+ print("TMA benchmarks will be running with experimental grid constant TMA descriptor.", )
37
+ else:
38
+ print("TMA benchmarks will be running without grid constant TMA descriptor.", )
39
+
40
+
41
+ # TmaAutoTuneHelper used in htyu's PR #5622
42
+ class TmaAutoTuneHelper:
43
+
44
+ # duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498
45
+ class KernelParamWrapper:
46
+
47
+ def __init__(self, desc):
48
+ self.desc = desc
49
+
50
+ def tma_desc_cpu_ptr(self):
51
+ return self.desc.data_ptr()
52
+
53
+ TMA_SIZE = 128
54
+
55
+ def __init__(self):
56
+ self.fill_1d_tma_descriptor_inner = (triton.runtime.driver.active.utils.fill_1d_tma_descriptor)
57
+ self.fill_2d_tma_descriptor_inner = (triton.runtime.driver.active.utils.fill_2d_tma_descriptor)
58
+ if HAS_TMA_DESC:
59
+ self.descriptors = {}
60
+ else:
61
+ self.cuda_descriptors = {}
62
+
63
+ # Call this method outside of the lambda function for grid size
64
+ def init_tma_descriptor(self, name):
65
+ if HAS_TMA_DESC:
66
+ self.descriptors[name] = torch.empty(TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8)
67
+ else:
68
+ self.cuda_descriptors[name] = torch.empty(TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8)
69
+
70
+ # Call this method inside the lambda function for grid size
71
+ def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size):
72
+ if HAS_TMA_DESC:
73
+ desc_x = self.descriptors[name]
74
+ assert desc_x.data_ptr() % 64 == 0
75
+ self.fill_1d_tma_descriptor_inner(ptr, dim, block_dim, element_size, desc_x.data_ptr())
76
+ else:
77
+ desc_x = self.cuda_descriptors[name]
78
+ buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
79
+ self.fill_1d_tma_descriptor_inner(ptr, dim, block_dim, element_size, buf_x.data_ptr())
80
+ desc_x.copy_(buf_x, non_blocking=True)
81
+
82
+ # Call this method inside the lambda function for grid size
83
+ def fill_2d_tma_descriptor(self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size):
84
+ if HAS_TMA_DESC:
85
+ desc_x = self.descriptors[name]
86
+ assert desc_x.data_ptr() % 64 == 0
87
+ self.fill_2d_tma_descriptor_inner(ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr())
88
+ else:
89
+ desc_x = self.cuda_descriptors[name]
90
+ buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
91
+ self.fill_2d_tma_descriptor_inner(ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr())
92
+ desc_x.copy_(buf_x, non_blocking=True)
93
+
94
+ def get_tma_descriptor_kernel_param(self, name):
95
+ if HAS_TMA_DESC:
96
+ assert self.descriptors[name] is not None
97
+ return self.KernelParamWrapper(self.descriptors[name])
98
+ else:
99
+ assert self.cuda_descriptors[name] is not None
100
+ return self.cuda_descriptors[name]
101
+
102
+
103
+ @triton.jit
104
+ def _attn_fwd_inner(acc, l_i, m_i, q, #
105
+ K_block_ptr, V_block_ptr, #
106
+ start_m, qk_scale, #
107
+ BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, #
108
+ STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, #
109
+ N_CTX: tl.constexpr, fp8_v: tl.constexpr):
110
+ # range of values handled by this stage
111
+ if STAGE == 1:
112
+ lo, hi = 0, start_m * BLOCK_M
113
+ elif STAGE == 2:
114
+ lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
115
+ lo = tl.multiple_of(lo, BLOCK_M)
116
+ # causal = False
117
+ else:
118
+ lo, hi = 0, N_CTX
119
+ K_block_ptr = tl.advance(K_block_ptr, (0, lo))
120
+ V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
121
+ # loop over k, v and update accumulator
122
+ for start_n in range(lo, hi, BLOCK_N):
123
+ start_n = tl.multiple_of(start_n, BLOCK_N)
124
+ # -- compute qk ----
125
+ k = tl.load(K_block_ptr)
126
+ qk = tl.dot(q, k)
127
+ if STAGE == 2:
128
+ mask = offs_m[:, None] >= (start_n + offs_n[None, :])
129
+ qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)
130
+ m_ij = tl.maximum(m_i, tl.max(qk, 1))
131
+ qk -= m_ij[:, None]
132
+ else:
133
+ m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
134
+ qk = qk * qk_scale - m_ij[:, None]
135
+ p = tl.math.exp2(qk)
136
+ l_ij = tl.sum(p, 1)
137
+ # -- update m_i and l_i
138
+ alpha = tl.math.exp2(m_i - m_ij)
139
+ l_i = l_i * alpha + l_ij
140
+ # -- update output accumulator --
141
+ acc = acc * alpha[:, None]
142
+ # update acc
143
+ v = tl.load(V_block_ptr)
144
+ if fp8_v:
145
+ p = p.to(tl.float8e5)
146
+ else:
147
+ p = p.to(tl.float16)
148
+ acc = tl.dot(p, v, acc)
149
+ # update m_i and l_i
150
+ m_i = m_ij
151
+ V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
152
+ K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
153
+ return acc, l_i, m_i
154
+
155
+
156
+ @triton.jit
157
+ def _attn_fwd_inner_tma(acc, l_i, m_i, q, #
158
+ desc_k, desc_v, #
159
+ offset_y, dtype: tl.constexpr, start_m, qk_scale, #
160
+ BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, #
161
+ STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, #
162
+ N_CTX: tl.constexpr):
163
+ # range of values handled by this stage
164
+ if STAGE == 1:
165
+ lo, hi = 0, start_m * BLOCK_M
166
+ elif STAGE == 2:
167
+ lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
168
+ lo = tl.multiple_of(lo, BLOCK_M)
169
+ # causal = False
170
+ else:
171
+ lo, hi = 0, N_CTX
172
+ offsetkv_y = offset_y + lo
173
+ # loop over k, v and update accumulator
174
+ for start_n in range(lo, hi, BLOCK_N):
175
+ start_n = tl.multiple_of(start_n, BLOCK_N)
176
+ # -- compute qk ----
177
+ k = tl._experimental_descriptor_load(desc_k, [offsetkv_y, 0], [BLOCK_N, HEAD_DIM], dtype).T
178
+ qk = tl.dot(q, k)
179
+ if STAGE == 2:
180
+ mask = offs_m[:, None] >= (start_n + offs_n[None, :])
181
+ qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)
182
+ m_ij = tl.maximum(m_i, tl.max(qk, 1))
183
+ qk -= m_ij[:, None]
184
+ else:
185
+ m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
186
+ qk = qk * qk_scale - m_ij[:, None]
187
+ p = tl.math.exp2(qk)
188
+ l_ij = tl.sum(p, 1)
189
+ # -- update m_i and l_i
190
+ alpha = tl.math.exp2(m_i - m_ij)
191
+ l_i = l_i * alpha + l_ij
192
+ # -- update output accumulator --
193
+ acc = acc * alpha[:, None]
194
+ # update acc
195
+ v = tl._experimental_descriptor_load(desc_v, [offsetkv_y, 0], [BLOCK_N, HEAD_DIM], dtype)
196
+ p = p.to(dtype)
197
+ # note that this non transposed v for FP8 is only supported on Blackwell
198
+ acc = tl.dot(p, v, acc)
199
+ # update m_i and l_i
200
+ m_i = m_ij
201
+ offsetkv_y += BLOCK_N
202
+ return acc, l_i, m_i
203
+
204
+
205
+ # We don't run auto-tuning every time to keep the tutorial fast. Keeping
206
+ # the code below and commenting out the equivalent parameters is convenient for
207
+ # re-tuning.
208
+ configs = [
209
+ triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \
210
+ for BM in [64, 128]\
211
+ for BN in [32, 64]\
212
+ for s in ([1] if is_hip() else [3, 4, 7])\
213
+ for w in [4, 8]\
214
+ ]
215
+
216
+
217
+ def keep(conf):
218
+ BLOCK_M = conf.kwargs["BLOCK_M"]
219
+ BLOCK_N = conf.kwargs["BLOCK_N"]
220
+ if BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8:
221
+ return False
222
+ return True
223
+
224
+
225
+ @triton.autotune(list(filter(keep, configs)), key=["N_CTX", "HEAD_DIM"])
226
+ @triton.jit
227
+ def _attn_fwd(Q, K, V, sm_scale, M, Out, #
228
+ stride_qz, stride_qh, stride_qm, stride_qk, #
229
+ stride_kz, stride_kh, stride_kn, stride_kk, #
230
+ stride_vz, stride_vh, stride_vk, stride_vn, #
231
+ stride_oz, stride_oh, stride_om, stride_on, #
232
+ Z, H, N_CTX, #
233
+ HEAD_DIM: tl.constexpr, #
234
+ BLOCK_M: tl.constexpr, #
235
+ BLOCK_N: tl.constexpr, #
236
+ STAGE: tl.constexpr #
237
+ ):
238
+ tl.static_assert(BLOCK_N <= HEAD_DIM)
239
+ start_m = tl.program_id(0)
240
+ off_hz = tl.program_id(1)
241
+ off_z = off_hz // H
242
+ off_h = off_hz % H
243
+ qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
244
+
245
+ # block pointers
246
+ Q_block_ptr = tl.make_block_ptr(
247
+ base=Q + qvk_offset,
248
+ shape=(N_CTX, HEAD_DIM),
249
+ strides=(stride_qm, stride_qk),
250
+ offsets=(start_m * BLOCK_M, 0),
251
+ block_shape=(BLOCK_M, HEAD_DIM),
252
+ order=(1, 0),
253
+ )
254
+ v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0)
255
+ V_block_ptr = tl.make_block_ptr(
256
+ base=V + qvk_offset,
257
+ shape=(N_CTX, HEAD_DIM),
258
+ strides=(stride_vk, stride_vn),
259
+ offsets=(0, 0),
260
+ block_shape=(BLOCK_N, HEAD_DIM),
261
+ order=v_order,
262
+ )
263
+ K_block_ptr = tl.make_block_ptr(
264
+ base=K + qvk_offset,
265
+ shape=(HEAD_DIM, N_CTX),
266
+ strides=(stride_kk, stride_kn),
267
+ offsets=(0, 0),
268
+ block_shape=(HEAD_DIM, BLOCK_N),
269
+ order=(0, 1),
270
+ )
271
+ O_block_ptr = tl.make_block_ptr(
272
+ base=Out + qvk_offset,
273
+ shape=(N_CTX, HEAD_DIM),
274
+ strides=(stride_om, stride_on),
275
+ offsets=(start_m * BLOCK_M, 0),
276
+ block_shape=(BLOCK_M, HEAD_DIM),
277
+ order=(1, 0),
278
+ )
279
+ # initialize offsets
280
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
281
+ offs_n = tl.arange(0, BLOCK_N)
282
+ # initialize pointer to m and l
283
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
284
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
285
+ acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
286
+ # load scales
287
+ qk_scale = sm_scale
288
+ qk_scale *= 1.44269504 # 1/log(2)
289
+ # load q: it will stay in SRAM throughout
290
+ q = tl.load(Q_block_ptr)
291
+ # stage 1: off-band
292
+ # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
293
+ # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
294
+ if STAGE & 1:
295
+ acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, #
296
+ start_m, qk_scale, #
297
+ BLOCK_M, HEAD_DIM, BLOCK_N, #
298
+ 4 - STAGE, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 #
299
+ )
300
+ # stage 2: on-band
301
+ if STAGE & 2:
302
+ # barrier makes it easier for compielr to schedule the
303
+ # two loops independently
304
+ acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, #
305
+ start_m, qk_scale, #
306
+ BLOCK_M, HEAD_DIM, BLOCK_N, #
307
+ 2, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 #
308
+ )
309
+ # epilogue
310
+ m_i += tl.math.log2(l_i)
311
+ acc = acc / l_i[:, None]
312
+ m_ptrs = M + off_hz * N_CTX + offs_m
313
+ tl.store(m_ptrs, m_i)
314
+ tl.store(O_block_ptr, acc.to(Out.type.element_ty))
315
+
316
+
317
+ # We don't run auto-tuning every time to keep the tutorial fast. Keeping
318
+ # the code below and commenting out the equivalent parameters is convenient for
319
+ # re-tuning.
320
+ configs_tma = [
321
+ triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \
322
+ for BM in [64, 128]\
323
+ for BN in [32, 64, 128]\
324
+ for s in [2, 3, 4, 6]\
325
+ for w in [4, 8]\
326
+ ]
327
+
328
+
329
+ def keep_tma(conf):
330
+ BLOCK_M = conf.kwargs["BLOCK_M"]
331
+ BLOCK_N = conf.kwargs["BLOCK_N"]
332
+ if (torch.cuda.get_device_capability()[0] == 9 and BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8):
333
+ return False
334
+ return True
335
+
336
+
337
+ @triton.autotune(configs=list(filter(keep_tma, configs_tma)), key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT"])
338
+ @triton.jit
339
+ def _attn_fwd_tma(sm_scale, M, #
340
+ Z, H, desc_q, desc_k, desc_v, desc_o, N_CTX, #
341
+ HEAD_DIM: tl.constexpr, #
342
+ BLOCK_M: tl.constexpr, #
343
+ BLOCK_N: tl.constexpr, #
344
+ FP8_OUTPUT: tl.constexpr, #
345
+ STAGE: tl.constexpr #
346
+ ):
347
+ dtype = tl.float8e5 if FP8_OUTPUT else tl.float16
348
+ tl.static_assert(BLOCK_N <= HEAD_DIM)
349
+ start_m = tl.program_id(0)
350
+ off_hz = tl.program_id(1)
351
+ off_z = off_hz // H
352
+ off_h = off_hz % H
353
+
354
+ offset_y = off_z + off_h * N_CTX
355
+ qo_offset_y = offset_y + start_m * BLOCK_M
356
+ # initialize offsets
357
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
358
+ offs_n = tl.arange(0, BLOCK_N)
359
+ # initialize pointer to m and l
360
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
361
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
362
+ acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
363
+ # load scales
364
+ qk_scale = sm_scale
365
+ qk_scale *= 1.44269504 # 1/log(2)
366
+ # load q: it will stay in SRAM throughout
367
+ q = tl._experimental_descriptor_load(desc_q, [qo_offset_y, 0], [BLOCK_M, HEAD_DIM], dtype)
368
+ # stage 1: off-band
369
+ # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
370
+ # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
371
+ if STAGE & 1:
372
+ acc, l_i, m_i = _attn_fwd_inner_tma(acc, l_i, m_i, q, #
373
+ desc_k, desc_v, #
374
+ offset_y, dtype, start_m, qk_scale, #
375
+ BLOCK_M, HEAD_DIM, BLOCK_N, #
376
+ 4 - STAGE, offs_m, offs_n, N_CTX, #
377
+ )
378
+ # stage 2: on-band
379
+ if STAGE & 2:
380
+ # barrier makes it easier for compielr to schedule the
381
+ # two loops independently
382
+ acc, l_i, m_i = _attn_fwd_inner_tma(acc, l_i, m_i, q, #
383
+ desc_k, desc_v, #
384
+ offset_y, dtype, start_m, qk_scale, #
385
+ BLOCK_M, HEAD_DIM, BLOCK_N, #
386
+ 2, offs_m, offs_n, N_CTX, #
387
+ )
388
+ # epilogue
389
+ m_i += tl.math.log2(l_i)
390
+ acc = acc / l_i[:, None]
391
+ m_ptrs = M + off_hz * N_CTX + offs_m
392
+ tl.store(m_ptrs, m_i)
393
+ tl._experimental_descriptor_store(desc_o, acc.to(dtype), [qo_offset_y, 0])
394
+
395
+
396
+ @triton.jit
397
+ def _attn_bwd_preprocess(O, DO, #
398
+ Delta, #
399
+ Z, H, N_CTX, #
400
+ BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr #
401
+ ):
402
+ off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
403
+ off_hz = tl.program_id(1)
404
+ off_n = tl.arange(0, HEAD_DIM)
405
+ # load
406
+ o = tl.load(O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :])
407
+ do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]).to(tl.float32)
408
+ delta = tl.sum(o * do, axis=1)
409
+ # write-back
410
+ tl.store(Delta + off_hz * N_CTX + off_m, delta)
411
+
412
+
413
+ # The main inner-loop logic for computing dK and dV.
414
+ @triton.jit
415
+ def _attn_bwd_dkdv(dk, dv, #
416
+ Q, k, v, sm_scale, #
417
+ DO, #
418
+ M, D, #
419
+ # shared by Q/K/V/DO.
420
+ stride_tok, stride_d, #
421
+ H, N_CTX, BLOCK_M1: tl.constexpr, #
422
+ BLOCK_N1: tl.constexpr, #
423
+ HEAD_DIM: tl.constexpr, #
424
+ # Filled in by the wrapper.
425
+ start_n, start_m, num_steps, #
426
+ MASK: tl.constexpr):
427
+ offs_m = start_m + tl.arange(0, BLOCK_M1)
428
+ offs_n = start_n + tl.arange(0, BLOCK_N1)
429
+ offs_k = tl.arange(0, HEAD_DIM)
430
+ qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d
431
+ do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
432
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
433
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
434
+ curr_m = start_m
435
+ step_m = BLOCK_M1
436
+ for blk_idx in range(num_steps):
437
+ qT = tl.load(qT_ptrs)
438
+ # Load m before computing qk to reduce pipeline stall.
439
+ offs_m = curr_m + tl.arange(0, BLOCK_M1)
440
+ m = tl.load(M + offs_m)
441
+ qkT = tl.dot(k, qT)
442
+ pT = tl.math.exp2(qkT - m[None, :])
443
+ # Autoregressive masking.
444
+ if MASK:
445
+ mask = (offs_m[None, :] >= offs_n[:, None])
446
+ pT = tl.where(mask, pT, 0.0)
447
+ do = tl.load(do_ptrs)
448
+ # Compute dV.
449
+ ppT = pT
450
+ ppT = ppT.to(tl.float16)
451
+ dv += tl.dot(ppT, do)
452
+ # D (= delta) is pre-divided by ds_scale.
453
+ Di = tl.load(D + offs_m)
454
+ # Compute dP and dS.
455
+ dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
456
+ dsT = pT * (dpT - Di[None, :])
457
+ dsT = dsT.to(tl.float16)
458
+ dk += tl.dot(dsT, tl.trans(qT))
459
+ # Increment pointers.
460
+ curr_m += step_m
461
+ qT_ptrs += step_m * stride_tok
462
+ do_ptrs += step_m * stride_tok
463
+ return dk, dv
464
+
465
+
466
+ # the main inner-loop logic for computing dQ
467
+ @triton.jit
468
+ def _attn_bwd_dq(dq, q, K, V, #
469
+ do, m, D,
470
+ # shared by Q/K/V/DO.
471
+ stride_tok, stride_d, #
472
+ H, N_CTX, #
473
+ BLOCK_M2: tl.constexpr, #
474
+ BLOCK_N2: tl.constexpr, #
475
+ HEAD_DIM: tl.constexpr,
476
+ # Filled in by the wrapper.
477
+ start_m, start_n, num_steps, #
478
+ MASK: tl.constexpr):
479
+ offs_m = start_m + tl.arange(0, BLOCK_M2)
480
+ offs_n = start_n + tl.arange(0, BLOCK_N2)
481
+ offs_k = tl.arange(0, HEAD_DIM)
482
+ kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
483
+ vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
484
+ # D (= delta) is pre-divided by ds_scale.
485
+ Di = tl.load(D + offs_m)
486
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
487
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
488
+ curr_n = start_n
489
+ step_n = BLOCK_N2
490
+ for blk_idx in range(num_steps):
491
+ kT = tl.load(kT_ptrs)
492
+ vT = tl.load(vT_ptrs)
493
+ qk = tl.dot(q, kT)
494
+ p = tl.math.exp2(qk - m)
495
+ # Autoregressive masking.
496
+ if MASK:
497
+ offs_n = curr_n + tl.arange(0, BLOCK_N2)
498
+ mask = (offs_m[:, None] >= offs_n[None, :])
499
+ p = tl.where(mask, p, 0.0)
500
+ # Compute dP and dS.
501
+ dp = tl.dot(do, vT).to(tl.float32)
502
+ ds = p * (dp - Di[:, None])
503
+ ds = ds.to(tl.float16)
504
+ # Compute dQ.
505
+ # NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
506
+ dq += tl.dot(ds, tl.trans(kT))
507
+ # Increment pointers.
508
+ curr_n += step_n
509
+ kT_ptrs += step_n * stride_tok
510
+ vT_ptrs += step_n * stride_tok
511
+ return dq
512
+
513
+
514
+ @triton.jit
515
+ def _attn_bwd(Q, K, V, sm_scale, #
516
+ DO, #
517
+ DQ, DK, DV, #
518
+ M, D,
519
+ # shared by Q/K/V/DO.
520
+ stride_z, stride_h, stride_tok, stride_d, #
521
+ H, N_CTX, #
522
+ BLOCK_M1: tl.constexpr, #
523
+ BLOCK_N1: tl.constexpr, #
524
+ BLOCK_M2: tl.constexpr, #
525
+ BLOCK_N2: tl.constexpr, #
526
+ BLK_SLICE_FACTOR: tl.constexpr, #
527
+ HEAD_DIM: tl.constexpr):
528
+ LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
529
+
530
+ bhid = tl.program_id(2)
531
+ off_chz = (bhid * N_CTX).to(tl.int64)
532
+ adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64)
533
+ pid = tl.program_id(0)
534
+
535
+ # offset pointers for batch/head
536
+ Q += adj
537
+ K += adj
538
+ V += adj
539
+ DO += adj
540
+ DQ += adj
541
+ DK += adj
542
+ DV += adj
543
+ M += off_chz
544
+ D += off_chz
545
+
546
+ # load scales
547
+ offs_k = tl.arange(0, HEAD_DIM)
548
+
549
+ start_n = pid * BLOCK_N1
550
+ start_m = start_n
551
+
552
+ MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
553
+ offs_n = start_n + tl.arange(0, BLOCK_N1)
554
+
555
+ dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
556
+ dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
557
+
558
+ # load K and V: they stay in SRAM throughout the inner loop.
559
+ k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
560
+ v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
561
+
562
+ num_steps = BLOCK_N1 // MASK_BLOCK_M1
563
+
564
+ dk, dv = _attn_bwd_dkdv(dk, dv, #
565
+ Q, k, v, sm_scale, #
566
+ DO, #
567
+ M, D, #
568
+ stride_tok, stride_d, #
569
+ H, N_CTX, #
570
+ MASK_BLOCK_M1, BLOCK_N1, HEAD_DIM, #
571
+ start_n, start_m, num_steps, #
572
+ MASK=True #
573
+ )
574
+
575
+ start_m += num_steps * MASK_BLOCK_M1
576
+ num_steps = (N_CTX - start_m) // BLOCK_M1
577
+
578
+ # Compute dK and dV for non-masked blocks.
579
+ dk, dv = _attn_bwd_dkdv( #
580
+ dk, dv, #
581
+ Q, k, v, sm_scale, #
582
+ DO, #
583
+ M, D, #
584
+ stride_tok, stride_d, #
585
+ H, N_CTX, #
586
+ BLOCK_M1, BLOCK_N1, HEAD_DIM, #
587
+ start_n, start_m, num_steps, #
588
+ MASK=False #
589
+ )
590
+
591
+ dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
592
+ tl.store(dv_ptrs, dv)
593
+
594
+ # Write back dK.
595
+ dk *= sm_scale
596
+ dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
597
+ tl.store(dk_ptrs, dk)
598
+
599
+ # THIS BLOCK DOES DQ:
600
+ start_m = pid * BLOCK_M2
601
+ end_n = start_m + BLOCK_M2
602
+
603
+ MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
604
+ offs_m = start_m + tl.arange(0, BLOCK_M2)
605
+
606
+ q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
607
+ dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32)
608
+ do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
609
+
610
+ m = tl.load(M + offs_m)
611
+ m = m[:, None]
612
+
613
+ # Compute dQ for masked (diagonal) blocks.
614
+ # NOTE: This code scans each row of QK^T backward (from right to left,
615
+ # but inside each call to _attn_bwd_dq, from left to right), but that's
616
+ # not due to anything important. I just wanted to reuse the loop
617
+ # structure for dK & dV above as much as possible.
618
+ num_steps = BLOCK_M2 // MASK_BLOCK_N2
619
+ dq = _attn_bwd_dq(dq, q, K, V, #
620
+ do, m, D, #
621
+ stride_tok, stride_d, #
622
+ H, N_CTX, #
623
+ BLOCK_M2, MASK_BLOCK_N2, HEAD_DIM, #
624
+ start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, #
625
+ MASK=True #
626
+ )
627
+ end_n -= num_steps * MASK_BLOCK_N2
628
+ # stage 2
629
+ num_steps = end_n // BLOCK_N2
630
+ dq = _attn_bwd_dq(dq, q, K, V, #
631
+ do, m, D, #
632
+ stride_tok, stride_d, #
633
+ H, N_CTX, #
634
+ BLOCK_M2, BLOCK_N2, HEAD_DIM, #
635
+ start_m, end_n - num_steps * BLOCK_N2, num_steps, #
636
+ MASK=False #
637
+ )
638
+ # Write back dQ.
639
+ dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
640
+ dq *= LN2
641
+ tl.store(dq_ptrs, dq)
642
+
643
+
644
+ class Attention(torch.autograd.Function):
645
+
646
+ @staticmethod
647
+ def forward(ctx, q, k, v, causal, sm_scale, USE_TMA=True):
648
+ # shape constraints
649
+ HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]
650
+ # when v is in float8_e5m2 it is transposed.
651
+ HEAD_DIM_V = v.shape[-1]
652
+ assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
653
+ assert HEAD_DIM_K in {16, 32, 64, 128, 256}
654
+ o = torch.empty_like(q)
655
+ stage = 3 if causal else 1
656
+ extra_kern_args = {}
657
+ # Tuning for AMD target
658
+ if is_hip():
659
+ waves_per_eu = 3 if HEAD_DIM_K <= 64 else 2
660
+ extra_kern_args = {"waves_per_eu": waves_per_eu, "allow_flush_denorm": True}
661
+
662
+ M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
663
+ if USE_TMA and supports_tma() and not (torch.cuda.get_device_capability()[0] == 9
664
+ and q.dtype == torch.float8_e5m2):
665
+ # Note that on Hopper we cannot perform a FP8 dot with a non-transposed second tensor
666
+ y_dim = q.shape[0] * q.shape[1] * q.shape[2]
667
+
668
+ desc_helper = TmaAutoTuneHelper()
669
+ desc_helper.init_tma_descriptor("q")
670
+ desc_helper.init_tma_descriptor("v")
671
+ desc_helper.init_tma_descriptor("k")
672
+ desc_helper.init_tma_descriptor("o")
673
+
674
+ def grid(META):
675
+ nonlocal desc_helper
676
+
677
+ desc_helper.fill_2d_tma_descriptor("q", q.data_ptr(), y_dim, HEAD_DIM_K, META["BLOCK_M"], HEAD_DIM_K,
678
+ q.element_size())
679
+
680
+ desc_helper.fill_2d_tma_descriptor("v", v.data_ptr(), y_dim, HEAD_DIM_K, META["BLOCK_N"], HEAD_DIM_K,
681
+ v.element_size())
682
+
683
+ desc_helper.fill_2d_tma_descriptor("k", k.data_ptr(), y_dim, HEAD_DIM_K, META["BLOCK_N"], HEAD_DIM_K,
684
+ k.element_size())
685
+
686
+ desc_helper.fill_2d_tma_descriptor("o", o.data_ptr(), y_dim, HEAD_DIM_K, META["BLOCK_M"], HEAD_DIM_K,
687
+ o.element_size())
688
+
689
+ return (triton.cdiv(q.shape[2], META["BLOCK_M"]), q.shape[0] * q.shape[1], 1)
690
+
691
+ desc_q = desc_helper.get_tma_descriptor_kernel_param("q")
692
+ desc_v = desc_helper.get_tma_descriptor_kernel_param("v")
693
+ desc_k = desc_helper.get_tma_descriptor_kernel_param("k")
694
+ desc_o = desc_helper.get_tma_descriptor_kernel_param("o")
695
+
696
+ ctx.grid = grid
697
+ _attn_fwd_tma[grid](
698
+ sm_scale, M, #
699
+ q.shape[0], q.shape[1], #
700
+ desc_q, desc_k, desc_v, desc_o, #
701
+ N_CTX=q.shape[2], #
702
+ HEAD_DIM=HEAD_DIM_K, #
703
+ FP8_OUTPUT=q.dtype == torch.float8_e5m2, #
704
+ STAGE=stage, #
705
+ **extra_kern_args)
706
+ else:
707
+ grid = lambda args: (triton.cdiv(q.shape[2], args["BLOCK_M"]), q.shape[0] * q.shape[1], 1)
708
+ ctx.grid = grid
709
+ _attn_fwd[grid](
710
+ q, k, v, sm_scale, M, o, #
711
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
712
+ k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
713
+ v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
714
+ o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
715
+ q.shape[0], q.shape[1], #
716
+ N_CTX=q.shape[2], #
717
+ HEAD_DIM=HEAD_DIM_K, #
718
+ STAGE=stage, #
719
+ **extra_kern_args)
720
+
721
+ ctx.save_for_backward(q, k, v, o, M)
722
+ ctx.sm_scale = sm_scale
723
+ ctx.HEAD_DIM = HEAD_DIM_K
724
+ ctx.causal = causal
725
+ return o
726
+
727
+ @staticmethod
728
+ def backward(ctx, do):
729
+ q, k, v, o, M = ctx.saved_tensors
730
+ assert do.is_contiguous()
731
+ assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
732
+ dq = torch.empty_like(q)
733
+ dk = torch.empty_like(k)
734
+ dv = torch.empty_like(v)
735
+ BATCH, N_HEAD, N_CTX = q.shape[:3]
736
+ PRE_BLOCK = 128
737
+ NUM_WARPS, NUM_STAGES = 4, 5
738
+ BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
739
+ BLK_SLICE_FACTOR = 2
740
+ RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
741
+ arg_k = k
742
+ arg_k = arg_k * (ctx.sm_scale * RCP_LN2)
743
+ PRE_BLOCK = 128
744
+ assert N_CTX % PRE_BLOCK == 0
745
+ pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
746
+ delta = torch.empty_like(M)
747
+ _attn_bwd_preprocess[pre_grid](
748
+ o, do, #
749
+ delta, #
750
+ BATCH, N_HEAD, N_CTX, #
751
+ BLOCK_M=PRE_BLOCK, HEAD_DIM=ctx.HEAD_DIM #
752
+ )
753
+ grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)
754
+ _attn_bwd[grid](
755
+ q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, #
756
+ M, delta, #
757
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
758
+ N_HEAD, N_CTX, #
759
+ BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, #
760
+ BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, #
761
+ BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, #
762
+ HEAD_DIM=ctx.HEAD_DIM, #
763
+ num_warps=NUM_WARPS, #
764
+ num_stages=NUM_STAGES #
765
+ )
766
+
767
+ return dq, dk, dv, None, None
768
+
769
+ def rotate_half(x):
770
+ """Rotates half the hidden dims of the input."""
771
+ x1 = x[..., : x.shape[-1] // 2]
772
+ x2 = x[..., x.shape[-1] // 2 :]
773
+ return torch.cat((-x2, x1), dim=-1)
774
+
775
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
776
+ """Applies Rotary Position Embedding to the query and key tensors.
777
+
778
+ Args:
779
+ q (`torch.Tensor`): The query tensor.
780
+ k (`torch.Tensor`): The key tensor.
781
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
782
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
783
+ position_ids (`torch.Tensor`, *optional*):
784
+ Deprecated and unused.
785
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
786
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
787
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
788
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
789
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
790
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
791
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
792
+ Returns:
793
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
794
+ """
795
+ cos = cos.unsqueeze(unsqueeze_dim)
796
+ sin = sin.unsqueeze(unsqueeze_dim)
797
+ q_embed = (q * cos) + (rotate_half(q) * sin)
798
+ k_embed = (k * cos) + (rotate_half(k) * sin)
799
+ return q_embed, k_embed
800
+
801
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
802
+ """
803
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
804
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
805
+ """
806
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
807
+ if n_rep == 1:
808
+ return hidden_states
809
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
810
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
811
+
812
+
813
+ class LlamaAttention(nn.Module):
814
+ q_proj: nn.Linear
815
+ k_proj: nn.Linear
816
+ v_proj: nn.Linear
817
+ o_proj: nn.Linear
818
+ scaling: float
819
+ attention_dropout: float
820
+ is_causal: bool
821
+ layer_idx: int
822
+ num_key_value_groups: int
823
+ num_attention_heads: int
824
+ num_key_value_heads: int
825
+ head_dim: int
826
+ def forward(
827
+ self,
828
+ hidden_states: torch.Tensor,
829
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
830
+ attention_mask: Optional[torch.Tensor],
831
+ past_key_value: Optional[Cache] = None,
832
+ cache_position: Optional[torch.LongTensor] = None,
833
+ **kwargs: Unpack[FlashAttentionKwargs],
834
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
835
+ input_shape_orig = hidden_states.shape
836
+ input_shape = hidden_states.shape[:-1]
837
+ hidden_shape = (*input_shape, -1, self.head_dim)
838
+
839
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
840
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
841
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
842
+ # key_states = repeat_kv(key_states, self.num_key_value_groups)
843
+ # value_states = repeat_kv(value_states, self.num_key_value_groups)
844
+ cos, sin = position_embeddings
845
+
846
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
847
+
848
+ if past_key_value is not None:
849
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
850
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
851
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
852
+
853
+ # query_states = query_states.transpose(1, 2).view(*input_shape, self.head_dim*self.num_attention_heads)
854
+ # key_states = key_states.transpose(1, 2).view(*input_shape, self.head_dim*self.num_key_value_heads)
855
+ # value_states = value_states.transpose(1, 2).view(*input_shape, self.head_dim*self.num_key_value_heads)
856
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
857
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
858
+ attn_output = Attention.apply(
859
+ query_states,
860
+ key_states,
861
+ value_states,
862
+ self.is_causal,
863
+ self.scaling,
864
+ **kwargs,
865
+ )
866
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
867
+ attn_output = self.o_proj(attn_output)
868
+ return attn_output, None
869
+
870
+ def eager_attention_forward(
871
+ query: torch.Tensor,
872
+ key: torch.Tensor,
873
+ value: torch.Tensor,
874
+ attention_mask: Optional[torch.Tensor],
875
+ scaling: float,
876
+ dropout: float = 0.0,
877
+ num_key_value_groups: int = 1,
878
+ training: bool = False,
879
+ **kwargs,
880
+ ):
881
+ key_states = repeat_kv(key, num_key_value_groups)
882
+ value_states = repeat_kv(value, num_key_value_groups)
883
+
884
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
885
+ if attention_mask is not None:
886
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
887
+ attn_weights = attn_weights + causal_mask
888
+
889
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
890
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=training)
891
+ attn_output = torch.matmul(attn_weights, value_states)
892
+ # attn_output = attn_output.transpose(1, 2).contiguous()
893
+
894
+ return attn_output, attn_weights
895
+
896
+ class HFLlamaAttention(nn.Module):
897
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
898
+
899
+ def __init__(self, config: LlamaConfig, layer_idx: int, device: str):
900
+ super().__init__()
901
+ self.config = config
902
+ self.layer_idx = layer_idx
903
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
904
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
905
+ self.scaling = self.head_dim**-0.5
906
+ self.attention_dropout = config.attention_dropout
907
+ self.is_causal = True
908
+
909
+ self.q_proj = nn.Linear(
910
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias, device=device
911
+ )
912
+ self.k_proj = nn.Linear(
913
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias, device=device
914
+ )
915
+ self.v_proj = nn.Linear(
916
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias, device=device
917
+ )
918
+ self.o_proj = nn.Linear(
919
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias, device=device
920
+ )
921
+
922
+ def forward(
923
+ self,
924
+ hidden_states: torch.Tensor,
925
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
926
+ attention_mask: Optional[torch.Tensor],
927
+ past_key_value: Optional[Cache] = None,
928
+ cache_position: Optional[torch.LongTensor] = None,
929
+ **kwargs: Unpack[FlashAttentionKwargs],
930
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
931
+ input_shape = hidden_states.shape[:-1]
932
+ hidden_shape = (*input_shape, -1, self.head_dim)
933
+
934
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
935
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
936
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
937
+
938
+ cos, sin = position_embeddings
939
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
940
+
941
+ if past_key_value is not None:
942
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
943
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
944
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
945
+
946
+ attention_interface: Callable = eager_attention_forward
947
+
948
+ attn_output, attn_weights = attention_interface(
949
+ self,
950
+ query_states,
951
+ key_states,
952
+ value_states,
953
+ attention_mask,
954
+ dropout=0.0 if not self.training else self.attention_dropout,
955
+ scaling=self.scaling,
956
+ **kwargs,
957
+ )
958
+
959
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
960
+ attn_output = self.o_proj(attn_output)
961
+ return attn_output, attn_weights
962
+
963
+ def attn_forward_kernel(
964
+ query: torch.Tensor,
965
+ key: torch.Tensor,
966
+ value: torch.Tensor,
967
+ scaling: float,
968
+ causal: bool,
969
+ ):
970
+ return Attention.apply(query, key, value, causal, scaling)
971
+
972
+ # def test_llama_attention_output():
973
+ # """
974
+ # Test to verify that the LlamaAttention module produces correct outputs by comparing
975
+ # with the reference implementation from transformers.
976
+ # """
977
+ # import torch
978
+
979
+ # # Set up test parameters
980
+ # batch_size = 2
981
+ # seq_len = 16
982
+ # model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct", device_map="cuda:0")
983
+ # # print("######################### model", model)
984
+ # config = model.config
985
+ # # Create a custom LlamaAttention instance
986
+ # custom_attn = LlamaAttention()
987
+ # custom_attn.num_attention_heads = config.num_attention_heads
988
+ # custom_attn.num_key_value_heads = config.num_key_value_heads
989
+ # custom_attn.head_dim = config.head_dim
990
+ # custom_attn.hidden_size = config.hidden_size
991
+ # custom_attn.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * config.head_dim, bias=False, device="cuda:0")
992
+ # custom_attn.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * config.head_dim, bias=False, device="cuda:0")
993
+ # custom_attn.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * config.head_dim, bias=False, device="cuda:0")
994
+ # custom_attn.o_proj = nn.Linear(config.num_attention_heads * config.head_dim, config.hidden_size, bias=False, device="cuda:0")
995
+ # print("######################### custom_attn.o_proj", custom_attn.o_proj.weight.data.shape)
996
+ # custom_attn.scaling = 1.0 / (config.head_dim ** 0.5)
997
+ # custom_attn.attention_dropout = 0.0
998
+ # custom_attn.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
999
+ # custom_attn.is_causal = True
1000
+ # custom_attn.layer_idx = 0
1001
+
1002
+ # # Create a reference HF LlamaAttention instance
1003
+ # hf_attn = HFLlamaAttention(
1004
+ # config=config,
1005
+ # layer_idx=0,
1006
+ # device="cuda:0"
1007
+ # )
1008
+
1009
+ # # Copy weights to ensure identical parameters
1010
+ # hf_attn.q_proj.weight.data.copy_(custom_attn.q_proj.weight.data)
1011
+ # hf_attn.k_proj.weight.data.copy_(custom_attn.k_proj.weight.data)
1012
+ # hf_attn.v_proj.weight.data.copy_(custom_attn.v_proj.weight.data)
1013
+ # hf_attn.o_proj.weight.data.copy_(custom_attn.o_proj.weight.data)
1014
+
1015
+ # # Create test inputs
1016
+ # hidden_states = torch.randn(batch_size, seq_len, config.hidden_size).to("cuda:0")
1017
+
1018
+ # # Create position embeddings (cos, sin)
1019
+ # cos = torch.cos(torch.arange(0, config.head_dim).float() / 10000.0).unsqueeze(0).to("cuda:0")
1020
+ # sin = torch.sin(torch.arange(0, config.head_dim).float() / 10000.0).unsqueeze(0).to("cuda:0")
1021
+ # position_embeddings = (cos, sin)
1022
+
1023
+ # # Create attention mask
1024
+ # attention_mask = torch.ones(batch_size, seq_len, ).to("cuda:0")
1025
+
1026
+ # # Run custom attention
1027
+ # custom_output, _ = custom_attn(
1028
+ # hidden_states=hidden_states,
1029
+ # position_embeddings=position_embeddings,
1030
+ # attention_mask=None
1031
+ # )
1032
+
1033
+ # # Run HF attention
1034
+ # hf_output = hf_attn(
1035
+ # hidden_states=hidden_states,
1036
+ # attention_mask=None,
1037
+ # position_embeddings=position_embeddings,
1038
+ # )[0]
1039
+
1040
+ # # Check if outputs are close
1041
+ # assert torch.allclose(custom_output, hf_output, atol=1e-5), \
1042
+ # f"Custom attention output doesn't match reference. Max diff: {(custom_output - hf_output).abs().max()}"
1043
+
1044
+ # print("LlamaAttention test passed! Custom implementation matches reference.")
1045
+ # return True
1046
+
1047
+
1048
+ # def test_triton_attention_vs_eager():
1049
+ # """
1050
+ # Test that compares the output of the Triton-based Attention implementation
1051
+ # with the eager implementation from HuggingFace.
1052
+ # """
1053
+ # import torch
1054
+ # from transformers import LlamaConfig
1055
+
1056
+ # # Set up test parameters
1057
+ # batch_size = 2
1058
+ # seq_len = 16
1059
+
1060
+ # # Create a simple config
1061
+ # model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct", device_map="cuda:0")
1062
+ # # print("######################### model", model)
1063
+ # config = model.config
1064
+
1065
+ # # Create inputs
1066
+ # q = torch.randn(batch_size, config.num_attention_heads, seq_len, config.head_dim, device="cuda")
1067
+ # k = torch.randn(batch_size, config.num_key_value_heads, seq_len, config.head_dim, device="cuda")
1068
+ # v = torch.randn(batch_size, config.num_key_value_heads, seq_len, config.head_dim, device="cuda")
1069
+
1070
+ # # Make inputs contiguous
1071
+ # q = q.contiguous()
1072
+ # k = k.contiguous()
1073
+ # v = v.contiguous()
1074
+
1075
+ # # Scaling factor
1076
+ # sm_scale = 1.0 / (config.head_dim ** 0.5)
1077
+
1078
+ # # Run Triton-based attention
1079
+ # causal = True
1080
+ # triton_output = Attention.apply(q, k, v, causal, sm_scale)
1081
+
1082
+ # # Run eager implementation
1083
+ # # Create a dummy attention mask for causal attention
1084
+ # attention_mask = torch.ones(batch_size, 1, seq_len, seq_len, device="cuda")
1085
+ # # if causal:
1086
+ # # # Create a causal mask (lower triangular)
1087
+ # # attention_mask = torch.tril(torch.ones((seq_len, seq_len), device="cuda"))
1088
+
1089
+ # # Compute attention scores
1090
+ # eager_output, _ = eager_attention_forward(
1091
+ # query=q,
1092
+ # key=k,
1093
+ # value=v,
1094
+ # attention_mask=attention_mask,
1095
+ # scaling=sm_scale,
1096
+ # num_key_value_groups=config.num_attention_heads // config.num_key_value_heads,
1097
+ # training=False,
1098
+ # )
1099
+ # print("######################### triton_output", triton_output.shape)
1100
+ # print("######################### eager_output", eager_output.shape)
1101
+ # print("######################### triton_output", triton_output)
1102
+ # print("######################### eager_output", eager_output)
1103
+ # is_close = torch.allclose(triton_output, eager_output, atol=1e-4)
1104
+ # print(f"Triton attention matches eager implementation: {is_close}")
1105
+
1106
+ # return is_close
1107
+
1108
+ # attention = Attention.apply
1109
+ # DEVICE = "cuda:0"
1110
+
1111
+ # import pytest
1112
+ # @pytest.mark.parametrize("Z, H, N_CTX, HEAD_DIM", [(1, 2, 1024, 64)])
1113
+ # @pytest.mark.parametrize("causal", [True])
1114
+ # def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16):
1115
+ # torch.manual_seed(20)
1116
+ # q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
1117
+ # k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
1118
+ # v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
1119
+ # sm_scale = 0.5
1120
+ # dout = torch.randn_like(q)
1121
+ # # reference implementation
1122
+ # M = torch.tril(torch.ones((N_CTX, N_CTX), device=DEVICE))
1123
+ # p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
1124
+ # if causal:
1125
+ # p[:, :, M == 0] = float("-inf")
1126
+ # p = torch.softmax(p.float(), dim=-1).half()
1127
+ # # p = torch.exp(p)
1128
+ # ref_out = torch.matmul(p, v)
1129
+ # ref_out.backward(dout)
1130
+ # ref_dv, v.grad = v.grad.clone(), None
1131
+ # ref_dk, k.grad = k.grad.clone(), None
1132
+ # ref_dq, q.grad = q.grad.clone(), None
1133
+ # # triton implementation
1134
+ # tri_out = attention(q, k, v, causal, sm_scale).half()
1135
+ # tri_out.backward(dout)
1136
+ # tri_dv, v.grad = v.grad.clone(), None
1137
+ # tri_dk, k.grad = k.grad.clone(), None
1138
+ # tri_dq, q.grad = q.grad.clone(), None
1139
+ # # compare
1140
+ # assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0)
1141
+ # rtol = 0.0
1142
+ # # Relative tolerance workaround for known hardware limitation of CDNA2 GPU.
1143
+ # # For details see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices
1144
+ # if torch.version.hip is not None and triton.runtime.driver.active.get_current_target().arch == "gfx90a":
1145
+ # rtol = 1e-2
1146
+ # assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=rtol)
1147
+ # assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=rtol)
1148
+ # assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=rtol)
build/torch-universal/triton_llama_attn/layers.py ADDED
File without changes
flake.nix ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ description = "Flake for Triton layer norm kernels";
3
+
4
+ inputs = {
5
+ kernel-builder.url = "github:huggingface/kernel-builder";
6
+ };
7
+
8
+ outputs =
9
+ {
10
+ self,
11
+ kernel-builder,
12
+ }:
13
+ kernel-builder.lib.genFlakeOutputs {
14
+ path = ./.;
15
+ rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
16
+ };
17
+ }
torch-ext/triton_llama_attn/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .attn import attn_forward_kernel
2
+
3
+ __all__ = ["attn_forward_kernel"]
torch-ext/triton_llama_attn/attn.py ADDED
@@ -0,0 +1,1148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers.models.llama.configuration_llama import LlamaConfig
3
+ from transformers import AutoModelForCausalLM
4
+ import triton.tools.experimental_descriptor
5
+ from typing import Tuple, Optional, Callable
6
+ import triton
7
+ import triton.language as tl
8
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
9
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
10
+ from transformers import Cache
11
+ import torch.nn as nn
12
+ from transformers.processing_utils import Unpack
13
+ # ENABLE_LHS_TO_TMEM is an experimental environment variable for Blackwell.
14
+ # If it is set to 1 it can improve performance of Blackwell attention. However,
15
+ # it defaults to 0 as it is known to cause correctness issues outside of the
16
+ # _attn_fwd_tma kernel below.
17
+
18
+ # DEVICE = triton.runtime.driver.active.get_active_torch_device()
19
+
20
+
21
+ def is_hip():
22
+ return triton.runtime.driver.active.get_current_target().backend == "hip"
23
+
24
+
25
+ def is_cuda():
26
+ return triton.runtime.driver.active.get_current_target().backend == "cuda"
27
+
28
+
29
+ def supports_tma():
30
+ return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
31
+
32
+
33
+ HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl)
34
+
35
+ if HAS_TMA_DESC:
36
+ print("TMA benchmarks will be running with experimental grid constant TMA descriptor.", )
37
+ else:
38
+ print("TMA benchmarks will be running without grid constant TMA descriptor.", )
39
+
40
+
41
+ # TmaAutoTuneHelper used in htyu's PR #5622
42
+ class TmaAutoTuneHelper:
43
+
44
+ # duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498
45
+ class KernelParamWrapper:
46
+
47
+ def __init__(self, desc):
48
+ self.desc = desc
49
+
50
+ def tma_desc_cpu_ptr(self):
51
+ return self.desc.data_ptr()
52
+
53
+ TMA_SIZE = 128
54
+
55
+ def __init__(self):
56
+ self.fill_1d_tma_descriptor_inner = (triton.runtime.driver.active.utils.fill_1d_tma_descriptor)
57
+ self.fill_2d_tma_descriptor_inner = (triton.runtime.driver.active.utils.fill_2d_tma_descriptor)
58
+ if HAS_TMA_DESC:
59
+ self.descriptors = {}
60
+ else:
61
+ self.cuda_descriptors = {}
62
+
63
+ # Call this method outside of the lambda function for grid size
64
+ def init_tma_descriptor(self, name):
65
+ if HAS_TMA_DESC:
66
+ self.descriptors[name] = torch.empty(TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8)
67
+ else:
68
+ self.cuda_descriptors[name] = torch.empty(TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8)
69
+
70
+ # Call this method inside the lambda function for grid size
71
+ def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size):
72
+ if HAS_TMA_DESC:
73
+ desc_x = self.descriptors[name]
74
+ assert desc_x.data_ptr() % 64 == 0
75
+ self.fill_1d_tma_descriptor_inner(ptr, dim, block_dim, element_size, desc_x.data_ptr())
76
+ else:
77
+ desc_x = self.cuda_descriptors[name]
78
+ buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
79
+ self.fill_1d_tma_descriptor_inner(ptr, dim, block_dim, element_size, buf_x.data_ptr())
80
+ desc_x.copy_(buf_x, non_blocking=True)
81
+
82
+ # Call this method inside the lambda function for grid size
83
+ def fill_2d_tma_descriptor(self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size):
84
+ if HAS_TMA_DESC:
85
+ desc_x = self.descriptors[name]
86
+ assert desc_x.data_ptr() % 64 == 0
87
+ self.fill_2d_tma_descriptor_inner(ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr())
88
+ else:
89
+ desc_x = self.cuda_descriptors[name]
90
+ buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
91
+ self.fill_2d_tma_descriptor_inner(ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr())
92
+ desc_x.copy_(buf_x, non_blocking=True)
93
+
94
+ def get_tma_descriptor_kernel_param(self, name):
95
+ if HAS_TMA_DESC:
96
+ assert self.descriptors[name] is not None
97
+ return self.KernelParamWrapper(self.descriptors[name])
98
+ else:
99
+ assert self.cuda_descriptors[name] is not None
100
+ return self.cuda_descriptors[name]
101
+
102
+
103
+ @triton.jit
104
+ def _attn_fwd_inner(acc, l_i, m_i, q, #
105
+ K_block_ptr, V_block_ptr, #
106
+ start_m, qk_scale, #
107
+ BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, #
108
+ STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, #
109
+ N_CTX: tl.constexpr, fp8_v: tl.constexpr):
110
+ # range of values handled by this stage
111
+ if STAGE == 1:
112
+ lo, hi = 0, start_m * BLOCK_M
113
+ elif STAGE == 2:
114
+ lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
115
+ lo = tl.multiple_of(lo, BLOCK_M)
116
+ # causal = False
117
+ else:
118
+ lo, hi = 0, N_CTX
119
+ K_block_ptr = tl.advance(K_block_ptr, (0, lo))
120
+ V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
121
+ # loop over k, v and update accumulator
122
+ for start_n in range(lo, hi, BLOCK_N):
123
+ start_n = tl.multiple_of(start_n, BLOCK_N)
124
+ # -- compute qk ----
125
+ k = tl.load(K_block_ptr)
126
+ qk = tl.dot(q, k)
127
+ if STAGE == 2:
128
+ mask = offs_m[:, None] >= (start_n + offs_n[None, :])
129
+ qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)
130
+ m_ij = tl.maximum(m_i, tl.max(qk, 1))
131
+ qk -= m_ij[:, None]
132
+ else:
133
+ m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
134
+ qk = qk * qk_scale - m_ij[:, None]
135
+ p = tl.math.exp2(qk)
136
+ l_ij = tl.sum(p, 1)
137
+ # -- update m_i and l_i
138
+ alpha = tl.math.exp2(m_i - m_ij)
139
+ l_i = l_i * alpha + l_ij
140
+ # -- update output accumulator --
141
+ acc = acc * alpha[:, None]
142
+ # update acc
143
+ v = tl.load(V_block_ptr)
144
+ if fp8_v:
145
+ p = p.to(tl.float8e5)
146
+ else:
147
+ p = p.to(tl.float16)
148
+ acc = tl.dot(p, v, acc)
149
+ # update m_i and l_i
150
+ m_i = m_ij
151
+ V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
152
+ K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
153
+ return acc, l_i, m_i
154
+
155
+
156
+ @triton.jit
157
+ def _attn_fwd_inner_tma(acc, l_i, m_i, q, #
158
+ desc_k, desc_v, #
159
+ offset_y, dtype: tl.constexpr, start_m, qk_scale, #
160
+ BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, #
161
+ STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, #
162
+ N_CTX: tl.constexpr):
163
+ # range of values handled by this stage
164
+ if STAGE == 1:
165
+ lo, hi = 0, start_m * BLOCK_M
166
+ elif STAGE == 2:
167
+ lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
168
+ lo = tl.multiple_of(lo, BLOCK_M)
169
+ # causal = False
170
+ else:
171
+ lo, hi = 0, N_CTX
172
+ offsetkv_y = offset_y + lo
173
+ # loop over k, v and update accumulator
174
+ for start_n in range(lo, hi, BLOCK_N):
175
+ start_n = tl.multiple_of(start_n, BLOCK_N)
176
+ # -- compute qk ----
177
+ k = tl._experimental_descriptor_load(desc_k, [offsetkv_y, 0], [BLOCK_N, HEAD_DIM], dtype).T
178
+ qk = tl.dot(q, k)
179
+ if STAGE == 2:
180
+ mask = offs_m[:, None] >= (start_n + offs_n[None, :])
181
+ qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)
182
+ m_ij = tl.maximum(m_i, tl.max(qk, 1))
183
+ qk -= m_ij[:, None]
184
+ else:
185
+ m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
186
+ qk = qk * qk_scale - m_ij[:, None]
187
+ p = tl.math.exp2(qk)
188
+ l_ij = tl.sum(p, 1)
189
+ # -- update m_i and l_i
190
+ alpha = tl.math.exp2(m_i - m_ij)
191
+ l_i = l_i * alpha + l_ij
192
+ # -- update output accumulator --
193
+ acc = acc * alpha[:, None]
194
+ # update acc
195
+ v = tl._experimental_descriptor_load(desc_v, [offsetkv_y, 0], [BLOCK_N, HEAD_DIM], dtype)
196
+ p = p.to(dtype)
197
+ # note that this non transposed v for FP8 is only supported on Blackwell
198
+ acc = tl.dot(p, v, acc)
199
+ # update m_i and l_i
200
+ m_i = m_ij
201
+ offsetkv_y += BLOCK_N
202
+ return acc, l_i, m_i
203
+
204
+
205
+ # We don't run auto-tuning every time to keep the tutorial fast. Keeping
206
+ # the code below and commenting out the equivalent parameters is convenient for
207
+ # re-tuning.
208
+ configs = [
209
+ triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \
210
+ for BM in [64, 128]\
211
+ for BN in [32, 64]\
212
+ for s in ([1] if is_hip() else [3, 4, 7])\
213
+ for w in [4, 8]\
214
+ ]
215
+
216
+
217
+ def keep(conf):
218
+ BLOCK_M = conf.kwargs["BLOCK_M"]
219
+ BLOCK_N = conf.kwargs["BLOCK_N"]
220
+ if BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8:
221
+ return False
222
+ return True
223
+
224
+
225
+ @triton.autotune(list(filter(keep, configs)), key=["N_CTX", "HEAD_DIM"])
226
+ @triton.jit
227
+ def _attn_fwd(Q, K, V, sm_scale, M, Out, #
228
+ stride_qz, stride_qh, stride_qm, stride_qk, #
229
+ stride_kz, stride_kh, stride_kn, stride_kk, #
230
+ stride_vz, stride_vh, stride_vk, stride_vn, #
231
+ stride_oz, stride_oh, stride_om, stride_on, #
232
+ Z, H, N_CTX, #
233
+ HEAD_DIM: tl.constexpr, #
234
+ BLOCK_M: tl.constexpr, #
235
+ BLOCK_N: tl.constexpr, #
236
+ STAGE: tl.constexpr #
237
+ ):
238
+ tl.static_assert(BLOCK_N <= HEAD_DIM)
239
+ start_m = tl.program_id(0)
240
+ off_hz = tl.program_id(1)
241
+ off_z = off_hz // H
242
+ off_h = off_hz % H
243
+ qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
244
+
245
+ # block pointers
246
+ Q_block_ptr = tl.make_block_ptr(
247
+ base=Q + qvk_offset,
248
+ shape=(N_CTX, HEAD_DIM),
249
+ strides=(stride_qm, stride_qk),
250
+ offsets=(start_m * BLOCK_M, 0),
251
+ block_shape=(BLOCK_M, HEAD_DIM),
252
+ order=(1, 0),
253
+ )
254
+ v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0)
255
+ V_block_ptr = tl.make_block_ptr(
256
+ base=V + qvk_offset,
257
+ shape=(N_CTX, HEAD_DIM),
258
+ strides=(stride_vk, stride_vn),
259
+ offsets=(0, 0),
260
+ block_shape=(BLOCK_N, HEAD_DIM),
261
+ order=v_order,
262
+ )
263
+ K_block_ptr = tl.make_block_ptr(
264
+ base=K + qvk_offset,
265
+ shape=(HEAD_DIM, N_CTX),
266
+ strides=(stride_kk, stride_kn),
267
+ offsets=(0, 0),
268
+ block_shape=(HEAD_DIM, BLOCK_N),
269
+ order=(0, 1),
270
+ )
271
+ O_block_ptr = tl.make_block_ptr(
272
+ base=Out + qvk_offset,
273
+ shape=(N_CTX, HEAD_DIM),
274
+ strides=(stride_om, stride_on),
275
+ offsets=(start_m * BLOCK_M, 0),
276
+ block_shape=(BLOCK_M, HEAD_DIM),
277
+ order=(1, 0),
278
+ )
279
+ # initialize offsets
280
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
281
+ offs_n = tl.arange(0, BLOCK_N)
282
+ # initialize pointer to m and l
283
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
284
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
285
+ acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
286
+ # load scales
287
+ qk_scale = sm_scale
288
+ qk_scale *= 1.44269504 # 1/log(2)
289
+ # load q: it will stay in SRAM throughout
290
+ q = tl.load(Q_block_ptr)
291
+ # stage 1: off-band
292
+ # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
293
+ # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
294
+ if STAGE & 1:
295
+ acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, #
296
+ start_m, qk_scale, #
297
+ BLOCK_M, HEAD_DIM, BLOCK_N, #
298
+ 4 - STAGE, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 #
299
+ )
300
+ # stage 2: on-band
301
+ if STAGE & 2:
302
+ # barrier makes it easier for compielr to schedule the
303
+ # two loops independently
304
+ acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, #
305
+ start_m, qk_scale, #
306
+ BLOCK_M, HEAD_DIM, BLOCK_N, #
307
+ 2, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 #
308
+ )
309
+ # epilogue
310
+ m_i += tl.math.log2(l_i)
311
+ acc = acc / l_i[:, None]
312
+ m_ptrs = M + off_hz * N_CTX + offs_m
313
+ tl.store(m_ptrs, m_i)
314
+ tl.store(O_block_ptr, acc.to(Out.type.element_ty))
315
+
316
+
317
+ # We don't run auto-tuning every time to keep the tutorial fast. Keeping
318
+ # the code below and commenting out the equivalent parameters is convenient for
319
+ # re-tuning.
320
+ configs_tma = [
321
+ triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \
322
+ for BM in [64, 128]\
323
+ for BN in [32, 64, 128]\
324
+ for s in [2, 3, 4, 6]\
325
+ for w in [4, 8]\
326
+ ]
327
+
328
+
329
+ def keep_tma(conf):
330
+ BLOCK_M = conf.kwargs["BLOCK_M"]
331
+ BLOCK_N = conf.kwargs["BLOCK_N"]
332
+ if (torch.cuda.get_device_capability()[0] == 9 and BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8):
333
+ return False
334
+ return True
335
+
336
+
337
+ @triton.autotune(configs=list(filter(keep_tma, configs_tma)), key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT"])
338
+ @triton.jit
339
+ def _attn_fwd_tma(sm_scale, M, #
340
+ Z, H, desc_q, desc_k, desc_v, desc_o, N_CTX, #
341
+ HEAD_DIM: tl.constexpr, #
342
+ BLOCK_M: tl.constexpr, #
343
+ BLOCK_N: tl.constexpr, #
344
+ FP8_OUTPUT: tl.constexpr, #
345
+ STAGE: tl.constexpr #
346
+ ):
347
+ dtype = tl.float8e5 if FP8_OUTPUT else tl.float16
348
+ tl.static_assert(BLOCK_N <= HEAD_DIM)
349
+ start_m = tl.program_id(0)
350
+ off_hz = tl.program_id(1)
351
+ off_z = off_hz // H
352
+ off_h = off_hz % H
353
+
354
+ offset_y = off_z + off_h * N_CTX
355
+ qo_offset_y = offset_y + start_m * BLOCK_M
356
+ # initialize offsets
357
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
358
+ offs_n = tl.arange(0, BLOCK_N)
359
+ # initialize pointer to m and l
360
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
361
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
362
+ acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
363
+ # load scales
364
+ qk_scale = sm_scale
365
+ qk_scale *= 1.44269504 # 1/log(2)
366
+ # load q: it will stay in SRAM throughout
367
+ q = tl._experimental_descriptor_load(desc_q, [qo_offset_y, 0], [BLOCK_M, HEAD_DIM], dtype)
368
+ # stage 1: off-band
369
+ # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
370
+ # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
371
+ if STAGE & 1:
372
+ acc, l_i, m_i = _attn_fwd_inner_tma(acc, l_i, m_i, q, #
373
+ desc_k, desc_v, #
374
+ offset_y, dtype, start_m, qk_scale, #
375
+ BLOCK_M, HEAD_DIM, BLOCK_N, #
376
+ 4 - STAGE, offs_m, offs_n, N_CTX, #
377
+ )
378
+ # stage 2: on-band
379
+ if STAGE & 2:
380
+ # barrier makes it easier for compielr to schedule the
381
+ # two loops independently
382
+ acc, l_i, m_i = _attn_fwd_inner_tma(acc, l_i, m_i, q, #
383
+ desc_k, desc_v, #
384
+ offset_y, dtype, start_m, qk_scale, #
385
+ BLOCK_M, HEAD_DIM, BLOCK_N, #
386
+ 2, offs_m, offs_n, N_CTX, #
387
+ )
388
+ # epilogue
389
+ m_i += tl.math.log2(l_i)
390
+ acc = acc / l_i[:, None]
391
+ m_ptrs = M + off_hz * N_CTX + offs_m
392
+ tl.store(m_ptrs, m_i)
393
+ tl._experimental_descriptor_store(desc_o, acc.to(dtype), [qo_offset_y, 0])
394
+
395
+
396
+ @triton.jit
397
+ def _attn_bwd_preprocess(O, DO, #
398
+ Delta, #
399
+ Z, H, N_CTX, #
400
+ BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr #
401
+ ):
402
+ off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
403
+ off_hz = tl.program_id(1)
404
+ off_n = tl.arange(0, HEAD_DIM)
405
+ # load
406
+ o = tl.load(O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :])
407
+ do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]).to(tl.float32)
408
+ delta = tl.sum(o * do, axis=1)
409
+ # write-back
410
+ tl.store(Delta + off_hz * N_CTX + off_m, delta)
411
+
412
+
413
+ # The main inner-loop logic for computing dK and dV.
414
+ @triton.jit
415
+ def _attn_bwd_dkdv(dk, dv, #
416
+ Q, k, v, sm_scale, #
417
+ DO, #
418
+ M, D, #
419
+ # shared by Q/K/V/DO.
420
+ stride_tok, stride_d, #
421
+ H, N_CTX, BLOCK_M1: tl.constexpr, #
422
+ BLOCK_N1: tl.constexpr, #
423
+ HEAD_DIM: tl.constexpr, #
424
+ # Filled in by the wrapper.
425
+ start_n, start_m, num_steps, #
426
+ MASK: tl.constexpr):
427
+ offs_m = start_m + tl.arange(0, BLOCK_M1)
428
+ offs_n = start_n + tl.arange(0, BLOCK_N1)
429
+ offs_k = tl.arange(0, HEAD_DIM)
430
+ qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d
431
+ do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
432
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
433
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
434
+ curr_m = start_m
435
+ step_m = BLOCK_M1
436
+ for blk_idx in range(num_steps):
437
+ qT = tl.load(qT_ptrs)
438
+ # Load m before computing qk to reduce pipeline stall.
439
+ offs_m = curr_m + tl.arange(0, BLOCK_M1)
440
+ m = tl.load(M + offs_m)
441
+ qkT = tl.dot(k, qT)
442
+ pT = tl.math.exp2(qkT - m[None, :])
443
+ # Autoregressive masking.
444
+ if MASK:
445
+ mask = (offs_m[None, :] >= offs_n[:, None])
446
+ pT = tl.where(mask, pT, 0.0)
447
+ do = tl.load(do_ptrs)
448
+ # Compute dV.
449
+ ppT = pT
450
+ ppT = ppT.to(tl.float16)
451
+ dv += tl.dot(ppT, do)
452
+ # D (= delta) is pre-divided by ds_scale.
453
+ Di = tl.load(D + offs_m)
454
+ # Compute dP and dS.
455
+ dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
456
+ dsT = pT * (dpT - Di[None, :])
457
+ dsT = dsT.to(tl.float16)
458
+ dk += tl.dot(dsT, tl.trans(qT))
459
+ # Increment pointers.
460
+ curr_m += step_m
461
+ qT_ptrs += step_m * stride_tok
462
+ do_ptrs += step_m * stride_tok
463
+ return dk, dv
464
+
465
+
466
+ # the main inner-loop logic for computing dQ
467
+ @triton.jit
468
+ def _attn_bwd_dq(dq, q, K, V, #
469
+ do, m, D,
470
+ # shared by Q/K/V/DO.
471
+ stride_tok, stride_d, #
472
+ H, N_CTX, #
473
+ BLOCK_M2: tl.constexpr, #
474
+ BLOCK_N2: tl.constexpr, #
475
+ HEAD_DIM: tl.constexpr,
476
+ # Filled in by the wrapper.
477
+ start_m, start_n, num_steps, #
478
+ MASK: tl.constexpr):
479
+ offs_m = start_m + tl.arange(0, BLOCK_M2)
480
+ offs_n = start_n + tl.arange(0, BLOCK_N2)
481
+ offs_k = tl.arange(0, HEAD_DIM)
482
+ kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
483
+ vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
484
+ # D (= delta) is pre-divided by ds_scale.
485
+ Di = tl.load(D + offs_m)
486
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
487
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
488
+ curr_n = start_n
489
+ step_n = BLOCK_N2
490
+ for blk_idx in range(num_steps):
491
+ kT = tl.load(kT_ptrs)
492
+ vT = tl.load(vT_ptrs)
493
+ qk = tl.dot(q, kT)
494
+ p = tl.math.exp2(qk - m)
495
+ # Autoregressive masking.
496
+ if MASK:
497
+ offs_n = curr_n + tl.arange(0, BLOCK_N2)
498
+ mask = (offs_m[:, None] >= offs_n[None, :])
499
+ p = tl.where(mask, p, 0.0)
500
+ # Compute dP and dS.
501
+ dp = tl.dot(do, vT).to(tl.float32)
502
+ ds = p * (dp - Di[:, None])
503
+ ds = ds.to(tl.float16)
504
+ # Compute dQ.
505
+ # NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
506
+ dq += tl.dot(ds, tl.trans(kT))
507
+ # Increment pointers.
508
+ curr_n += step_n
509
+ kT_ptrs += step_n * stride_tok
510
+ vT_ptrs += step_n * stride_tok
511
+ return dq
512
+
513
+
514
+ @triton.jit
515
+ def _attn_bwd(Q, K, V, sm_scale, #
516
+ DO, #
517
+ DQ, DK, DV, #
518
+ M, D,
519
+ # shared by Q/K/V/DO.
520
+ stride_z, stride_h, stride_tok, stride_d, #
521
+ H, N_CTX, #
522
+ BLOCK_M1: tl.constexpr, #
523
+ BLOCK_N1: tl.constexpr, #
524
+ BLOCK_M2: tl.constexpr, #
525
+ BLOCK_N2: tl.constexpr, #
526
+ BLK_SLICE_FACTOR: tl.constexpr, #
527
+ HEAD_DIM: tl.constexpr):
528
+ LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
529
+
530
+ bhid = tl.program_id(2)
531
+ off_chz = (bhid * N_CTX).to(tl.int64)
532
+ adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64)
533
+ pid = tl.program_id(0)
534
+
535
+ # offset pointers for batch/head
536
+ Q += adj
537
+ K += adj
538
+ V += adj
539
+ DO += adj
540
+ DQ += adj
541
+ DK += adj
542
+ DV += adj
543
+ M += off_chz
544
+ D += off_chz
545
+
546
+ # load scales
547
+ offs_k = tl.arange(0, HEAD_DIM)
548
+
549
+ start_n = pid * BLOCK_N1
550
+ start_m = start_n
551
+
552
+ MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
553
+ offs_n = start_n + tl.arange(0, BLOCK_N1)
554
+
555
+ dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
556
+ dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
557
+
558
+ # load K and V: they stay in SRAM throughout the inner loop.
559
+ k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
560
+ v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
561
+
562
+ num_steps = BLOCK_N1 // MASK_BLOCK_M1
563
+
564
+ dk, dv = _attn_bwd_dkdv(dk, dv, #
565
+ Q, k, v, sm_scale, #
566
+ DO, #
567
+ M, D, #
568
+ stride_tok, stride_d, #
569
+ H, N_CTX, #
570
+ MASK_BLOCK_M1, BLOCK_N1, HEAD_DIM, #
571
+ start_n, start_m, num_steps, #
572
+ MASK=True #
573
+ )
574
+
575
+ start_m += num_steps * MASK_BLOCK_M1
576
+ num_steps = (N_CTX - start_m) // BLOCK_M1
577
+
578
+ # Compute dK and dV for non-masked blocks.
579
+ dk, dv = _attn_bwd_dkdv( #
580
+ dk, dv, #
581
+ Q, k, v, sm_scale, #
582
+ DO, #
583
+ M, D, #
584
+ stride_tok, stride_d, #
585
+ H, N_CTX, #
586
+ BLOCK_M1, BLOCK_N1, HEAD_DIM, #
587
+ start_n, start_m, num_steps, #
588
+ MASK=False #
589
+ )
590
+
591
+ dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
592
+ tl.store(dv_ptrs, dv)
593
+
594
+ # Write back dK.
595
+ dk *= sm_scale
596
+ dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
597
+ tl.store(dk_ptrs, dk)
598
+
599
+ # THIS BLOCK DOES DQ:
600
+ start_m = pid * BLOCK_M2
601
+ end_n = start_m + BLOCK_M2
602
+
603
+ MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
604
+ offs_m = start_m + tl.arange(0, BLOCK_M2)
605
+
606
+ q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
607
+ dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32)
608
+ do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
609
+
610
+ m = tl.load(M + offs_m)
611
+ m = m[:, None]
612
+
613
+ # Compute dQ for masked (diagonal) blocks.
614
+ # NOTE: This code scans each row of QK^T backward (from right to left,
615
+ # but inside each call to _attn_bwd_dq, from left to right), but that's
616
+ # not due to anything important. I just wanted to reuse the loop
617
+ # structure for dK & dV above as much as possible.
618
+ num_steps = BLOCK_M2 // MASK_BLOCK_N2
619
+ dq = _attn_bwd_dq(dq, q, K, V, #
620
+ do, m, D, #
621
+ stride_tok, stride_d, #
622
+ H, N_CTX, #
623
+ BLOCK_M2, MASK_BLOCK_N2, HEAD_DIM, #
624
+ start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, #
625
+ MASK=True #
626
+ )
627
+ end_n -= num_steps * MASK_BLOCK_N2
628
+ # stage 2
629
+ num_steps = end_n // BLOCK_N2
630
+ dq = _attn_bwd_dq(dq, q, K, V, #
631
+ do, m, D, #
632
+ stride_tok, stride_d, #
633
+ H, N_CTX, #
634
+ BLOCK_M2, BLOCK_N2, HEAD_DIM, #
635
+ start_m, end_n - num_steps * BLOCK_N2, num_steps, #
636
+ MASK=False #
637
+ )
638
+ # Write back dQ.
639
+ dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
640
+ dq *= LN2
641
+ tl.store(dq_ptrs, dq)
642
+
643
+
644
+ class Attention(torch.autograd.Function):
645
+
646
+ @staticmethod
647
+ def forward(ctx, q, k, v, causal, sm_scale, USE_TMA=True):
648
+ # shape constraints
649
+ HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]
650
+ # when v is in float8_e5m2 it is transposed.
651
+ HEAD_DIM_V = v.shape[-1]
652
+ assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
653
+ assert HEAD_DIM_K in {16, 32, 64, 128, 256}
654
+ o = torch.empty_like(q)
655
+ stage = 3 if causal else 1
656
+ extra_kern_args = {}
657
+ # Tuning for AMD target
658
+ if is_hip():
659
+ waves_per_eu = 3 if HEAD_DIM_K <= 64 else 2
660
+ extra_kern_args = {"waves_per_eu": waves_per_eu, "allow_flush_denorm": True}
661
+
662
+ M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
663
+ if USE_TMA and supports_tma() and not (torch.cuda.get_device_capability()[0] == 9
664
+ and q.dtype == torch.float8_e5m2):
665
+ # Note that on Hopper we cannot perform a FP8 dot with a non-transposed second tensor
666
+ y_dim = q.shape[0] * q.shape[1] * q.shape[2]
667
+
668
+ desc_helper = TmaAutoTuneHelper()
669
+ desc_helper.init_tma_descriptor("q")
670
+ desc_helper.init_tma_descriptor("v")
671
+ desc_helper.init_tma_descriptor("k")
672
+ desc_helper.init_tma_descriptor("o")
673
+
674
+ def grid(META):
675
+ nonlocal desc_helper
676
+
677
+ desc_helper.fill_2d_tma_descriptor("q", q.data_ptr(), y_dim, HEAD_DIM_K, META["BLOCK_M"], HEAD_DIM_K,
678
+ q.element_size())
679
+
680
+ desc_helper.fill_2d_tma_descriptor("v", v.data_ptr(), y_dim, HEAD_DIM_K, META["BLOCK_N"], HEAD_DIM_K,
681
+ v.element_size())
682
+
683
+ desc_helper.fill_2d_tma_descriptor("k", k.data_ptr(), y_dim, HEAD_DIM_K, META["BLOCK_N"], HEAD_DIM_K,
684
+ k.element_size())
685
+
686
+ desc_helper.fill_2d_tma_descriptor("o", o.data_ptr(), y_dim, HEAD_DIM_K, META["BLOCK_M"], HEAD_DIM_K,
687
+ o.element_size())
688
+
689
+ return (triton.cdiv(q.shape[2], META["BLOCK_M"]), q.shape[0] * q.shape[1], 1)
690
+
691
+ desc_q = desc_helper.get_tma_descriptor_kernel_param("q")
692
+ desc_v = desc_helper.get_tma_descriptor_kernel_param("v")
693
+ desc_k = desc_helper.get_tma_descriptor_kernel_param("k")
694
+ desc_o = desc_helper.get_tma_descriptor_kernel_param("o")
695
+
696
+ ctx.grid = grid
697
+ _attn_fwd_tma[grid](
698
+ sm_scale, M, #
699
+ q.shape[0], q.shape[1], #
700
+ desc_q, desc_k, desc_v, desc_o, #
701
+ N_CTX=q.shape[2], #
702
+ HEAD_DIM=HEAD_DIM_K, #
703
+ FP8_OUTPUT=q.dtype == torch.float8_e5m2, #
704
+ STAGE=stage, #
705
+ **extra_kern_args)
706
+ else:
707
+ grid = lambda args: (triton.cdiv(q.shape[2], args["BLOCK_M"]), q.shape[0] * q.shape[1], 1)
708
+ ctx.grid = grid
709
+ _attn_fwd[grid](
710
+ q, k, v, sm_scale, M, o, #
711
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
712
+ k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
713
+ v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
714
+ o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
715
+ q.shape[0], q.shape[1], #
716
+ N_CTX=q.shape[2], #
717
+ HEAD_DIM=HEAD_DIM_K, #
718
+ STAGE=stage, #
719
+ **extra_kern_args)
720
+
721
+ ctx.save_for_backward(q, k, v, o, M)
722
+ ctx.sm_scale = sm_scale
723
+ ctx.HEAD_DIM = HEAD_DIM_K
724
+ ctx.causal = causal
725
+ return o
726
+
727
+ @staticmethod
728
+ def backward(ctx, do):
729
+ q, k, v, o, M = ctx.saved_tensors
730
+ assert do.is_contiguous()
731
+ assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
732
+ dq = torch.empty_like(q)
733
+ dk = torch.empty_like(k)
734
+ dv = torch.empty_like(v)
735
+ BATCH, N_HEAD, N_CTX = q.shape[:3]
736
+ PRE_BLOCK = 128
737
+ NUM_WARPS, NUM_STAGES = 4, 5
738
+ BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
739
+ BLK_SLICE_FACTOR = 2
740
+ RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
741
+ arg_k = k
742
+ arg_k = arg_k * (ctx.sm_scale * RCP_LN2)
743
+ PRE_BLOCK = 128
744
+ assert N_CTX % PRE_BLOCK == 0
745
+ pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
746
+ delta = torch.empty_like(M)
747
+ _attn_bwd_preprocess[pre_grid](
748
+ o, do, #
749
+ delta, #
750
+ BATCH, N_HEAD, N_CTX, #
751
+ BLOCK_M=PRE_BLOCK, HEAD_DIM=ctx.HEAD_DIM #
752
+ )
753
+ grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)
754
+ _attn_bwd[grid](
755
+ q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, #
756
+ M, delta, #
757
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
758
+ N_HEAD, N_CTX, #
759
+ BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, #
760
+ BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, #
761
+ BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, #
762
+ HEAD_DIM=ctx.HEAD_DIM, #
763
+ num_warps=NUM_WARPS, #
764
+ num_stages=NUM_STAGES #
765
+ )
766
+
767
+ return dq, dk, dv, None, None
768
+
769
+ def rotate_half(x):
770
+ """Rotates half the hidden dims of the input."""
771
+ x1 = x[..., : x.shape[-1] // 2]
772
+ x2 = x[..., x.shape[-1] // 2 :]
773
+ return torch.cat((-x2, x1), dim=-1)
774
+
775
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
776
+ """Applies Rotary Position Embedding to the query and key tensors.
777
+
778
+ Args:
779
+ q (`torch.Tensor`): The query tensor.
780
+ k (`torch.Tensor`): The key tensor.
781
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
782
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
783
+ position_ids (`torch.Tensor`, *optional*):
784
+ Deprecated and unused.
785
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
786
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
787
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
788
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
789
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
790
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
791
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
792
+ Returns:
793
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
794
+ """
795
+ cos = cos.unsqueeze(unsqueeze_dim)
796
+ sin = sin.unsqueeze(unsqueeze_dim)
797
+ q_embed = (q * cos) + (rotate_half(q) * sin)
798
+ k_embed = (k * cos) + (rotate_half(k) * sin)
799
+ return q_embed, k_embed
800
+
801
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
802
+ """
803
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
804
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
805
+ """
806
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
807
+ if n_rep == 1:
808
+ return hidden_states
809
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
810
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
811
+
812
+
813
+ class LlamaAttention(nn.Module):
814
+ q_proj: nn.Linear
815
+ k_proj: nn.Linear
816
+ v_proj: nn.Linear
817
+ o_proj: nn.Linear
818
+ scaling: float
819
+ attention_dropout: float
820
+ is_causal: bool
821
+ layer_idx: int
822
+ num_key_value_groups: int
823
+ num_attention_heads: int
824
+ num_key_value_heads: int
825
+ head_dim: int
826
+ def forward(
827
+ self,
828
+ hidden_states: torch.Tensor,
829
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
830
+ attention_mask: Optional[torch.Tensor],
831
+ past_key_value: Optional[Cache] = None,
832
+ cache_position: Optional[torch.LongTensor] = None,
833
+ **kwargs: Unpack[FlashAttentionKwargs],
834
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
835
+ input_shape_orig = hidden_states.shape
836
+ input_shape = hidden_states.shape[:-1]
837
+ hidden_shape = (*input_shape, -1, self.head_dim)
838
+
839
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
840
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
841
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
842
+ # key_states = repeat_kv(key_states, self.num_key_value_groups)
843
+ # value_states = repeat_kv(value_states, self.num_key_value_groups)
844
+ cos, sin = position_embeddings
845
+
846
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
847
+
848
+ if past_key_value is not None:
849
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
850
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
851
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
852
+
853
+ # query_states = query_states.transpose(1, 2).view(*input_shape, self.head_dim*self.num_attention_heads)
854
+ # key_states = key_states.transpose(1, 2).view(*input_shape, self.head_dim*self.num_key_value_heads)
855
+ # value_states = value_states.transpose(1, 2).view(*input_shape, self.head_dim*self.num_key_value_heads)
856
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
857
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
858
+ attn_output = Attention.apply(
859
+ query_states,
860
+ key_states,
861
+ value_states,
862
+ self.is_causal,
863
+ self.scaling,
864
+ **kwargs,
865
+ )
866
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
867
+ attn_output = self.o_proj(attn_output)
868
+ return attn_output, None
869
+
870
+ def eager_attention_forward(
871
+ query: torch.Tensor,
872
+ key: torch.Tensor,
873
+ value: torch.Tensor,
874
+ attention_mask: Optional[torch.Tensor],
875
+ scaling: float,
876
+ dropout: float = 0.0,
877
+ num_key_value_groups: int = 1,
878
+ training: bool = False,
879
+ **kwargs,
880
+ ):
881
+ key_states = repeat_kv(key, num_key_value_groups)
882
+ value_states = repeat_kv(value, num_key_value_groups)
883
+
884
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
885
+ if attention_mask is not None:
886
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
887
+ attn_weights = attn_weights + causal_mask
888
+
889
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
890
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=training)
891
+ attn_output = torch.matmul(attn_weights, value_states)
892
+ # attn_output = attn_output.transpose(1, 2).contiguous()
893
+
894
+ return attn_output, attn_weights
895
+
896
+ class HFLlamaAttention(nn.Module):
897
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
898
+
899
+ def __init__(self, config: LlamaConfig, layer_idx: int, device: str):
900
+ super().__init__()
901
+ self.config = config
902
+ self.layer_idx = layer_idx
903
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
904
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
905
+ self.scaling = self.head_dim**-0.5
906
+ self.attention_dropout = config.attention_dropout
907
+ self.is_causal = True
908
+
909
+ self.q_proj = nn.Linear(
910
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias, device=device
911
+ )
912
+ self.k_proj = nn.Linear(
913
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias, device=device
914
+ )
915
+ self.v_proj = nn.Linear(
916
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias, device=device
917
+ )
918
+ self.o_proj = nn.Linear(
919
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias, device=device
920
+ )
921
+
922
+ def forward(
923
+ self,
924
+ hidden_states: torch.Tensor,
925
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
926
+ attention_mask: Optional[torch.Tensor],
927
+ past_key_value: Optional[Cache] = None,
928
+ cache_position: Optional[torch.LongTensor] = None,
929
+ **kwargs: Unpack[FlashAttentionKwargs],
930
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
931
+ input_shape = hidden_states.shape[:-1]
932
+ hidden_shape = (*input_shape, -1, self.head_dim)
933
+
934
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
935
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
936
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
937
+
938
+ cos, sin = position_embeddings
939
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
940
+
941
+ if past_key_value is not None:
942
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
943
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
944
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
945
+
946
+ attention_interface: Callable = eager_attention_forward
947
+
948
+ attn_output, attn_weights = attention_interface(
949
+ self,
950
+ query_states,
951
+ key_states,
952
+ value_states,
953
+ attention_mask,
954
+ dropout=0.0 if not self.training else self.attention_dropout,
955
+ scaling=self.scaling,
956
+ **kwargs,
957
+ )
958
+
959
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
960
+ attn_output = self.o_proj(attn_output)
961
+ return attn_output, attn_weights
962
+
963
+ def attn_forward_kernel(
964
+ query: torch.Tensor,
965
+ key: torch.Tensor,
966
+ value: torch.Tensor,
967
+ scaling: float,
968
+ causal: bool,
969
+ ):
970
+ return Attention.apply(query, key, value, causal, scaling)
971
+
972
+ # def test_llama_attention_output():
973
+ # """
974
+ # Test to verify that the LlamaAttention module produces correct outputs by comparing
975
+ # with the reference implementation from transformers.
976
+ # """
977
+ # import torch
978
+
979
+ # # Set up test parameters
980
+ # batch_size = 2
981
+ # seq_len = 16
982
+ # model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct", device_map="cuda:0")
983
+ # # print("######################### model", model)
984
+ # config = model.config
985
+ # # Create a custom LlamaAttention instance
986
+ # custom_attn = LlamaAttention()
987
+ # custom_attn.num_attention_heads = config.num_attention_heads
988
+ # custom_attn.num_key_value_heads = config.num_key_value_heads
989
+ # custom_attn.head_dim = config.head_dim
990
+ # custom_attn.hidden_size = config.hidden_size
991
+ # custom_attn.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * config.head_dim, bias=False, device="cuda:0")
992
+ # custom_attn.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * config.head_dim, bias=False, device="cuda:0")
993
+ # custom_attn.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * config.head_dim, bias=False, device="cuda:0")
994
+ # custom_attn.o_proj = nn.Linear(config.num_attention_heads * config.head_dim, config.hidden_size, bias=False, device="cuda:0")
995
+ # print("######################### custom_attn.o_proj", custom_attn.o_proj.weight.data.shape)
996
+ # custom_attn.scaling = 1.0 / (config.head_dim ** 0.5)
997
+ # custom_attn.attention_dropout = 0.0
998
+ # custom_attn.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
999
+ # custom_attn.is_causal = True
1000
+ # custom_attn.layer_idx = 0
1001
+
1002
+ # # Create a reference HF LlamaAttention instance
1003
+ # hf_attn = HFLlamaAttention(
1004
+ # config=config,
1005
+ # layer_idx=0,
1006
+ # device="cuda:0"
1007
+ # )
1008
+
1009
+ # # Copy weights to ensure identical parameters
1010
+ # hf_attn.q_proj.weight.data.copy_(custom_attn.q_proj.weight.data)
1011
+ # hf_attn.k_proj.weight.data.copy_(custom_attn.k_proj.weight.data)
1012
+ # hf_attn.v_proj.weight.data.copy_(custom_attn.v_proj.weight.data)
1013
+ # hf_attn.o_proj.weight.data.copy_(custom_attn.o_proj.weight.data)
1014
+
1015
+ # # Create test inputs
1016
+ # hidden_states = torch.randn(batch_size, seq_len, config.hidden_size).to("cuda:0")
1017
+
1018
+ # # Create position embeddings (cos, sin)
1019
+ # cos = torch.cos(torch.arange(0, config.head_dim).float() / 10000.0).unsqueeze(0).to("cuda:0")
1020
+ # sin = torch.sin(torch.arange(0, config.head_dim).float() / 10000.0).unsqueeze(0).to("cuda:0")
1021
+ # position_embeddings = (cos, sin)
1022
+
1023
+ # # Create attention mask
1024
+ # attention_mask = torch.ones(batch_size, seq_len, ).to("cuda:0")
1025
+
1026
+ # # Run custom attention
1027
+ # custom_output, _ = custom_attn(
1028
+ # hidden_states=hidden_states,
1029
+ # position_embeddings=position_embeddings,
1030
+ # attention_mask=None
1031
+ # )
1032
+
1033
+ # # Run HF attention
1034
+ # hf_output = hf_attn(
1035
+ # hidden_states=hidden_states,
1036
+ # attention_mask=None,
1037
+ # position_embeddings=position_embeddings,
1038
+ # )[0]
1039
+
1040
+ # # Check if outputs are close
1041
+ # assert torch.allclose(custom_output, hf_output, atol=1e-5), \
1042
+ # f"Custom attention output doesn't match reference. Max diff: {(custom_output - hf_output).abs().max()}"
1043
+
1044
+ # print("LlamaAttention test passed! Custom implementation matches reference.")
1045
+ # return True
1046
+
1047
+
1048
+ # def test_triton_attention_vs_eager():
1049
+ # """
1050
+ # Test that compares the output of the Triton-based Attention implementation
1051
+ # with the eager implementation from HuggingFace.
1052
+ # """
1053
+ # import torch
1054
+ # from transformers import LlamaConfig
1055
+
1056
+ # # Set up test parameters
1057
+ # batch_size = 2
1058
+ # seq_len = 16
1059
+
1060
+ # # Create a simple config
1061
+ # model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct", device_map="cuda:0")
1062
+ # # print("######################### model", model)
1063
+ # config = model.config
1064
+
1065
+ # # Create inputs
1066
+ # q = torch.randn(batch_size, config.num_attention_heads, seq_len, config.head_dim, device="cuda")
1067
+ # k = torch.randn(batch_size, config.num_key_value_heads, seq_len, config.head_dim, device="cuda")
1068
+ # v = torch.randn(batch_size, config.num_key_value_heads, seq_len, config.head_dim, device="cuda")
1069
+
1070
+ # # Make inputs contiguous
1071
+ # q = q.contiguous()
1072
+ # k = k.contiguous()
1073
+ # v = v.contiguous()
1074
+
1075
+ # # Scaling factor
1076
+ # sm_scale = 1.0 / (config.head_dim ** 0.5)
1077
+
1078
+ # # Run Triton-based attention
1079
+ # causal = True
1080
+ # triton_output = Attention.apply(q, k, v, causal, sm_scale)
1081
+
1082
+ # # Run eager implementation
1083
+ # # Create a dummy attention mask for causal attention
1084
+ # attention_mask = torch.ones(batch_size, 1, seq_len, seq_len, device="cuda")
1085
+ # # if causal:
1086
+ # # # Create a causal mask (lower triangular)
1087
+ # # attention_mask = torch.tril(torch.ones((seq_len, seq_len), device="cuda"))
1088
+
1089
+ # # Compute attention scores
1090
+ # eager_output, _ = eager_attention_forward(
1091
+ # query=q,
1092
+ # key=k,
1093
+ # value=v,
1094
+ # attention_mask=attention_mask,
1095
+ # scaling=sm_scale,
1096
+ # num_key_value_groups=config.num_attention_heads // config.num_key_value_heads,
1097
+ # training=False,
1098
+ # )
1099
+ # print("######################### triton_output", triton_output.shape)
1100
+ # print("######################### eager_output", eager_output.shape)
1101
+ # print("######################### triton_output", triton_output)
1102
+ # print("######################### eager_output", eager_output)
1103
+ # is_close = torch.allclose(triton_output, eager_output, atol=1e-4)
1104
+ # print(f"Triton attention matches eager implementation: {is_close}")
1105
+
1106
+ # return is_close
1107
+
1108
+ # attention = Attention.apply
1109
+ # DEVICE = "cuda:0"
1110
+
1111
+ # import pytest
1112
+ # @pytest.mark.parametrize("Z, H, N_CTX, HEAD_DIM", [(1, 2, 1024, 64)])
1113
+ # @pytest.mark.parametrize("causal", [True])
1114
+ # def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16):
1115
+ # torch.manual_seed(20)
1116
+ # q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
1117
+ # k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
1118
+ # v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
1119
+ # sm_scale = 0.5
1120
+ # dout = torch.randn_like(q)
1121
+ # # reference implementation
1122
+ # M = torch.tril(torch.ones((N_CTX, N_CTX), device=DEVICE))
1123
+ # p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
1124
+ # if causal:
1125
+ # p[:, :, M == 0] = float("-inf")
1126
+ # p = torch.softmax(p.float(), dim=-1).half()
1127
+ # # p = torch.exp(p)
1128
+ # ref_out = torch.matmul(p, v)
1129
+ # ref_out.backward(dout)
1130
+ # ref_dv, v.grad = v.grad.clone(), None
1131
+ # ref_dk, k.grad = k.grad.clone(), None
1132
+ # ref_dq, q.grad = q.grad.clone(), None
1133
+ # # triton implementation
1134
+ # tri_out = attention(q, k, v, causal, sm_scale).half()
1135
+ # tri_out.backward(dout)
1136
+ # tri_dv, v.grad = v.grad.clone(), None
1137
+ # tri_dk, k.grad = k.grad.clone(), None
1138
+ # tri_dq, q.grad = q.grad.clone(), None
1139
+ # # compare
1140
+ # assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0)
1141
+ # rtol = 0.0
1142
+ # # Relative tolerance workaround for known hardware limitation of CDNA2 GPU.
1143
+ # # For details see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices
1144
+ # if torch.version.hip is not None and triton.runtime.driver.active.get_current_target().arch == "gfx90a":
1145
+ # rtol = 1e-2
1146
+ # assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=rtol)
1147
+ # assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=rtol)
1148
+ # assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=rtol)
torch-ext/triton_llama_attn/layers.py ADDED
File without changes