zaydzuhri commited on
Commit
f72219a
·
verified ·
1 Parent(s): 4135502

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +11 -0
  2. checkpoint/step-1/.metadata +3 -0
  3. checkpoint/step-10000/.metadata +3 -0
  4. checkpoint/step-100000/.metadata +3 -0
  5. checkpoint/step-20000/.metadata +3 -0
  6. checkpoint/step-30000/.metadata +3 -0
  7. checkpoint/step-40000/.metadata +3 -0
  8. checkpoint/step-50000/.metadata +3 -0
  9. checkpoint/step-60000/.metadata +3 -0
  10. checkpoint/step-70000/.metadata +3 -0
  11. checkpoint/step-80000/.metadata +3 -0
  12. checkpoint/step-90000/.metadata +3 -0
  13. fla/ops/__pycache__/__init__.cpython-311.pyc +0 -0
  14. fla/ops/attn/__pycache__/naive_softpick.cpython-311.pyc +0 -0
  15. fla/ops/attn/__pycache__/parallel_rectified.cpython-311.pyc +0 -0
  16. fla/ops/attn/__pycache__/parallel_softpick.cpython-311.pyc +0 -0
  17. fla/ops/gated_delta_rule/__init__.py +7 -0
  18. fla/ops/gated_delta_rule/__pycache__/__init__.cpython-311.pyc +0 -0
  19. fla/ops/gated_delta_rule/__pycache__/chunk.cpython-311.pyc +0 -0
  20. fla/ops/gated_delta_rule/__pycache__/fused_recurrent.cpython-311.pyc +0 -0
  21. fla/ops/gated_delta_rule/__pycache__/wy_fast.cpython-311.pyc +0 -0
  22. fla/ops/gated_delta_rule/chunk.py +392 -0
  23. fla/ops/gated_delta_rule/fused_recurrent.py +321 -0
  24. fla/ops/gated_delta_rule/wy_fast.py +620 -0
  25. fla/ops/generalized_delta_rule/README.md +37 -0
  26. fla/ops/generalized_delta_rule/__init__.py +9 -0
  27. fla/ops/generalized_delta_rule/__pycache__/__init__.cpython-311.pyc +0 -0
  28. fla/ops/generalized_delta_rule/dplr/chunk.py +388 -0
  29. fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py +446 -0
  30. fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py +324 -0
  31. fla/ops/generalized_delta_rule/dplr/chunk_h_bwd.py +196 -0
  32. fla/ops/generalized_delta_rule/dplr/chunk_h_fwd.py +197 -0
  33. fla/ops/generalized_delta_rule/dplr/chunk_o_fwd.py +138 -0
  34. fla/ops/generalized_delta_rule/dplr/fused_recurrent.py +292 -0
  35. fla/ops/generalized_delta_rule/dplr/naive.py +96 -0
  36. fla/ops/generalized_delta_rule/dplr/wy_fast_bwd.py +184 -0
  37. fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py +318 -0
  38. fla/ops/generalized_delta_rule/iplr/__init__.py +7 -0
  39. fla/ops/generalized_delta_rule/iplr/__pycache__/__init__.cpython-311.pyc +0 -0
  40. fla/ops/generalized_delta_rule/iplr/__pycache__/chunk.cpython-311.pyc +0 -0
  41. fla/ops/generalized_delta_rule/iplr/__pycache__/fused_recurrent.cpython-311.pyc +0 -0
  42. fla/ops/generalized_delta_rule/iplr/__pycache__/wy_fast.cpython-311.pyc +0 -0
  43. fla/ops/generalized_delta_rule/iplr/chunk.py +528 -0
  44. fla/ops/generalized_delta_rule/iplr/fused_recurrent.py +451 -0
  45. fla/ops/generalized_delta_rule/iplr/naive.py +69 -0
  46. fla/ops/generalized_delta_rule/iplr/wy_fast.py +338 -0
  47. fla/ops/gla/__init__.py +11 -0
  48. fla/ops/gla/__pycache__/__init__.cpython-311.pyc +0 -0
  49. fla/ops/gla/__pycache__/chunk.cpython-311.pyc +0 -0
  50. fla/ops/gla/__pycache__/fused_chunk.cpython-311.pyc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ checkpoint/step-60000/.metadata filter=lfs diff=lfs merge=lfs -text
37
+ checkpoint/step-70000/.metadata filter=lfs diff=lfs merge=lfs -text
38
+ checkpoint/step-90000/.metadata filter=lfs diff=lfs merge=lfs -text
39
+ checkpoint/step-100000/.metadata filter=lfs diff=lfs merge=lfs -text
40
+ checkpoint/step-20000/.metadata filter=lfs diff=lfs merge=lfs -text
41
+ checkpoint/step-80000/.metadata filter=lfs diff=lfs merge=lfs -text
42
+ checkpoint/step-10000/.metadata filter=lfs diff=lfs merge=lfs -text
43
+ checkpoint/step-50000/.metadata filter=lfs diff=lfs merge=lfs -text
44
+ checkpoint/step-1/.metadata filter=lfs diff=lfs merge=lfs -text
45
+ checkpoint/step-30000/.metadata filter=lfs diff=lfs merge=lfs -text
46
+ checkpoint/step-40000/.metadata filter=lfs diff=lfs merge=lfs -text
checkpoint/step-1/.metadata ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7a1dc8097b7a6f7f04d8d3b1ac59dbca7331cd4f7554d465556a775cc8fb2a3
3
+ size 1966399
checkpoint/step-10000/.metadata ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c8230628b2b9318c77cc629a747246173e1f08994304d6eb6c277b937fe4a122
3
+ size 1966605
checkpoint/step-100000/.metadata ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:20172fb8269576c2752f0c9c5571db66977472808e11f8a08a62ecebc8c1ca3f
3
+ size 1966953
checkpoint/step-20000/.metadata ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b6129edb991ed42f312a5437d565bb7921e331f0c8ee6e8d2b9271fd28d05350
3
+ size 1966726
checkpoint/step-30000/.metadata ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cdd2f42f953103801f648d026a8c5473671a3970a42b90dbec3d61eb6887bba6
3
+ size 1966842
checkpoint/step-40000/.metadata ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1be5eabe3423f35d02413a59af0177d6cfc9e7a30fad1bf5d7c98ae6740fe503
3
+ size 1966870
checkpoint/step-50000/.metadata ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b8de54a3bd58bec547d87c515963acf7b678e1976263b0c7d6106842d07b8ce
3
+ size 1966890
checkpoint/step-60000/.metadata ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cdf7894da9066d4b1b627693dec1f3df1211e7a894995740daaa7d77b7dd1985
3
+ size 1966912
checkpoint/step-70000/.metadata ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:843a213824fe630f60d4c48810db333f8ffec4138a436c837403aba0704abad8
3
+ size 1966934
checkpoint/step-80000/.metadata ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65cc9b063d0a6a992190ccd8efccbf5142716eb5c2082a9ae581bbede3d15fd7
3
+ size 1966952
checkpoint/step-90000/.metadata ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a25b7f564d152ef6984a3e23b67a896d57c0aae5a269437b077ef5bee760039
3
+ size 1966952
fla/ops/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (2.55 kB). View file
 
fla/ops/attn/__pycache__/naive_softpick.cpython-311.pyc ADDED
Binary file (3 kB). View file
 
fla/ops/attn/__pycache__/parallel_rectified.cpython-311.pyc ADDED
Binary file (34.7 kB). View file
 
fla/ops/attn/__pycache__/parallel_softpick.cpython-311.pyc ADDED
Binary file (35.9 kB). View file
 
fla/ops/gated_delta_rule/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .chunk import chunk_gated_delta_rule
2
+ from .fused_recurrent import fused_recurrent_gated_delta_rule
3
+
4
+ __all__ = [
5
+ "chunk_gated_delta_rule",
6
+ "fused_recurrent_gated_delta_rule"
7
+ ]
fla/ops/gated_delta_rule/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (358 Bytes). View file
 
fla/ops/gated_delta_rule/__pycache__/chunk.cpython-311.pyc ADDED
Binary file (15 kB). View file
 
fla/ops/gated_delta_rule/__pycache__/fused_recurrent.cpython-311.pyc ADDED
Binary file (15.4 kB). View file
 
fla/ops/gated_delta_rule/__pycache__/wy_fast.cpython-311.pyc ADDED
Binary file (45.9 kB). View file
 
fla/ops/gated_delta_rule/chunk.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ from einops import rearrange
9
+
10
+ from fla.modules.l2norm import l2norm_bwd, l2norm_fwd
11
+ from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu, chunk_gated_delta_rule_fwd_h
12
+ from fla.ops.common.chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv_local, chunk_fwd_o
13
+ from fla.ops.gated_delta_rule.wy_fast import bwd_prepare_wy_repr, fwd_prepare_wy_repr, fwd_recompute_w_u
14
+ from fla.ops.utils import chunk_local_cumsum
15
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
16
+
17
+
18
+ def chunk_gated_delta_rule_fwd(
19
+ q: torch.Tensor,
20
+ k: torch.Tensor,
21
+ v: torch.Tensor,
22
+ g: torch.Tensor,
23
+ beta: torch.Tensor,
24
+ scale: float,
25
+ initial_state: torch.Tensor,
26
+ output_final_state: bool,
27
+ offsets: Optional[torch.LongTensor] = None,
28
+ indices: Optional[torch.LongTensor] = None,
29
+ head_first: bool = True,
30
+ chunk_size: int = 64
31
+ ):
32
+ g = chunk_local_cumsum(g, chunk_size, offsets=offsets, indices=indices, head_first=head_first)
33
+ # obtain WY representation. u is actually the new v.
34
+ w, u, Aw, Au = fwd_prepare_wy_repr(
35
+ k=k,
36
+ v=v,
37
+ beta=beta,
38
+ g=g,
39
+ offsets=offsets,
40
+ indices=indices,
41
+ head_first=head_first,
42
+ chunk_size=chunk_size
43
+ )
44
+
45
+ h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
46
+ k=k,
47
+ w=w,
48
+ u=u,
49
+ g=g,
50
+ initial_state=initial_state,
51
+ output_final_state=output_final_state,
52
+ offsets=offsets,
53
+ indices=indices,
54
+ head_first=head_first,
55
+ chunk_size=chunk_size
56
+ )
57
+
58
+ # obtain output
59
+ o = chunk_fwd_o(
60
+ q=q,
61
+ k=k,
62
+ v=v_new,
63
+ h=h,
64
+ g=g,
65
+ scale=scale,
66
+ offsets=offsets,
67
+ indices=indices,
68
+ head_first=head_first,
69
+ chunk_size=chunk_size
70
+ )
71
+ return g, o, Aw, Au, final_state
72
+
73
+
74
+ def chunk_gated_delta_rule_bwd(
75
+ q: torch.Tensor,
76
+ k: torch.Tensor,
77
+ v: torch.Tensor,
78
+ g: torch.Tensor,
79
+ beta: torch.Tensor,
80
+ Aw: torch.Tensor,
81
+ Au: torch.Tensor,
82
+ scale: float,
83
+ initial_state: torch.Tensor,
84
+ do: torch.Tensor,
85
+ dht: torch.Tensor,
86
+ offsets: Optional[torch.LongTensor] = None,
87
+ indices: Optional[torch.LongTensor] = None,
88
+ head_first: bool = True,
89
+ chunk_size: int = 64
90
+ ):
91
+ T = q.shape[2] if head_first else q.shape[1]
92
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
93
+ w, u = fwd_recompute_w_u(
94
+ k=k,
95
+ v=v,
96
+ beta=beta,
97
+ Aw=Aw,
98
+ Au=Au,
99
+ offsets=offsets,
100
+ indices=indices,
101
+ head_first=head_first,
102
+ chunk_size=BT
103
+ )
104
+ h, v_new, _ = chunk_gated_delta_rule_fwd_h(
105
+ k=k,
106
+ w=w,
107
+ u=u,
108
+ g=g,
109
+ initial_state=initial_state,
110
+ output_final_state=False,
111
+ offsets=offsets,
112
+ indices=indices,
113
+ head_first=head_first,
114
+ chunk_size=BT
115
+ )
116
+ dv = chunk_bwd_dv_local(
117
+ q=q,
118
+ k=k,
119
+ g=g,
120
+ do=do,
121
+ dh=None,
122
+ scale=scale,
123
+ offsets=offsets,
124
+ indices=indices,
125
+ head_first=head_first,
126
+ chunk_size=BT
127
+ )
128
+ dh, dh0, dv = chunk_gated_delta_rule_bwd_dhu(
129
+ q=q,
130
+ k=k,
131
+ w=w,
132
+ g=g,
133
+ h0=initial_state,
134
+ dht=dht,
135
+ do=do,
136
+ dv=dv,
137
+ scale=scale,
138
+ offsets=offsets,
139
+ indices=indices,
140
+ head_first=head_first,
141
+ chunk_size=BT
142
+ )
143
+ dq, dk, dw, dg = chunk_bwd_dqkwg(
144
+ q=q,
145
+ k=k,
146
+ v=v_new,
147
+ w=w,
148
+ g=g,
149
+ h=h,
150
+ dv=dv,
151
+ do=do,
152
+ dh=dh,
153
+ scale=scale,
154
+ offsets=offsets,
155
+ indices=indices,
156
+ head_first=head_first,
157
+ chunk_size=BT
158
+ )
159
+ dk2, dv, db, dg2 = bwd_prepare_wy_repr(
160
+ k=k,
161
+ v=v,
162
+ beta=beta,
163
+ g=g,
164
+ Aw=Aw,
165
+ Au=Au,
166
+ dw=dw,
167
+ du=dv,
168
+ offsets=offsets,
169
+ indices=indices,
170
+ head_first=head_first,
171
+ chunk_size=BT
172
+ )
173
+ dk.add_(dk2)
174
+ dg.add_(dg2)
175
+ assert dg.dtype == torch.float32, "dg should be fp32"
176
+ dg = chunk_local_cumsum(dg, chunk_size, reverse=True, offsets=offsets, indices=indices, head_first=head_first)
177
+ return dq, dk, dv, db, dg, dh0
178
+
179
+
180
+ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
181
+
182
+ @staticmethod
183
+ @input_guard
184
+ @autocast_custom_fwd
185
+ def forward(
186
+ ctx,
187
+ q: torch.Tensor,
188
+ k: torch.Tensor,
189
+ v: torch.Tensor,
190
+ g: torch.Tensor,
191
+ beta: torch.Tensor,
192
+ scale: float,
193
+ initial_state: torch.Tensor,
194
+ output_final_state: bool,
195
+ offsets: Optional[torch.LongTensor] = None,
196
+ head_first: bool = True,
197
+ use_qk_l2norm_in_kernel: bool = False
198
+ ):
199
+ chunk_size = 64
200
+ q_orig = q
201
+ k_orig = k
202
+
203
+ if use_qk_l2norm_in_kernel:
204
+ q = l2norm_fwd(q)
205
+ k = l2norm_fwd(k)
206
+
207
+ # 2-d indices denoting the offsets of chunks in each sequence
208
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
209
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
210
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
211
+ indices = None
212
+ if offsets is not None:
213
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], chunk_size).tolist()])
214
+ indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets)
215
+
216
+ g, o, Aw, Au, final_state = chunk_gated_delta_rule_fwd(
217
+ q=q,
218
+ k=k,
219
+ v=v,
220
+ g=g,
221
+ beta=beta,
222
+ scale=scale,
223
+ initial_state=initial_state,
224
+ output_final_state=output_final_state,
225
+ offsets=offsets,
226
+ indices=indices,
227
+ head_first=head_first,
228
+ chunk_size=chunk_size,
229
+ )
230
+ ctx.save_for_backward(q_orig, k_orig, v, g, beta, Aw, Au, initial_state, offsets, indices)
231
+ ctx.chunk_size = chunk_size
232
+ ctx.scale = scale
233
+ ctx.head_first = head_first
234
+ ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
235
+ return o.to(q.dtype), final_state
236
+
237
+ @staticmethod
238
+ @input_guard
239
+ @autocast_custom_bwd
240
+ def backward(
241
+ ctx,
242
+ do: torch.Tensor,
243
+ dht: torch.Tensor
244
+ ):
245
+ q, k, v, g, beta, Aw, Au, initial_state, offsets, indices = ctx.saved_tensors
246
+ if ctx.use_qk_l2norm_in_kernel:
247
+ q, q_orig = l2norm_fwd(q), q
248
+ k, k_orig = l2norm_fwd(k), k
249
+ dq, dk, dv, db, dg, dh0 = chunk_gated_delta_rule_bwd(
250
+ q=q,
251
+ k=k,
252
+ v=v,
253
+ g=g,
254
+ beta=beta,
255
+ Aw=Aw,
256
+ Au=Au,
257
+ scale=ctx.scale,
258
+ initial_state=initial_state,
259
+ do=do,
260
+ dht=dht,
261
+ offsets=offsets,
262
+ indices=indices,
263
+ head_first=ctx.head_first,
264
+ chunk_size=ctx.chunk_size
265
+ )
266
+ if ctx.use_qk_l2norm_in_kernel:
267
+ dq = l2norm_bwd(q_orig, dq)
268
+ dk = l2norm_bwd(k_orig, dk)
269
+ return dq.to(q), dk.to(k), dv.to(v), dg.to(g), db.to(beta), None, dh0, None, None, None, None
270
+
271
+
272
+ @torch.compiler.disable
273
+ def chunk_gated_delta_rule(
274
+ q: torch.Tensor,
275
+ k: torch.Tensor,
276
+ v: torch.Tensor,
277
+ g: torch.Tensor,
278
+ beta: torch.Tensor,
279
+ scale: float = None,
280
+ initial_state: torch.Tensor = None,
281
+ output_final_state: bool = False,
282
+ cu_seqlens: Optional[torch.LongTensor] = None,
283
+ head_first: bool = False,
284
+ use_qk_l2norm_in_kernel: bool = False
285
+ ):
286
+ r"""
287
+ Args:
288
+ q (torch.Tensor):
289
+ queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
290
+ k (torch.Tensor):
291
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
292
+ v (torch.Tensor):
293
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
294
+ g (torch.Tensor):
295
+ (forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
296
+ beta (torch.Tensor):
297
+ betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
298
+ scale (Optional[int]):
299
+ Scale factor for the RetNet attention scores.
300
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
301
+ initial_state (Optional[torch.Tensor]):
302
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
303
+ For equal-length input sequences, `N` equals the batch size `B`.
304
+ Default: `None`.
305
+ output_final_state (Optional[bool]):
306
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
307
+ cu_seqlens (torch.LongTensor):
308
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
309
+ consistent with the FlashAttention API.
310
+ head_first (Optional[bool]):
311
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
312
+ Default: `False`.
313
+
314
+ Returns:
315
+ o (torch.Tensor):
316
+ Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
317
+ final_state (torch.Tensor):
318
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
319
+
320
+ Examples::
321
+ >>> import torch
322
+ >>> import torch.nn.functional as F
323
+ >>> from einops import rearrange
324
+ >>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule
325
+ # inputs with equal lengths
326
+ >>> B, T, H, K, V = 4, 2048, 4, 512, 512
327
+ >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
328
+ >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
329
+ >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
330
+ >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
331
+ >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
332
+ >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
333
+ >>> o, ht = chunk_gated_delta_rule(
334
+ q, k, v, g, beta,
335
+ initial_state=h0,
336
+ output_final_state=True,
337
+ head_first=False
338
+ )
339
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
340
+ >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
341
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
342
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
343
+ >>> o_var, ht_var = chunk_gated_delta_rule(
344
+ q, k, v, g, beta,
345
+ initial_state=h0,
346
+ output_final_state=True,
347
+ cu_seqlens=cu_seqlens,
348
+ head_first=False
349
+ )
350
+ """
351
+ assert q.dtype == k.dtype == v.dtype
352
+ assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
353
+ assert len(beta.shape) == 3, "beta must be of shape [B, H, T] if head_first=True, or [B, T, H] if head_first=False."
354
+
355
+ if cu_seqlens is not None:
356
+ if q.shape[0] != 1:
357
+ raise ValueError(
358
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
359
+ f"Please flatten variable-length inputs before processing."
360
+ )
361
+ if head_first:
362
+ raise RuntimeError(
363
+ "Sequences with variable lengths are not supported for head-first mode"
364
+ )
365
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
366
+ raise ValueError(
367
+ f"The number of initial states is expected to be equal to the number of input sequences, "
368
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
369
+ )
370
+ if head_first:
371
+ q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v))
372
+ beta, g = map(lambda x: rearrange(x, 'b h t -> b t h'), (beta, g))
373
+ if scale is None:
374
+ scale = k.shape[-1] ** -0.5
375
+ else:
376
+ assert scale > 0, "Scale must be positive."
377
+ o, final_state = ChunkGatedDeltaRuleFunction.apply(
378
+ q,
379
+ k,
380
+ v,
381
+ g,
382
+ beta,
383
+ scale,
384
+ initial_state,
385
+ output_final_state,
386
+ cu_seqlens,
387
+ False,
388
+ use_qk_l2norm_in_kernel
389
+ )
390
+ if head_first:
391
+ o = rearrange(o, 'b t h v -> b h t v')
392
+ return o, final_state
fla/ops/gated_delta_rule/fused_recurrent.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from einops import rearrange
10
+
11
+ from fla.ops.utils.op import exp
12
+ from fla.utils import input_guard
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
17
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
18
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
19
+ })
20
+ @triton.jit(do_not_specialize=['T'])
21
+ def fused_recurrent_gated_delta_rule_fwd_kernel(
22
+ q,
23
+ k,
24
+ v,
25
+ g,
26
+ beta,
27
+ o,
28
+ h0,
29
+ ht,
30
+ offsets,
31
+ scale,
32
+ T,
33
+ B: tl.constexpr,
34
+ H: tl.constexpr,
35
+ K: tl.constexpr,
36
+ V: tl.constexpr,
37
+ BK: tl.constexpr,
38
+ BV: tl.constexpr,
39
+ USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
40
+ STORE_FINAL_STATE: tl.constexpr, # whether to store final state
41
+ IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar,
42
+ USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
43
+ USE_OFFSETS: tl.constexpr
44
+ ):
45
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
46
+ i_n, i_h = i_nh // H, i_nh % H
47
+ if USE_OFFSETS:
48
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
49
+ all = T
50
+ T = eos - bos
51
+ else:
52
+ bos, eos = i_n * T, i_n * T + T
53
+ all = B * T
54
+ o_k = i_k * BK + tl.arange(0, BK)
55
+ o_v = i_v * BV + tl.arange(0, BV)
56
+
57
+ p_q = q + (bos * H + i_h) * K + o_k
58
+ p_k = k + (bos * H + i_h) * K + o_k
59
+ p_v = v + (bos * H + i_h) * V + o_v
60
+ if IS_BETA_HEADWISE:
61
+ p_beta = beta + (bos * H + i_h) * V + o_v
62
+ else:
63
+ p_beta = beta + bos * H + i_h
64
+ p_g = g + bos * H + i_h
65
+ p_o = o + ((i_k * all + bos) * H + i_h) * V + o_v
66
+
67
+ mask_k = o_k < K
68
+ mask_v = o_v < V
69
+ mask_h = mask_k[:, None] & mask_v[None, :]
70
+
71
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
72
+ if USE_INITIAL_STATE:
73
+ p_h0 = h0 + i_nh * K*V + o_k[:, None] * V + o_v[None, :]
74
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
75
+
76
+ for _ in range(0, T):
77
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
78
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
79
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
80
+ b_g = tl.load(p_g).to(tl.float32)
81
+
82
+ if USE_QK_L2NORM_IN_KERNEL:
83
+ b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
84
+ b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
85
+ b_q = b_q * scale
86
+ # [BK, BV]
87
+ b_h *= exp(b_g)
88
+ # [BV]
89
+ b_v -= tl.sum(b_h * b_k[:, None], 0)
90
+ if IS_BETA_HEADWISE:
91
+ b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
92
+ else:
93
+ b_beta = tl.load(p_beta).to(tl.float32)
94
+ b_v *= b_beta
95
+ # [BK, BV]
96
+ b_h += b_k[:, None] * b_v[None, :]
97
+ # [BV]
98
+ b_o = tl.sum(b_h * b_q[:, None], 0)
99
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
100
+
101
+ p_q += H*K
102
+ p_k += H*K
103
+ p_o += H*V
104
+ p_v += H*V
105
+ p_g += H
106
+ p_beta += H * (V if IS_BETA_HEADWISE else 1)
107
+
108
+ if STORE_FINAL_STATE:
109
+ p_ht = ht + i_nh * K*V + o_k[:, None] * V + o_v[None, :]
110
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
111
+
112
+
113
+ def fused_recurrent_gated_delta_rule_fwd(
114
+ q: torch.Tensor,
115
+ k: torch.Tensor,
116
+ v: torch.Tensor,
117
+ g: torch.Tensor,
118
+ beta: torch.Tensor,
119
+ scale: float,
120
+ initial_state: torch.Tensor,
121
+ output_final_state: bool,
122
+ use_qk_l2norm_in_kernel: bool = False,
123
+ offsets: Optional[torch.LongTensor] = None,
124
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
125
+ B, T, H, K, V = *k.shape, v.shape[-1]
126
+ N = B if offsets is None else len(offsets) - 1
127
+ BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8)
128
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
129
+ assert NK == 1, "NK > 1 is not supported yet"
130
+ num_stages = 3
131
+ num_warps = 1
132
+
133
+ o = q.new_empty(NK, *v.shape)
134
+ if output_final_state:
135
+ final_state = q.new_empty(N, H, K, V, dtype=torch.float32)
136
+ else:
137
+ final_state = None
138
+
139
+ grid = (NK, NV, N * H)
140
+ fused_recurrent_gated_delta_rule_fwd_kernel[grid](
141
+ q=q,
142
+ k=k,
143
+ v=v,
144
+ g=g,
145
+ beta=beta,
146
+ o=o,
147
+ h0=initial_state,
148
+ ht=final_state,
149
+ offsets=offsets,
150
+ scale=scale,
151
+ T=T,
152
+ B=B,
153
+ H=H,
154
+ K=K,
155
+ V=V,
156
+ BK=BK,
157
+ BV=BV,
158
+ IS_BETA_HEADWISE=beta.ndim == v.ndim,
159
+ USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
160
+ num_warps=num_warps,
161
+ num_stages=num_stages,
162
+ )
163
+ o = o.squeeze(0)
164
+ return o, final_state
165
+
166
+
167
+ class FusedRecurrentFunction(torch.autograd.Function):
168
+
169
+ @staticmethod
170
+ @input_guard
171
+ def forward(
172
+ ctx,
173
+ q: torch.Tensor,
174
+ k: torch.Tensor,
175
+ v: torch.Tensor,
176
+ g: torch.Tensor,
177
+ beta: torch.Tensor,
178
+ scale: float,
179
+ initial_state: torch.Tensor,
180
+ output_final_state: bool,
181
+ offsets: Optional[torch.LongTensor] = None,
182
+ use_qk_l2norm_in_kernel: bool = False
183
+ ):
184
+ o, final_state = fused_recurrent_gated_delta_rule_fwd(
185
+ q=q,
186
+ k=k,
187
+ v=v,
188
+ g=g,
189
+ beta=beta,
190
+ scale=scale,
191
+ initial_state=initial_state,
192
+ output_final_state=output_final_state,
193
+ use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
194
+ offsets=offsets
195
+ )
196
+
197
+ return o, final_state
198
+
199
+ @staticmethod
200
+ @input_guard
201
+ def backward(ctx, do, dht):
202
+ raise NotImplementedError(
203
+ "Backward pass is not implemented yet and we do not have plans to implement it "
204
+ "because we haven't figured out how to compute dg without materializing the full "
205
+ "hidden states for all time steps."
206
+ )
207
+
208
+
209
+ def fused_recurrent_gated_delta_rule(
210
+ q: torch.Tensor,
211
+ k: torch.Tensor,
212
+ v: torch.Tensor,
213
+ g: torch.Tensor,
214
+ beta: torch.Tensor = None,
215
+ scale: float = None,
216
+ initial_state: torch.Tensor = None,
217
+ output_final_state: bool = False,
218
+ cu_seqlens: Optional[torch.LongTensor] = None,
219
+ use_qk_l2norm_in_kernel: bool = False,
220
+ head_first: bool = False,
221
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
222
+ r"""
223
+ Args:
224
+ q (torch.Tensor):
225
+ queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
226
+ k (torch.Tensor):
227
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
228
+ v (torch.Tensor):
229
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
230
+ g (torch.Tensor):
231
+ g (decays) of shape `[B, T, H]` if `head_first=False` else `(B, H, T)`.
232
+ beta (torch.Tensor):
233
+ betas of shape `[B, T, H]` if `head_first=False` else `(B, H, T)`.
234
+ scale (Optional[int]):
235
+ Scale factor for the RetNet attention scores.
236
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
237
+ initial_state (Optional[torch.Tensor]):
238
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
239
+ For equal-length input sequences, `N` equals the batch size `B`.
240
+ Default: `None`.
241
+ output_final_state (Optional[bool]):
242
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
243
+ cu_seqlens (torch.LongTensor):
244
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
245
+ consistent with the FlashAttention API.
246
+
247
+ Returns:
248
+ o (torch.Tensor):
249
+ Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
250
+ final_state (torch.Tensor):
251
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
252
+
253
+ Examples::
254
+ >>> import torch
255
+ >>> import torch.nn.functional as F
256
+ >>> from einops import rearrange
257
+ >>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule
258
+ # inputs with equal lengths
259
+ >>> B, T, H, K, V = 4, 2048, 4, 512, 512
260
+ >>> q = torch.randn(B, T, H, K, device='cuda')
261
+ >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1)
262
+ >>> v = torch.randn(B, T, H, V, device='cuda')
263
+ >>> g = F.logsigmoid(torch.rand(B, T, H, device='cuda'))
264
+ >>> beta = torch.rand(B, T, H, device='cuda').sigmoid()
265
+ >>> h0 = torch.randn(B, H, K, V, device='cuda')
266
+ >>> o, ht = fused_gated_recurrent_delta_rule(
267
+ q, k, v, g, beta,
268
+ initial_state=h0,
269
+ output_final_state=True,
270
+ )
271
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
272
+ >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta))
273
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
274
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
275
+ >>> o_var, ht_var = fused_gated_recurrent_delta_rule(
276
+ q, k, v, g, beta,
277
+ initial_state=h0,
278
+ output_final_state=True,
279
+ cu_seqlens=cu_seqlens
280
+ )
281
+ >>> assert o.allclose(o_var.view(o.shape))
282
+ >>> assert ht.allclose(ht_var)
283
+ """
284
+ if cu_seqlens is not None:
285
+ if q.shape[0] != 1:
286
+ raise ValueError(
287
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
288
+ f"Please flatten variable-length inputs before processing."
289
+ )
290
+ if head_first:
291
+ raise RuntimeError(
292
+ "Sequences with variable lengths are not supported for head-first mode"
293
+ )
294
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
295
+ raise ValueError(
296
+ f"The number of initial states is expected to be equal to the number of input sequences, "
297
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
298
+ )
299
+ if scale is None:
300
+ scale = k.shape[-1] ** -0.5
301
+ else:
302
+ assert scale > 0, "scale must be positive"
303
+ if beta is None:
304
+ beta = torch.ones_like(q[..., 0])
305
+ if head_first:
306
+ q, k, v, g, beta = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, g, beta))
307
+ o, final_state = FusedRecurrentFunction.apply(
308
+ q,
309
+ k,
310
+ v,
311
+ g,
312
+ beta,
313
+ scale,
314
+ initial_state,
315
+ output_final_state,
316
+ cu_seqlens,
317
+ use_qk_l2norm_in_kernel
318
+ )
319
+ if head_first:
320
+ o = rearrange(o, 'b t h v -> b h t v')
321
+ return o, final_state
fla/ops/gated_delta_rule/wy_fast.py ADDED
@@ -0,0 +1,620 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import safe_exp
11
+ from fla.utils import check_shared_mem
12
+
13
+
14
+ @triton.heuristics({
15
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
16
+ })
17
+ @triton.autotune(
18
+ configs=[
19
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
20
+ for num_warps in [2, 4, 8]
21
+ for num_stages in [2, 3, 4]
22
+ ],
23
+ key=['H', 'K', 'BT', 'BK', 'BC', 'HEAD_FIRST', 'USE_OFFSETS'],
24
+ )
25
+ @triton.jit(do_not_specialize=['T'])
26
+ def fwd_prepare_wy_repr_kernel_chunk32(
27
+ k,
28
+ g,
29
+ beta,
30
+ Aw,
31
+ Au,
32
+ offsets,
33
+ indices,
34
+ T,
35
+ H: tl.constexpr,
36
+ K: tl.constexpr,
37
+ BT: tl.constexpr,
38
+ BK: tl.constexpr,
39
+ BC: tl.constexpr,
40
+ HEAD_FIRST: tl.constexpr,
41
+ USE_OFFSETS: tl.constexpr
42
+ ):
43
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
44
+ i_b, i_h = i_bh // H, i_bh % H
45
+ if USE_OFFSETS:
46
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
47
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
48
+ T = eos - bos
49
+ else:
50
+ bos, eos = i_b * T, i_b * T + T
51
+
52
+ b_Aw = tl.zeros([BC, BC], dtype=tl.float32)
53
+ if HEAD_FIRST:
54
+ p_beta = tl.make_block_ptr(beta + i_bh*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
55
+ else:
56
+ p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
57
+
58
+ b_beta = tl.load(p_beta, boundary_check=(0,))
59
+
60
+ for i_k in range(tl.cdiv(K, BK)):
61
+ if HEAD_FIRST:
62
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
63
+ else:
64
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
65
+ b_k = tl.load(p_k, boundary_check=(0, 1))
66
+ b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
67
+ b_Aw += tl.dot(b_kb, tl.trans(b_k))
68
+
69
+ b_Aw = -tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_Aw, 0)
70
+
71
+ if HEAD_FIRST:
72
+ p_g = tl.make_block_ptr(g + i_bh*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
73
+ else:
74
+ p_g = tl.make_block_ptr(g + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
75
+
76
+ b_g = tl.load(p_g, boundary_check=(0,))
77
+ b_Au = b_Aw * safe_exp(b_g[:, None] - b_g[None, :])
78
+
79
+ for i in range(1, BC):
80
+ mask = tl.arange(0, BC) == i
81
+ b_aw = tl.sum(tl.where(mask[:, None], b_Aw, 0), 0)
82
+ b_au = tl.sum(tl.where(mask[:, None], b_Au, 0), 0)
83
+ b_aw = b_aw + tl.sum(b_aw[:, None] * b_Aw, 0) * (tl.arange(0, BC) < i)
84
+ b_au = b_au + tl.sum(b_au[:, None] * b_Au, 0) * (tl.arange(0, BC) < i)
85
+ b_Aw = tl.where(mask[:, None], b_aw, b_Aw)
86
+ b_Au = tl.where(mask[:, None], b_au, b_Au)
87
+
88
+ # blockwise computation of lower triangular matrix's inverse
89
+ # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1]
90
+ b_Aw += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
91
+ b_Au += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
92
+ if HEAD_FIRST:
93
+ p_Aw = tl.make_block_ptr(Aw + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
94
+ p_Au = tl.make_block_ptr(Au + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
95
+ else:
96
+ p_Aw = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
97
+ p_Au = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
98
+ tl.store(p_Aw, b_Aw.to(p_Aw.dtype.element_ty), boundary_check=(0, 1))
99
+ tl.store(p_Au, b_Au.to(p_Au.dtype.element_ty), boundary_check=(0, 1))
100
+
101
+
102
+ @triton.heuristics({
103
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
104
+ })
105
+ @triton.autotune(
106
+ configs=[
107
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
108
+ for num_warps in [2, 4, 8]
109
+ for num_stages in [2, 3, 4]
110
+ ],
111
+ key=['H', 'K', 'BT', 'BK', 'BC', 'USE_OFFSETS', 'HEAD_FIRST'],
112
+ )
113
+ @triton.jit(do_not_specialize=['T'])
114
+ def fwd_prepare_wy_repr_kernel_chunk64(
115
+ k,
116
+ g,
117
+ beta,
118
+ Aw,
119
+ Au,
120
+ offsets,
121
+ indices,
122
+ T,
123
+ H: tl.constexpr,
124
+ K: tl.constexpr,
125
+ BT: tl.constexpr,
126
+ BK: tl.constexpr,
127
+ BC: tl.constexpr,
128
+ USE_OFFSETS: tl.constexpr,
129
+ HEAD_FIRST: tl.constexpr
130
+ ):
131
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
132
+ i_b, i_h = i_bh // H, i_bh % H
133
+ if USE_OFFSETS:
134
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
135
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
136
+ T = eos - bos
137
+ else:
138
+ bos, eos = i_b * T, i_b * T + T
139
+
140
+ b_Aw = tl.zeros([BC, BC], dtype=tl.float32)
141
+ b_Aw2 = tl.zeros([BC, BC], dtype=tl.float32)
142
+ b_Aw3 = tl.zeros([BC, BC], dtype=tl.float32)
143
+ if HEAD_FIRST:
144
+ p_beta = tl.make_block_ptr(beta + i_bh*T, (T,), (1,), (i_t * BT,), (BC,), (0,))
145
+ p_beta2 = tl.make_block_ptr(beta + i_bh*T, (T,), (1,), (i_t * BT + BC,), (BC,), (0,))
146
+ else:
147
+ p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BC,), (0,))
148
+ p_beta2 = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT + BC,), (BC,), (0,))
149
+
150
+ b_beta = tl.load(p_beta, boundary_check=(0,))
151
+ b_beta2 = tl.load(p_beta2, boundary_check=(0,))
152
+
153
+ for i_k in range(tl.cdiv(K, BK)):
154
+ if HEAD_FIRST:
155
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
156
+ p_k2 = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + BC, i_k * BK), (BC, BK), (1, 0))
157
+ else:
158
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
159
+ p_k2 = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + BC, i_k * BK), (BC, BK), (1, 0))
160
+ b_k = tl.load(p_k, boundary_check=(0, 1))
161
+ b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
162
+ b_k2 = tl.load(p_k2, boundary_check=(0, 1))
163
+ b_kb2 = (b_k2 * b_beta2[:, None]).to(b_k2.dtype)
164
+ b_Aw += tl.dot(b_kb, tl.trans(b_k))
165
+ b_Aw2 += tl.dot(b_kb2, tl.trans(b_k2))
166
+ b_Aw3 += tl.dot(b_kb2, tl.trans(b_k))
167
+
168
+ b_Aw = -tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_Aw, 0)
169
+ b_Aw2 = -tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_Aw2, 0)
170
+
171
+ if HEAD_FIRST:
172
+ p_g = tl.make_block_ptr(g + i_bh*T, (T,), (1,), (i_t * BT,), (BC,), (0,))
173
+ p_g2 = tl.make_block_ptr(g + i_bh*T, (T,), (1,), (i_t * BT + BC,), (BC,), (0,))
174
+ else:
175
+ p_g = tl.make_block_ptr(g + bos*H + i_h, (T,), (H,), (i_t * BT,), (BC,), (0,))
176
+ p_g2 = tl.make_block_ptr(g + bos*H + i_h, (T,), (H,), (i_t * BT + BC,), (BC,), (0,))
177
+ b_g = tl.load(p_g, boundary_check=(0,))
178
+ b_g2 = tl.load(p_g2, boundary_check=(0,))
179
+
180
+ mask_c = tl.arange(0, BC)[:, None] >= tl.arange(0, BC)[None, :]
181
+ mask_g = i_t * BT + tl.arange(0, BC) < T
182
+ mask_g2 = i_t * BT + BC + tl.arange(0, BC) < T
183
+
184
+ b_Au = tl.where(mask_g[None, :] & mask_c, b_Aw * safe_exp(b_g[:, None] - b_g[None, :]), 0)
185
+ b_Au2 = tl.where(mask_g2[None, :] & mask_c, b_Aw2 * safe_exp(b_g2[:, None] - b_g2[None, :]), 0)
186
+ b_Au3 = tl.where(mask_g[None, :], b_Aw3 * safe_exp(b_g2[:, None] - b_g[None, :]), 0)
187
+
188
+ for i in range(1, BC):
189
+ mask = tl.arange(0, BC) == i
190
+ b_aw = tl.sum(tl.where(mask[:, None], b_Aw, 0), 0)
191
+ b_aw2 = tl.sum(tl.where(mask[:, None], b_Aw2, 0), 0)
192
+ b_au = tl.sum(tl.where(mask[:, None], b_Au, 0), 0)
193
+ b_au2 = tl.sum(tl.where(mask[:, None], b_Au2, 0), 0)
194
+ b_aw = b_aw + tl.sum(b_aw[:, None] * b_Aw, 0) * (tl.arange(0, BC) < i)
195
+ b_aw2 = b_aw2 + tl.sum(b_aw2[:, None] * b_Aw2, 0) * (tl.arange(0, BC) < i)
196
+ b_au = b_au + tl.sum(b_au[:, None] * b_Au, 0) * (tl.arange(0, BC) < i)
197
+ b_au2 = b_au2 + tl.sum(b_au2[:, None] * b_Au2, 0) * (tl.arange(0, BC) < i)
198
+ b_Aw = tl.where(mask[:, None], b_aw, b_Aw)
199
+ b_Aw2 = tl.where(mask[:, None], b_aw2, b_Aw2)
200
+ b_Au = tl.where(mask[:, None], b_au, b_Au)
201
+ b_Au2 = tl.where(mask[:, None], b_au2, b_Au2)
202
+ # blockwise computation of lower triangular matrix's inverse
203
+ # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1]
204
+ b_Aw += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
205
+ b_Aw2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
206
+ # improve precision by disallowing tf32.
207
+ b_Aw3 = -tl.dot(tl.dot(b_Aw2, b_Aw3, allow_tf32=False), b_Aw, allow_tf32=False)
208
+ b_Au += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
209
+ b_Au2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
210
+ b_Au3 = -tl.dot(tl.dot(b_Au2, b_Au3, allow_tf32=False), b_Au, allow_tf32=False)
211
+
212
+ if HEAD_FIRST:
213
+ p_Aw1 = tl.make_block_ptr(Aw + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
214
+ p_Aw2 = tl.make_block_ptr(Aw + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
215
+ p_Aw3 = tl.make_block_ptr(Aw + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
216
+ p_Aw4 = tl.make_block_ptr(Aw + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
217
+ p_Au1 = tl.make_block_ptr(Au + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
218
+ p_Au2 = tl.make_block_ptr(Au + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
219
+ p_Au3 = tl.make_block_ptr(Au + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
220
+ p_Au4 = tl.make_block_ptr(Au + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
221
+ else:
222
+ p_Aw1 = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
223
+ p_Aw2 = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
224
+ p_Aw3 = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
225
+ p_Aw4 = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
226
+ p_Au1 = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
227
+ p_Au2 = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
228
+ p_Au3 = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
229
+ p_Au4 = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
230
+
231
+ tl.store(p_Aw1, b_Aw.to(p_Aw1.dtype.element_ty), boundary_check=(0, 1))
232
+ tl.store(p_Aw2, b_Aw2.to(p_Aw2.dtype.element_ty), boundary_check=(0, 1))
233
+ tl.store(p_Aw3, b_Aw3.to(p_Aw3.dtype.element_ty), boundary_check=(0, 1))
234
+ tl.store(p_Aw4, tl.zeros([BC, BC], dtype=tl.float32).to(p_Aw4.dtype.element_ty), boundary_check=(0, 1))
235
+ tl.store(p_Au1, b_Au.to(p_Au1.dtype.element_ty), boundary_check=(0, 1))
236
+ tl.store(p_Au2, b_Au2.to(p_Au2.dtype.element_ty), boundary_check=(0, 1))
237
+ tl.store(p_Au3, b_Au3.to(p_Au3.dtype.element_ty), boundary_check=(0, 1))
238
+ tl.store(p_Au4, tl.zeros([BC, BC], dtype=tl.float32).to(p_Au4.dtype.element_ty), boundary_check=(0, 1))
239
+
240
+
241
+ @triton.heuristics({
242
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
243
+ })
244
+ @triton.autotune(
245
+ configs=[
246
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
247
+ for num_warps in [2, 4, 8]
248
+ for num_stages in [2, 3, 4]
249
+ ],
250
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'HEAD_FIRST', 'USE_OFFSETS'],
251
+ )
252
+ @triton.jit(do_not_specialize=['T'])
253
+ def fwd_recompute_w_u_kernel(
254
+ k,
255
+ v,
256
+ beta,
257
+ w,
258
+ u,
259
+ Aw,
260
+ Au,
261
+ offsets,
262
+ indices,
263
+ T,
264
+ H: tl.constexpr,
265
+ K: tl.constexpr,
266
+ V: tl.constexpr,
267
+ BT: tl.constexpr,
268
+ BK: tl.constexpr,
269
+ BV: tl.constexpr,
270
+ HEAD_FIRST: tl.constexpr,
271
+ USE_OFFSETS: tl.constexpr
272
+ ):
273
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
274
+ i_b, i_h = i_bh // H, i_bh % H
275
+ if USE_OFFSETS:
276
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
277
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
278
+ T = eos - bos
279
+ else:
280
+ bos, eos = i_b * T, i_b * T + T
281
+ if HEAD_FIRST:
282
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
283
+ p_Au = tl.make_block_ptr(Au + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
284
+ else:
285
+ p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
286
+ p_Au = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
287
+ b_beta = tl.load(p_beta, boundary_check=(0,))
288
+ b_Au = tl.load(p_Au, boundary_check=(0, 1))
289
+
290
+ for i_v in range(tl.cdiv(V, BV)):
291
+ if HEAD_FIRST:
292
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
293
+ p_u = tl.make_block_ptr(u + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
294
+ else:
295
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
296
+ p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
297
+ b_v = tl.load(p_v, boundary_check=(0, 1))
298
+ b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
299
+ b_u = tl.dot(b_Au, b_vb, allow_tf32=False)
300
+ tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
301
+
302
+ tl.debug_barrier()
303
+ b_Au = None
304
+ if HEAD_FIRST:
305
+ p_Aw = tl.make_block_ptr(Aw + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
306
+ else:
307
+ p_Aw = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
308
+ b_Aw = tl.load(p_Aw, boundary_check=(0, 1))
309
+
310
+ for i_k in range(tl.cdiv(K, BK)):
311
+ if HEAD_FIRST:
312
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
313
+ p_w = tl.make_block_ptr(w + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
314
+ else:
315
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
316
+ p_w = tl.make_block_ptr(w + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
317
+ b_k = tl.load(p_k, boundary_check=(0, 1))
318
+ b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
319
+ b_w = tl.dot(b_Aw, b_kb)
320
+ tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
321
+
322
+
323
+ def fwd_prepare_wy_repr(
324
+ k: torch.Tensor,
325
+ v: torch.Tensor,
326
+ g: torch.Tensor,
327
+ beta: torch.Tensor,
328
+ offsets: Optional[torch.LongTensor],
329
+ indices: Optional[torch.LongTensor],
330
+ head_first: bool = True,
331
+ chunk_size: int = 64
332
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
333
+ if head_first:
334
+ B, H, T, K = k.shape
335
+ else:
336
+ B, T, H, K = k.shape
337
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
338
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
339
+ BC = min(BT, 32)
340
+ BK = min(triton.next_power_of_2(K), 64)
341
+ # bf16 should be good enough.
342
+ Aw = torch.empty(B, *((H, T) if head_first else (T, H)), BT, device=k.device, dtype=k.dtype)
343
+ Au = torch.empty(B, *((H, T) if head_first else (T, H)), BT, device=k.device, dtype=k.dtype)
344
+
345
+ fwd_fn = fwd_prepare_wy_repr_kernel_chunk64 if BT == 64 else fwd_prepare_wy_repr_kernel_chunk32
346
+ fwd_fn[(NT, B*H)](
347
+ k=k,
348
+ g=g,
349
+ beta=beta,
350
+ Aw=Aw,
351
+ Au=Au,
352
+ offsets=offsets,
353
+ indices=indices,
354
+ T=T,
355
+ H=H,
356
+ K=K,
357
+ BT=BT,
358
+ BK=BK,
359
+ BC=BC,
360
+ HEAD_FIRST=head_first
361
+ )
362
+ w, u = fwd_recompute_w_u(
363
+ k=k,
364
+ v=v,
365
+ beta=beta,
366
+ Aw=Aw,
367
+ Au=Au,
368
+ offsets=offsets,
369
+ indices=indices,
370
+ head_first=head_first,
371
+ chunk_size=chunk_size
372
+ )
373
+ return w, u, Aw, Au
374
+
375
+
376
+ def fwd_recompute_w_u(
377
+ k: torch.Tensor,
378
+ v: torch.Tensor,
379
+ beta: torch.Tensor,
380
+ Aw: torch.Tensor,
381
+ Au: torch.Tensor,
382
+ offsets: Optional[torch.LongTensor],
383
+ indices: Optional[torch.LongTensor],
384
+ head_first: bool,
385
+ chunk_size: int
386
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
387
+ if head_first:
388
+ B, H, T, K, V = *k.shape, v.shape[-1]
389
+ else:
390
+ B, T, H, K, V = *k.shape, v.shape[-1]
391
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
392
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
393
+ BK = min(triton.next_power_of_2(K), 64)
394
+ BV = min(triton.next_power_of_2(V), 64)
395
+
396
+ u = torch.empty_like(v)
397
+ w = torch.empty_like(k)
398
+ fwd_recompute_w_u_kernel[(NT, B*H)](
399
+ k=k,
400
+ v=v,
401
+ beta=beta,
402
+ w=w,
403
+ u=u,
404
+ Aw=Aw,
405
+ Au=Au,
406
+ offsets=offsets,
407
+ indices=indices,
408
+ T=T,
409
+ H=H,
410
+ K=K,
411
+ V=V,
412
+ BT=BT,
413
+ BK=BK,
414
+ BV=BV,
415
+ HEAD_FIRST=head_first
416
+ )
417
+ return w, u
418
+
419
+
420
+ @triton.heuristics({
421
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
422
+ })
423
+ @triton.autotune(
424
+ configs=[
425
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
426
+ for num_warps in [2, 4]
427
+ for num_stages in [2, 3, 4]
428
+ ],
429
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'HEAD_FIRST', 'USE_OFFSETS']
430
+ )
431
+ @triton.jit(do_not_specialize=['T'])
432
+ def bwd_prepare_wy_repr_kernel(
433
+ k,
434
+ v,
435
+ beta,
436
+ g,
437
+ Aw,
438
+ Au,
439
+ dw,
440
+ du,
441
+ dk,
442
+ dv,
443
+ dbeta,
444
+ dg,
445
+ offsets,
446
+ indices,
447
+ T,
448
+ H: tl.constexpr,
449
+ K: tl.constexpr,
450
+ V: tl.constexpr,
451
+ BT: tl.constexpr,
452
+ BK: tl.constexpr,
453
+ BV: tl.constexpr,
454
+ HEAD_FIRST: tl.constexpr,
455
+ USE_OFFSETS: tl.constexpr
456
+ ):
457
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
458
+ i_b, i_h = i_bh // H, i_bh % H
459
+ if USE_OFFSETS:
460
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
461
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
462
+ T = eos - bos
463
+ else:
464
+ bos, eos = i_b * T, i_b * T + T
465
+
466
+ b_dbeta = tl.zeros([BT], dtype=tl.float32)
467
+ b_dA = tl.zeros([BT, BT], dtype=tl.float32)
468
+ if HEAD_FIRST:
469
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
470
+ p_A = tl.make_block_ptr(Aw + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
471
+ else:
472
+ p_beta = tl.make_block_ptr(beta + (bos*H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,))
473
+ p_A = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
474
+
475
+ b_A = tl.load(p_A, boundary_check=(0, 1))
476
+ b_beta = tl.load(p_beta, boundary_check=(0,))
477
+
478
+ for i_k in range(tl.cdiv(K, BK)):
479
+ if HEAD_FIRST:
480
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
481
+ p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
482
+ p_dw = tl.make_block_ptr(dw + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
483
+ else:
484
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
485
+ p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
486
+ p_dw = tl.make_block_ptr(dw + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
487
+ b_k = tl.load(p_k, boundary_check=(0, 1))
488
+ b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)
489
+ b_dw = tl.load(p_dw, boundary_check=(0, 1))
490
+ b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False)
491
+ b_dk_beta = tl.dot(b_A, b_dw, allow_tf32=False)
492
+ b_dk = b_dk_beta * b_beta[:, None]
493
+ b_dbeta += tl.sum(b_dk_beta * b_k, 1)
494
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
495
+
496
+ b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA, 0)
497
+ b_dA = tl.dot(b_dA.to(b_A.dtype), b_A)
498
+ b_dA = tl.dot(b_A, b_dA.to(b_A.dtype))
499
+ b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA, 0).to(k.dtype.element_ty)
500
+
501
+ if HEAD_FIRST:
502
+ p_A = tl.make_block_ptr(Au + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
503
+ else:
504
+ p_A = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
505
+ b_A = tl.load(p_A, boundary_check=(0, 1))
506
+ b_dA2 = tl.zeros([BT, BT], dtype=tl.float32)
507
+
508
+ for i_v in range(tl.cdiv(V, BV)):
509
+ if HEAD_FIRST:
510
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
511
+ p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
512
+ p_du = tl.make_block_ptr(du + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
513
+ else:
514
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
515
+ p_dv = tl.make_block_ptr(dv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
516
+ p_du = tl.make_block_ptr(du + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
517
+ b_v = tl.load(p_v, boundary_check=(0, 1))
518
+ b_v_beta = (b_v * b_beta[:, None]).to(b_v.dtype)
519
+ b_du = tl.load(p_du, boundary_check=(0, 1))
520
+ b_dA2 += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)
521
+ b_dv_beta = tl.dot(b_A, b_du, allow_tf32=False)
522
+ b_dv = b_dv_beta * b_beta[:, None]
523
+ b_dbeta += tl.sum(b_dv_beta * b_v, 1)
524
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
525
+
526
+ b_dA2 = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA2, 0)
527
+ b_dA2 = tl.dot(b_dA2.to(b_A.dtype), b_A)
528
+ b_dA2 = tl.dot(b_A, b_dA2.to(b_A.dtype))
529
+ b_dA2 = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA2, 0).to(k.dtype.element_ty)
530
+ if HEAD_FIRST:
531
+ p_g = tl.make_block_ptr(g + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
532
+ else:
533
+ p_g = tl.make_block_ptr(g + (bos*H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,))
534
+ b_g = tl.load(p_g, boundary_check=(0,))
535
+ b_dA2 *= safe_exp(b_g[:, None] - b_g[None, :])
536
+ b_dA += b_dA2
537
+ b_dA = b_dA.to(k.dtype.element_ty)
538
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
539
+
540
+ for i_k in range(tl.cdiv(K, BK)):
541
+ if HEAD_FIRST:
542
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
543
+ p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
544
+ else:
545
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
546
+ p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
547
+ b_k = tl.load(p_k, boundary_check=(0, 1))
548
+ b_dk = tl.load(p_dk, boundary_check=(0, 1))
549
+ b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)
550
+ b_A += tl.dot(b_k_beta, tl.trans(b_k))
551
+ b_dk_beta = tl.dot(b_dA, b_k, allow_tf32=False)
552
+ b_dbeta += tl.sum(b_dk_beta * b_k, 1)
553
+ b_dk += tl.dot(tl.trans(b_dA), b_k_beta, allow_tf32=False)
554
+ b_dk += b_dk_beta * b_beta[:, None]
555
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
556
+ b_dA2 *= b_A
557
+ b_dg = tl.sum(b_dA2, axis=1) - tl.sum(b_dA2, axis=0)
558
+ if HEAD_FIRST:
559
+ p_dg = tl.make_block_ptr(dg + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
560
+ p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
561
+ else:
562
+ p_dg = tl.make_block_ptr(dg + (bos*H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,))
563
+ p_dbeta = tl.make_block_ptr(dbeta + (bos*H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,))
564
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,))
565
+ tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,))
566
+
567
+
568
+ def bwd_prepare_wy_repr(
569
+ k: torch.Tensor,
570
+ v: torch.Tensor,
571
+ g: torch.Tensor,
572
+ beta: torch.Tensor,
573
+ Aw: torch.Tensor,
574
+ Au: torch.Tensor,
575
+ dw: torch.Tensor,
576
+ du: torch.Tensor,
577
+ offsets: Optional[torch.LongTensor],
578
+ indices: Optional[torch.LongTensor],
579
+ head_first: bool,
580
+ chunk_size: int
581
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
582
+ if head_first:
583
+ B, H, T, K, V = *k.shape, v.shape[-1]
584
+ else:
585
+ B, T, H, K, V = *k.shape, v.shape[-1]
586
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
587
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
588
+ CONST_TILING = 64 if check_shared_mem() else 32
589
+ BK = min(triton.next_power_of_2(K), CONST_TILING)
590
+ BV = min(triton.next_power_of_2(V), CONST_TILING)
591
+
592
+ dk = torch.empty_like(k)
593
+ dv = torch.empty_like(v)
594
+ dbeta = torch.empty_like(beta)
595
+ dg = torch.empty_like(g)
596
+ bwd_prepare_wy_repr_kernel[(NT, B * H)](
597
+ k=k,
598
+ v=v,
599
+ beta=beta,
600
+ g=g,
601
+ Aw=Aw,
602
+ Au=Au,
603
+ dw=dw,
604
+ du=du,
605
+ dk=dk,
606
+ dv=dv,
607
+ dbeta=dbeta,
608
+ dg=dg,
609
+ offsets=offsets,
610
+ indices=indices,
611
+ T=T,
612
+ H=H,
613
+ K=K,
614
+ V=V,
615
+ BT=BT,
616
+ BK=BK,
617
+ BV=BV,
618
+ HEAD_FIRST=head_first
619
+ )
620
+ return dk, dv, dbeta, dg
fla/ops/generalized_delta_rule/README.md ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generalized Delta Rule
2
+
3
+ In delta rule we have the recurrence:
4
+
5
+ ```math
6
+ \mathbf{S}_t = \mathbf{S}_{t-1}(\mathbf{I}-\beta_t \mathbf{k}_t\mathbf{k}_t^T) + \beta_t \mathbf{v}_t\mathbf{k}_t^T
7
+ ```
8
+
9
+ This repository implements a delta rule variant where $\mathbf{I}$ is not necessarily an identity matrix; $\mathbf{k}_t$ in $\mathbf{I} - \beta_t \mathbf{k}_t\mathbf{k}_t^T$ might be different from input $\mathbf{k}_t$ in $\mathbf{v}_t\mathbf{k}_t^T$.
10
+
11
+ ## IPLR (Identity Plus Low Rank)
12
+
13
+ The first variant is IPLR, where we have:
14
+
15
+ ```math
16
+ \mathbf{S}_t = \mathbf{S}_{t-1}(\mathbf{I}+\mathbf{a}_t\mathbf{b}_t^T) + \mathbf{v}_t\mathbf{k}_t^T
17
+ ```
18
+
19
+ When $\mathbf{a}_t = -\beta_t \mathbf{k}_t$, $\mathbf{b}_t = \mathbf{k}_t$, $\mathbf{v}_t= \beta_t \mathbf{v}_t$, we recover the original delta rule. Since here the transition matrix is identity-plus-low-rank, we refer to this variant as IPLR.
20
+
21
+ ### Numerical Stability
22
+
23
+ $\mathbf{a}_t$ and $\mathbf{b}_t$ must be in opposite directions, that is, $\mathbf{b}_t = \lambda_t \mathbf{a}_t$ where $\lambda_t < 0$. For an understanding of why this is necessary, you can derive the eigenvalues of the transition matrix.
24
+
25
+ ## DPLR (Diagonal Plus Low Rank)
26
+
27
+ The second variant is DPLR, where we have:
28
+
29
+ ```math
30
+ \mathbf{S}_t = \mathbf{S}_{t-1}(\mathbf{D}_t+\mathbf{a}_t\mathbf{b}_t^T) + \mathbf{v}_t\mathbf{k}_t^T
31
+ ```
32
+
33
+ Here, $\mathbf{I}$ is replaced by a diagonal matrix $\mathbf{D}_t$. This transition matrix structure has been utilized in RWKV7.
34
+
35
+ ## Efficient Chunkwise Implementation
36
+
37
+ For detailed information about efficient chunkwise implementation, please refer to our [technical note](https://drive.google.com/file/d/1rJbO3dU4fe7OKG3w7Yg058z_BNIuavNF/view?usp=sharing).
fla/ops/generalized_delta_rule/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .dplr import chunk_dplr_delta_rule, fused_recurrent_dplr_delta_rule
2
+ from .iplr import chunk_iplr_delta_rule, fused_recurrent_iplr_delta_rule
3
+
4
+ __all__ = [
5
+ 'chunk_dplr_delta_rule',
6
+ 'fused_recurrent_dplr_delta_rule',
7
+ 'chunk_iplr_delta_rule',
8
+ 'fused_recurrent_iplr_delta_rule'
9
+ ]
fla/ops/generalized_delta_rule/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (450 Bytes). View file
 
fla/ops/generalized_delta_rule/dplr/chunk.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+
9
+ from fla.ops.common.utils import prepare_chunk_indices
10
+ from fla.ops.generalized_delta_rule.dplr.chunk_A_bwd import chunk_dplr_bwd_dqk_intra
11
+ from fla.ops.generalized_delta_rule.dplr.chunk_A_fwd import chunk_fwd_intra_dplr_fn
12
+ from fla.ops.generalized_delta_rule.dplr.chunk_h_bwd import chunk_dplr_bwd_dhu
13
+ from fla.ops.generalized_delta_rule.dplr.chunk_h_fwd import chunk_dplr_fwd_h
14
+ from fla.ops.generalized_delta_rule.dplr.chunk_o_bwd import chunk_dplr_bwd_dAu, chunk_dplr_bwd_dv, chunk_dplr_bwd_o
15
+ from fla.ops.generalized_delta_rule.dplr.chunk_o_fwd import chunk_dplr_fwd_o
16
+ from fla.ops.generalized_delta_rule.dplr.wy_fast_bwd import chunk_dplr_bwd_wy
17
+ from fla.ops.generalized_delta_rule.dplr.wy_fast_fwd import fwd_prepare_wy_repr
18
+ from fla.ops.rwkv6.chunk import chunk_rwkv6_fwd_cumsum
19
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
20
+
21
+
22
+ def chunk_dplr_fwd(
23
+ q: torch.Tensor,
24
+ k: torch.Tensor,
25
+ v: torch.Tensor,
26
+ a: torch.Tensor,
27
+ b: torch.Tensor,
28
+ gk: torch.Tensor,
29
+ scale: float,
30
+ initial_state: torch.Tensor,
31
+ output_final_state: bool,
32
+ offsets: Optional[torch.LongTensor] = None,
33
+ indices: Optional[torch.LongTensor] = None,
34
+ head_first: bool = True,
35
+ chunk_size: int = 64
36
+ ):
37
+ T = q.shape[2] if head_first else q.shape[1]
38
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
39
+ gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT, offsets=offsets, indices=indices, head_first=head_first)
40
+
41
+ A_ab, A_qk, A_ak, A_qb, qg, kg, ag, bg = chunk_fwd_intra_dplr_fn(
42
+ q=q,
43
+ k=k,
44
+ a=a,
45
+ b=b,
46
+ gi=gi,
47
+ ge=ge,
48
+ scale=scale,
49
+ offsets=offsets,
50
+ indices=indices,
51
+ chunk_size=BT,
52
+ head_first=head_first
53
+ )
54
+ del ge
55
+
56
+ # A_ab, A_ak, gi, ge torch.float32
57
+ # A_qk, A_qb, qg, kg, ag, bg, dtype=q.dtype, eg: bf16
58
+ w, u, _ = fwd_prepare_wy_repr(
59
+ ag=ag,
60
+ A_ab=A_ab,
61
+ A_ak=A_ak,
62
+ v=v,
63
+ offsets=offsets,
64
+ indices=indices,
65
+ head_first=head_first,
66
+ chunk_size=BT
67
+ )
68
+ del A_ab, A_ak
69
+ h, v_new, final_state = chunk_dplr_fwd_h(
70
+ kg=kg,
71
+ bg=bg,
72
+ v=v,
73
+ w=w,
74
+ u=u,
75
+ gk=gi,
76
+ initial_state=initial_state,
77
+ output_final_state=output_final_state,
78
+ offsets=offsets,
79
+ indices=indices,
80
+ head_first=head_first,
81
+ chunk_size=BT
82
+ )
83
+ del u, kg, bg, gi
84
+
85
+ o = chunk_dplr_fwd_o(
86
+ qg=qg,
87
+ v=v,
88
+ v_new=v_new,
89
+ A_qk=A_qk,
90
+ A_qb=A_qb,
91
+ h=h,
92
+ offsets=offsets,
93
+ indices=indices,
94
+ head_first=head_first,
95
+ chunk_size=BT
96
+ )
97
+ del v_new, h, A_qk, A_qb
98
+
99
+ return o, final_state
100
+
101
+
102
+ class ChunkDPLRDeltaRuleFunction(torch.autograd.Function):
103
+
104
+ @staticmethod
105
+ @input_guard
106
+ @autocast_custom_fwd
107
+ def forward(
108
+ ctx,
109
+ q: torch.Tensor,
110
+ k: torch.Tensor,
111
+ v: torch.Tensor,
112
+ a: torch.Tensor,
113
+ b: torch.Tensor,
114
+ gk: torch.Tensor,
115
+ scale: float,
116
+ initial_state: torch.Tensor,
117
+ output_final_state: bool,
118
+ offsets: Optional[torch.LongTensor] = None,
119
+ head_first: bool = True
120
+ ):
121
+ chunk_size = 16
122
+
123
+ # 2-d indices denoting the offsets of chunks in each sequence
124
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
125
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
126
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
127
+ indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None
128
+
129
+ o, final_state = chunk_dplr_fwd(
130
+ q=q,
131
+ k=k,
132
+ v=v,
133
+ a=a,
134
+ b=b,
135
+ gk=gk,
136
+ scale=scale,
137
+ initial_state=initial_state,
138
+ output_final_state=output_final_state,
139
+ offsets=offsets,
140
+ indices=indices,
141
+ head_first=head_first,
142
+ chunk_size=chunk_size
143
+ )
144
+ ctx.save_for_backward(q, k, v, a, b, gk, initial_state)
145
+ ctx.head_first = head_first
146
+ ctx.offsets = offsets
147
+ ctx.indices = indices
148
+ ctx.scale = scale
149
+ ctx.chunk_size = chunk_size
150
+ return o.to(q.dtype), final_state
151
+
152
+ @staticmethod
153
+ @input_guard
154
+ @autocast_custom_bwd
155
+ def backward(
156
+ ctx,
157
+ do: torch.Tensor,
158
+ dht: torch.Tensor
159
+ ):
160
+ q, k, v, a, b, gk, initial_state = ctx.saved_tensors
161
+ BT = ctx.chunk_size
162
+ head_first = ctx.head_first
163
+ offsets = ctx.offsets
164
+ indices = ctx.indices
165
+ scale = ctx.scale
166
+
167
+ # ******* start recomputing everything, otherwise i believe the gpu memory will be exhausted *******
168
+ gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT, offsets=offsets, indices=indices, head_first=head_first)
169
+
170
+ A_ab, A_qk, A_ak, A_qb, qg, kg, ag, bg = chunk_fwd_intra_dplr_fn(
171
+ q=q,
172
+ k=k,
173
+ a=a,
174
+ b=b,
175
+ gi=gi,
176
+ ge=ge,
177
+ scale=scale,
178
+ offsets=offsets,
179
+ indices=indices,
180
+ chunk_size=BT,
181
+ head_first=head_first
182
+ )
183
+ w, u, A_ab_inv = fwd_prepare_wy_repr(
184
+ ag=ag,
185
+ A_ab=A_ab,
186
+ A_ak=A_ak,
187
+ v=v,
188
+ offsets=offsets,
189
+ indices=indices,
190
+ head_first=head_first,
191
+ chunk_size=BT
192
+ )
193
+ del A_ab
194
+ h, v_new, _ = chunk_dplr_fwd_h(
195
+ kg=kg,
196
+ bg=bg,
197
+ v=v,
198
+ w=w,
199
+ u=u,
200
+ gk=gi,
201
+ initial_state=initial_state,
202
+ offsets=offsets,
203
+ indices=indices,
204
+ head_first=head_first,
205
+ chunk_size=BT
206
+ )
207
+ del u
208
+ # ******* end of recomputation *******
209
+ # A_ak, A_ab_inv, gi, ge torch.float32
210
+ # A_qk, A_qb, qg, kg, ag, bg, v_new dtype=q.dtype, eg: bf16
211
+
212
+ dv_new_intra, dA_qk, dA_qb = chunk_dplr_bwd_dAu(
213
+ v=v,
214
+ v_new=v_new,
215
+ do=do,
216
+ A_qb=A_qb,
217
+ scale=scale,
218
+ offsets=offsets,
219
+ indices=indices,
220
+ head_first=head_first,
221
+ chunk_size=BT
222
+ )
223
+
224
+ dh, dh0, dv_new = chunk_dplr_bwd_dhu(
225
+ qg=qg,
226
+ bg=bg,
227
+ w=w,
228
+ gk=gi,
229
+ h0=initial_state,
230
+ dht=dht,
231
+ do=do,
232
+ dv=dv_new_intra,
233
+ offsets=offsets,
234
+ indices=indices,
235
+ head_first=head_first,
236
+ chunk_size=BT
237
+ )
238
+
239
+ dv = chunk_dplr_bwd_dv(
240
+ A_qk=A_qk,
241
+ kg=kg,
242
+ do=do,
243
+ dh=dh,
244
+ offsets=offsets,
245
+ indices=indices,
246
+ head_first=head_first,
247
+ chunk_size=BT
248
+ )
249
+ del A_qk
250
+
251
+ dqg, dkg, dw, dbg, dgk_last = chunk_dplr_bwd_o(
252
+ k=kg,
253
+ b=bg,
254
+ v=v,
255
+ v_new=v_new,
256
+ do=do,
257
+ h=h,
258
+ dh=dh,
259
+ dv=dv_new,
260
+ w=w,
261
+ gk=gi,
262
+ offsets=offsets,
263
+ indices=indices,
264
+ chunk_size=BT,
265
+ scale=scale,
266
+ head_first=head_first,
267
+ )
268
+ del v_new
269
+
270
+ dA_ab, dA_ak, dv, dag = chunk_dplr_bwd_wy(
271
+ A_ab_inv=A_ab_inv,
272
+ A_ak=A_ak,
273
+ v=v,
274
+ ag=ag,
275
+ dw=dw,
276
+ du=dv_new,
277
+ dv0=dv,
278
+ offsets=offsets,
279
+ indices=indices,
280
+ head_first=head_first,
281
+ chunk_size=BT
282
+ )
283
+ del A_ak
284
+
285
+ dq, dk, da, db, dgk = chunk_dplr_bwd_dqk_intra(
286
+ q=q,
287
+ k=k,
288
+ a=a,
289
+ b=b,
290
+ gi=gi,
291
+ ge=ge,
292
+ dAqk=dA_qk,
293
+ dAqb=dA_qb,
294
+ dAak=dA_ak,
295
+ dAab=dA_ab,
296
+ dgk_last=dgk_last,
297
+ dqg=dqg,
298
+ dkg=dkg,
299
+ dag=dag,
300
+ dbg=dbg,
301
+ chunk_size=BT,
302
+ scale=scale,
303
+ head_first=head_first,
304
+ offsets=offsets,
305
+ indices=indices
306
+ )
307
+
308
+ return dq.to(q), dk.to(k), dv.to(v), da.to(a), db.to(b), dgk.to(gk), None, dh0, None, None, None
309
+
310
+
311
+ @torch.compiler.disable
312
+ def chunk_dplr_delta_rule(
313
+ q: torch.Tensor,
314
+ k: torch.Tensor,
315
+ v: torch.Tensor,
316
+ a: torch.Tensor,
317
+ b: torch.Tensor,
318
+ gk: torch.Tensor,
319
+ scale: Optional[float] = None,
320
+ initial_state: Optional[torch.Tensor] = None,
321
+ output_final_state: bool = False,
322
+ cu_seqlens: Optional[torch.LongTensor] = None,
323
+ head_first: bool = False
324
+ ):
325
+ r"""
326
+ Args:
327
+ q (torch.Tensor):
328
+ queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
329
+ k (torch.Tensor):
330
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
331
+ v (torch.Tensor):
332
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
333
+ a (torch.Tensor):
334
+ activations of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
335
+ b (torch.Tensor):
336
+ betas of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
337
+ gk (torch.Tensor):
338
+ gk of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. decay term in log space!
339
+ scale (Optional[int]):
340
+ Scale factor for the RetNet attention scores.
341
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
342
+ initial_state (Optional[torch.Tensor]):
343
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
344
+ For equal-length input sequences, `N` equals the batch size `B`.
345
+ Default: `None`.
346
+ output_final_state (Optional[bool]):
347
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
348
+ cu_seqlens (torch.LongTensor):
349
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
350
+ consistent with the FlashAttention API.
351
+ head_first (Optional[bool]):
352
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
353
+ Default: `False`.
354
+
355
+ Returns:
356
+ o (torch.Tensor):
357
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
358
+ final_state (torch.Tensor):
359
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
360
+ """
361
+ assert q.dtype == k.dtype == v.dtype
362
+ # assert q.dtype != torch.float32, "ChunkDeltaRuleFunction does not support float32. Please use bfloat16."
363
+ # gk = gk.float()
364
+
365
+ if cu_seqlens is not None:
366
+ if q.shape[0] != 1:
367
+ raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
368
+ f"Please flatten variable-length inputs before processing.")
369
+ if head_first:
370
+ raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
371
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
372
+ raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, "
373
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.")
374
+ scale = k.shape[-1] ** -0.5 if scale is None else scale
375
+ o, final_state = ChunkDPLRDeltaRuleFunction.apply(
376
+ q,
377
+ k,
378
+ v,
379
+ a,
380
+ b,
381
+ gk,
382
+ scale,
383
+ initial_state,
384
+ output_final_state,
385
+ cu_seqlens,
386
+ head_first
387
+ )
388
+ return o, final_state
fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import exp, gather
11
+ from fla.utils import check_shared_mem, is_gather_supported, use_cuda_graph
12
+
13
+
14
+ @triton.heuristics({
15
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
16
+ })
17
+ @triton.autotune(
18
+ configs=[
19
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
20
+ for num_warps in [2, 4, 8, 16, 32]
21
+ for num_stages in [2, 3, 4]
22
+ ],
23
+ key=['BK', 'NC', 'BT', 'K'],
24
+ use_cuda_graph=use_cuda_graph,
25
+ )
26
+ @triton.jit(do_not_specialize=['T'])
27
+ def chunk_dplr_bwd_kernel_intra(
28
+ q,
29
+ k,
30
+ a,
31
+ b,
32
+ gi,
33
+ ge,
34
+ dAqk,
35
+ dAqb,
36
+ dAak,
37
+ dAab,
38
+ dq,
39
+ dk,
40
+ da,
41
+ db,
42
+ dqg,
43
+ dkg,
44
+ dag,
45
+ dbg,
46
+ dgk,
47
+ dgk_offset,
48
+ offsets,
49
+ indices,
50
+ scale: tl.constexpr,
51
+ T,
52
+ H: tl.constexpr,
53
+ K: tl.constexpr,
54
+ BT: tl.constexpr,
55
+ BC: tl.constexpr,
56
+ BK: tl.constexpr,
57
+ NC: tl.constexpr,
58
+ USE_OFFSETS: tl.constexpr,
59
+ HEAD_FIRST: tl.constexpr,
60
+ GATHER_SUPPORTED: tl.constexpr
61
+ ):
62
+ i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
63
+ i_b, i_h = i_bh // H, i_bh % H
64
+ i_t, i_i = i_c // NC, i_c % NC
65
+ if USE_OFFSETS:
66
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
67
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
68
+ else:
69
+ bos, eos = i_b * T, i_b * T + T
70
+ T = eos - bos
71
+ if i_t * BT + i_i * BC >= T:
72
+ return
73
+
74
+ # offset calculation
75
+ ge += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
76
+ gi += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
77
+ q += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
78
+ a += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
79
+ b += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
80
+ k += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
81
+ dq += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
82
+ dk += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
83
+ da += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
84
+ db += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
85
+ dqg += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
86
+ dag += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
87
+ dkg += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
88
+ dbg += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
89
+ dgk += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
90
+ dgk_offset += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
91
+ dAqk += i_bh * T * BT if HEAD_FIRST else (bos*H + i_h) * BT
92
+ dAqb += i_bh * T * BT if HEAD_FIRST else (bos*H + i_h) * BT
93
+ dAak += i_bh * T * BT if HEAD_FIRST else (bos*H + i_h) * BT
94
+ dAab += i_bh * T * BT if HEAD_FIRST else (bos*H + i_h) * BT
95
+
96
+ stride_qk = K if HEAD_FIRST else H*K
97
+ stride_A = BT if HEAD_FIRST else H*BT
98
+
99
+ p_ge = tl.make_block_ptr(ge, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
100
+ p_gi = tl.make_block_ptr(gi, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
101
+ # [BC, BK]
102
+ b_ge = tl.load(p_ge, boundary_check=(0, 1))
103
+ b_gi = tl.load(p_gi, boundary_check=(0, 1))
104
+ b_dq = tl.zeros([BC, BK], dtype=tl.float32)
105
+ b_da = tl.zeros([BC, BK], dtype=tl.float32)
106
+ b_dk = tl.zeros([BC, BK], dtype=tl.float32)
107
+ b_db = tl.zeros([BC, BK], dtype=tl.float32)
108
+ # intra chunk gradient calculation
109
+ p_dAqk = tl.make_block_ptr(dAqk, (T, BT), (stride_A, 1), (i_t*BT + i_i*BC, i_i*BC), (BC, BC), (1, 0))
110
+ p_dAab = tl.make_block_ptr(dAab, (T, BT), (stride_A, 1), (i_t*BT + i_i*BC, i_i*BC), (BC, BC), (1, 0))
111
+ p_dAqb = tl.make_block_ptr(dAqb, (T, BT), (stride_A, 1), (i_t*BT + i_i*BC, i_i*BC), (BC, BC), (1, 0))
112
+ p_dAak = tl.make_block_ptr(dAak, (T, BT), (stride_A, 1), (i_t*BT + i_i*BC, i_i*BC), (BC, BC), (1, 0))
113
+ o_i = tl.arange(0, BC)
114
+ p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t*BT + i_i*BC, i_k*BK), (BC, BK), (1, 0))
115
+ p_b = tl.make_block_ptr(b, (T, K), (stride_qk, 1), (i_t*BT + i_i*BC, i_k*BK), (BC, BK), (1, 0))
116
+ p_a = tl.make_block_ptr(a, (T, K), (stride_qk, 1), (i_t*BT + i_i*BC, i_k*BK), (BC, BK), (1, 0))
117
+ p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t*BT + i_i*BC, i_k*BK), (BC, BK), (1, 0))
118
+ b_k = tl.load(p_k, boundary_check=(0, 1)).to(tl.float32)
119
+ b_b = tl.load(p_b, boundary_check=(0, 1)).to(tl.float32)
120
+ b_q = tl.load(p_q, boundary_check=(0, 1)).to(tl.float32)
121
+ b_a = tl.load(p_a, boundary_check=(0, 1)).to(tl.float32)
122
+ b_dAqk = tl.load(p_dAqk, boundary_check=(0, 1)).to(tl.float32)
123
+ b_dAab = tl.load(p_dAab, boundary_check=(0, 1)).to(tl.float32)
124
+ b_dAqb = tl.load(p_dAqb, boundary_check=(0, 1)).to(tl.float32)
125
+ b_dAak = tl.load(p_dAak, boundary_check=(0, 1)).to(tl.float32)
126
+
127
+ # inter chunk gradient calculation
128
+ o_k = i_k * BK + tl.arange(0, BK)
129
+ m_k = o_k < K
130
+ if i_i > 0:
131
+ p_gn = gi + (i_t * BT + i_i * BC - 1) * stride_qk + o_k
132
+ p_gn = tl.max_contiguous(tl.multiple_of(p_gn, BK), BK)
133
+ # [BK,]
134
+ b_gn = tl.load(p_gn, mask=m_k, other=0)
135
+ # [BK,]
136
+ for i_j in range(0, i_i):
137
+ p_kj = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
138
+ p_bj = tl.make_block_ptr(b, (T, K), (stride_qk, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
139
+ p_gkj = tl.make_block_ptr(gi, (T, K), (stride_qk, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
140
+ p_dAqikj = tl.make_block_ptr(dAqk, (T, BT), (stride_A, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
141
+ p_dAaibj = tl.make_block_ptr(dAab, (T, BT), (stride_A, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
142
+ p_dAqibj = tl.make_block_ptr(dAqb, (T, BT), (stride_A, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
143
+ p_dAaikj = tl.make_block_ptr(dAak, (T, BT), (stride_A, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
144
+ # [BC, BK]
145
+ b_kj = tl.load(p_kj, boundary_check=(0, 1))
146
+ b_bj = tl.load(p_bj, boundary_check=(0, 1))
147
+ b_gkj = tl.load(p_gkj, boundary_check=(0, 1))
148
+ tmp = exp(b_gn[None, :] - b_gkj)
149
+ b_kjg = b_kj * tmp
150
+ b_bjg = b_bj * tmp
151
+ # [BC, BC]
152
+ b_dAqikj = tl.load(p_dAqikj, boundary_check=(0, 1))
153
+ b_dAaibj = tl.load(p_dAaibj, boundary_check=(0, 1))
154
+ b_dAqibj = tl.load(p_dAqibj, boundary_check=(0, 1))
155
+ b_dAaikj = tl.load(p_dAaikj, boundary_check=(0, 1))
156
+ # [BC, BK]
157
+ b_dq += tl.dot(b_dAqikj, b_kjg)
158
+ b_dq += tl.dot(b_dAqibj, b_bjg)
159
+ # [BC, BC]
160
+ b_da += tl.dot(b_dAaibj, b_bjg)
161
+ b_da += tl.dot(b_dAaikj, b_kjg)
162
+ b_dq *= exp(b_gi - b_gn[None, :])
163
+ b_da *= exp(b_ge - b_gn[None, :])
164
+
165
+ NC = min(NC, tl.cdiv(T - i_t * BT, BC))
166
+ if i_i < NC - 1:
167
+ p_gn = gi + (min(i_t * BT + i_i * BC + BC, T) - 1)*stride_qk + o_k
168
+ p_gn = tl.max_contiguous(tl.multiple_of(p_gn, BK), BK)
169
+ # [BK,]
170
+ b_gn = tl.load(p_gn, mask=m_k, other=0)
171
+ for i_j in range(i_i + 1, NC):
172
+ m_j = (i_t * BT + i_j * BC + tl.arange(0, BC)) < T
173
+ p_qj = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
174
+ p_aj = tl.make_block_ptr(a, (T, K), (stride_qk, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
175
+ p_gij = tl.make_block_ptr(gi, (T, K), (stride_qk, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
176
+ p_gej = tl.make_block_ptr(ge, (T, K), (stride_qk, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
177
+ p_dAqjki = tl.make_block_ptr(dAqk, (BT, T), (1, stride_A), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1))
178
+ p_dAajbi = tl.make_block_ptr(dAab, (BT, T), (1, stride_A), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1))
179
+ p_dAqjbi = tl.make_block_ptr(dAqb, (BT, T), (1, stride_A), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1))
180
+ p_dAajki = tl.make_block_ptr(dAak, (BT, T), (1, stride_A), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1))
181
+ b_qj = tl.load(p_qj, boundary_check=(0, 1))
182
+ b_aj = tl.load(p_aj, boundary_check=(0, 1))
183
+ b_gij = tl.load(p_gij, boundary_check=(0, 1))
184
+ b_gej = tl.load(p_gej, boundary_check=(0, 1))
185
+ b_gij = tl.where(m_j[:, None] & m_k, b_gij, float('-inf'))
186
+ b_gej = tl.where(m_j[:, None] & m_k, b_gej, float('-inf'))
187
+ b_qjg = b_qj * exp(b_gij - b_gn[None, :])
188
+ b_ajg = b_aj * exp(b_gej - b_gn[None, :])
189
+ # [BC, BC]
190
+ b_dAqjki = tl.load(p_dAqjki, boundary_check=(0, 1))
191
+ b_dAajbi = tl.load(p_dAajbi, boundary_check=(0, 1))
192
+ b_dAqjbi = tl.load(p_dAqjbi, boundary_check=(0, 1))
193
+ b_dAajki = tl.load(p_dAajki, boundary_check=(0, 1))
194
+ b_dk += tl.dot(b_dAqjki, b_qjg)
195
+ b_dk += tl.dot(b_dAajki, b_ajg)
196
+ b_db += tl.dot(b_dAqjbi, b_qjg)
197
+ b_db += tl.dot(b_dAajbi, b_ajg)
198
+ tmp = exp(b_gn[None, :] - b_gi)
199
+ b_dk *= tmp
200
+ b_db *= tmp
201
+
202
+ # intra chunk gradient calculation
203
+ for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
204
+ # trick to index the block
205
+ if GATHER_SUPPORTED:
206
+ row_idx = tl.full([1, BK], j, dtype=tl.int16)
207
+ col_idx = tl.full([BC, 1], j, dtype=tl.int16)
208
+ row_idx_bc = tl.full([1, BC], j, dtype=tl.int16)
209
+ # [1, BK]
210
+ b_kj = gather(b_k, row_idx, axis=0)
211
+ b_bj = gather(b_b, row_idx, axis=0)
212
+ b_gij = gather(b_gi, row_idx, axis=0)
213
+ b_gej = gather(b_ge, row_idx, axis=0)
214
+ b_qj = gather(b_q, row_idx, axis=0)
215
+ b_aj = gather(b_a, row_idx, axis=0)
216
+ # [BC, 1]
217
+ b_dAqk_j = gather(b_dAqk, col_idx, axis=1)
218
+ b_dAab_j = gather(b_dAab, col_idx, axis=1)
219
+ b_dAqb_j = gather(b_dAqb, col_idx, axis=1)
220
+ b_dAak_j = gather(b_dAak, col_idx, axis=1)
221
+ # [1, BC] -> [BC, 1]
222
+ b_dA_qk_j = tl.sum(gather(b_dAqk, row_idx_bc, axis=0), 0)[:, None]
223
+ b_dA_qk_j = tl.sum(gather(b_dAqk, row_idx_bc, axis=0), 0)[:, None]
224
+ b_dA_ab_j = tl.sum(gather(b_dAab, row_idx_bc, axis=0), 0)[:, None]
225
+ b_dA_qb_j = tl.sum(gather(b_dAqb, row_idx_bc, axis=0), 0)[:, None]
226
+ b_dA_ak_j = tl.sum(gather(b_dAak, row_idx_bc, axis=0), 0)[:, None]
227
+ else:
228
+ mask_idx = tl.arange(0, BC) == j
229
+ b_kj = tl.sum(tl.where(mask_idx[:, None], b_k, 0), 0)[None, :]
230
+ b_bj = tl.sum(tl.where(mask_idx[:, None], b_b, 0), 0)[None, :]
231
+ b_gij = tl.sum(tl.where(mask_idx[:, None], b_gi, 0), 0)[None, :]
232
+ b_gej = tl.sum(tl.where(mask_idx[:, None], b_ge, 0), 0)[None, :]
233
+ b_dAqk_j = tl.sum(tl.where(mask_idx[None, :], b_dAqk, 0), 1)[:, None]
234
+ b_dAab_j = tl.sum(tl.where(mask_idx[None, :], b_dAab, 0), 1)[:, None]
235
+ b_dAqb_j = tl.sum(tl.where(mask_idx[None, :], b_dAqb, 0), 1)[:, None]
236
+ b_dAak_j = tl.sum(tl.where(mask_idx[None, :], b_dAak, 0), 1)[:, None]
237
+ b_dA_qk_j = tl.sum(tl.where(mask_idx[:, None], b_dAqk, 0), 0)[:, None]
238
+ b_dA_ab_j = tl.sum(tl.where(mask_idx[:, None], b_dAab, 0), 0)[:, None]
239
+ b_dA_qb_j = tl.sum(tl.where(mask_idx[:, None], b_dAqb, 0), 0)[:, None]
240
+ b_dA_ak_j = tl.sum(tl.where(mask_idx[:, None], b_dAak, 0), 0)[:, None]
241
+ # [1, BK] b_qj, b_aj
242
+ b_qj = tl.sum(tl.where(mask_idx[:, None], b_q, 0), 0)[None, :]
243
+ b_aj = tl.sum(tl.where(mask_idx[:, None], b_a, 0), 0)[None, :]
244
+ # tl.static_print(b_kj)
245
+ m_e = o_i[:, None] > j
246
+ m_i = o_i[:, None] >= j
247
+ tmp1 = exp(b_gi - b_gij)
248
+ tmp2 = exp(b_ge - b_gij)
249
+ b_dq += tl.where(m_i, b_dAqk_j * b_kj * tmp1, 0.)
250
+ b_dq += tl.where(m_i, b_dAqb_j * b_bj * tmp1, 0.)
251
+ b_da += tl.where(m_e, b_dAab_j * b_bj * tmp2, 0.)
252
+ b_da += tl.where(m_e, b_dAak_j * b_kj * tmp2, 0.)
253
+
254
+ m_i = o_i[:, None] <= j
255
+ m_e = o_i[:, None] < j
256
+ tmp1 = exp(b_gij - b_gi)
257
+ tmp2 = exp(b_gej - b_gi)
258
+ b_dk += tl.where(m_i, b_dA_qk_j * b_qj * tmp1, 0.)
259
+ b_dk += tl.where(m_e, b_dA_ak_j * b_aj * tmp2, 0.)
260
+ b_db += tl.where(m_i, b_dA_qb_j * b_qj * tmp1, 0.)
261
+ b_db += tl.where(m_e, b_dA_ab_j * b_aj * tmp2, 0.)
262
+ # post processing
263
+ p_dq = tl.make_block_ptr(dq, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
264
+ p_dk = tl.make_block_ptr(dk, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
265
+ p_da = tl.make_block_ptr(da, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
266
+ p_db = tl.make_block_ptr(db, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
267
+ p_dgk = tl.make_block_ptr(dgk, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
268
+ p_dgk_offset = tl.make_block_ptr(dgk_offset, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
269
+ p_dqg = tl.make_block_ptr(dqg, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
270
+ p_dkg = tl.make_block_ptr(dkg, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
271
+ p_dag = tl.make_block_ptr(dag, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
272
+ p_dbg = tl.make_block_ptr(dbg, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
273
+ p_gn = gi + (min(i_t * BT + BT, T) - 1)*stride_qk + o_k
274
+ p_gn = tl.max_contiguous(tl.multiple_of(p_gn, BK), BK)
275
+ b_gn = tl.load(p_gn, mask=m_k, other=0)
276
+ b_da += tl.load(p_dag, boundary_check=(0, 1)) * exp(b_ge)
277
+ b_dq += tl.load(p_dqg, boundary_check=(0, 1)) * exp(b_gi) * scale
278
+ tmp = exp(b_gn[None, :] - b_gi)
279
+ b_dk += tl.load(p_dkg, boundary_check=(0, 1)) * tmp
280
+ b_db += tl.load(p_dbg, boundary_check=(0, 1)) * tmp
281
+ tl.store(p_dq, (b_dq).to(p_dq.dtype.element_ty), boundary_check=(0, 1))
282
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
283
+ tl.store(p_da, b_da.to(p_da.dtype.element_ty), boundary_check=(0, 1))
284
+ tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0, 1))
285
+ b_dgk = b_dq * b_q + b_da * b_a - b_dk * b_k - b_db * b_b
286
+ b_dgk_offset = b_da * b_a
287
+ tl.store(p_dgk, b_dgk.to(p_dgk.dtype.element_ty), boundary_check=(0, 1))
288
+ tl.store(p_dgk_offset, b_dgk_offset.to(p_dgk_offset.dtype.element_ty), boundary_check=(0, 1))
289
+
290
+
291
+ @triton.heuristics({
292
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
293
+ })
294
+ @triton.autotune(
295
+ configs=[
296
+ triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages)
297
+ for num_warps in [2, 4, 8, 16, 32]
298
+ for num_stages in [2, 3, 4]
299
+ for BK in [32, 64]
300
+ ],
301
+ key=['BK', 'BT', 'K'],
302
+ use_cuda_graph=use_cuda_graph,
303
+ )
304
+ @triton.jit(do_not_specialize=['T'])
305
+ def chunk_dplr_bwd_dgk_kernel(
306
+ dgk,
307
+ dgk_offset,
308
+ dgk_last,
309
+ dgk_output,
310
+ offsets,
311
+ indices,
312
+ T,
313
+ H: tl.constexpr,
314
+ K: tl.constexpr,
315
+ BT: tl.constexpr,
316
+ BK: tl.constexpr,
317
+ USE_OFFSETS: tl.constexpr,
318
+ HEAD_FIRST: tl.constexpr,
319
+ ):
320
+ i_t, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
321
+ i_b, i_h = i_bh // H, i_bh % H
322
+ if USE_OFFSETS:
323
+ i_tg = i_t
324
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
325
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
326
+ T = eos - bos
327
+ NT = tl.cdiv(T, BT)
328
+ else:
329
+ NT = tl.cdiv(T, BT)
330
+ i_tg = i_b * NT + i_t
331
+ bos, eos = i_b * T, i_b * T + T
332
+ T = eos - bos
333
+ stride_qk = K if HEAD_FIRST else H * K
334
+ dgk += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
335
+ dgk_offset += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
336
+ dgk_last += ((i_bh * NT + i_t) * K) if HEAD_FIRST else (i_tg * H + i_h) * K
337
+ dgk_output += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
338
+ p_dgk_last = dgk_last + tl.arange(0, BK) + i_k * BK
339
+ m_k = tl.arange(0, BK) + i_k * BK < K
340
+ b_dgk_last = tl.load(p_dgk_last, mask=m_k, other=0)
341
+ p_dgk_offset = tl.make_block_ptr(dgk_offset, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
342
+ p_dgk = tl.make_block_ptr(dgk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
343
+ b_dgk = tl.load(p_dgk, boundary_check=(0, 1))
344
+ b_dgk_offset = tl.load(p_dgk_offset, boundary_check=(0, 1))
345
+ # m_inv_cumsum = (tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :]).to(tl.float32)
346
+ # b_dgk_cumsum = tl.dot(m_inv_cumsum, b_dgk, allow_tf32=False)
347
+ b_dgk_cumsum = tl.cumsum(b_dgk, 0, reverse=True)
348
+ b_dgk_cumsum += b_dgk_last[None, :]
349
+ b_dgk_cumsum -= b_dgk_offset
350
+ p_dgk_output = tl.make_block_ptr(dgk_output, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
351
+ tl.store(p_dgk_output, b_dgk_cumsum.to(p_dgk_output.dtype.element_ty), boundary_check=(0, 1))
352
+
353
+
354
+ def chunk_dplr_bwd_dqk_intra(
355
+ q: torch.Tensor,
356
+ k: torch.Tensor,
357
+ a: torch.Tensor,
358
+ b: torch.Tensor,
359
+ gi: torch.Tensor,
360
+ ge: torch.Tensor,
361
+ dAqk: torch.Tensor,
362
+ dAqb: torch.Tensor,
363
+ dAak: torch.Tensor,
364
+ dAab: torch.Tensor,
365
+ dqg: torch.Tensor,
366
+ dkg: torch.Tensor,
367
+ dag: torch.Tensor,
368
+ dbg: torch.Tensor,
369
+ dgk_last: torch.Tensor,
370
+ offsets: Optional[torch.LongTensor] = None,
371
+ indices: Optional[torch.LongTensor] = None,
372
+ head_first: bool = True,
373
+ scale: float = 1.0,
374
+ chunk_size: int = 64,
375
+ ):
376
+ if head_first:
377
+ B, H, T, K = q.shape
378
+ else:
379
+ B, T, H, K = q.shape
380
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
381
+ BC = min(16, BT)
382
+ BK = min(64, triton.next_power_of_2(K)) if check_shared_mem() else min(32, triton.next_power_of_2(K))
383
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
384
+ NC = triton.cdiv(BT, BC)
385
+ NK = triton.cdiv(K, BK)
386
+
387
+ dq = torch.empty_like(q)
388
+ dk = torch.empty_like(k)
389
+ da = torch.empty_like(a)
390
+ db = torch.empty_like(b)
391
+ dgk = torch.empty_like(gi, dtype=torch.float)
392
+ dgk_offset = torch.empty_like(gi, dtype=torch.float)
393
+
394
+ grid = (NK, NT * NC, B * H)
395
+ chunk_dplr_bwd_kernel_intra[grid](
396
+ q=q,
397
+ k=k,
398
+ a=a,
399
+ b=b,
400
+ gi=gi,
401
+ ge=ge,
402
+ dAqk=dAqk,
403
+ dAqb=dAqb,
404
+ dAak=dAak,
405
+ dAab=dAab,
406
+ dq=dq,
407
+ dk=dk,
408
+ dgk=dgk,
409
+ dgk_offset=dgk_offset,
410
+ dqg=dqg,
411
+ dkg=dkg,
412
+ dag=dag,
413
+ dbg=dbg,
414
+ da=da,
415
+ db=db,
416
+ offsets=offsets,
417
+ indices=indices,
418
+ scale=scale,
419
+ T=T,
420
+ H=H,
421
+ K=K,
422
+ BT=BT,
423
+ BC=BC,
424
+ BK=BK,
425
+ NC=NC,
426
+ HEAD_FIRST=head_first,
427
+ GATHER_SUPPORTED=is_gather_supported
428
+ )
429
+
430
+ def grid2(meta): return (NT, triton.cdiv(K, meta['BK']), B * H)
431
+ dgk_output = torch.empty_like(dgk)
432
+
433
+ chunk_dplr_bwd_dgk_kernel[grid2](
434
+ dgk=dgk,
435
+ dgk_offset=dgk_offset,
436
+ dgk_last=dgk_last,
437
+ dgk_output=dgk_output,
438
+ offsets=offsets,
439
+ indices=indices,
440
+ T=T,
441
+ H=H,
442
+ K=K,
443
+ BT=BT,
444
+ HEAD_FIRST=head_first
445
+ )
446
+ return dq, dk, da, db, dgk_output
fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import exp, gather
11
+ from fla.utils import is_gather_supported, use_cuda_graph
12
+
13
+
14
+ @triton.heuristics({
15
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
16
+ })
17
+ @triton.autotune(
18
+ configs=[
19
+ triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages)
20
+ for BK in [32, 64]
21
+ for num_warps in [2, 4, 8, 16]
22
+ for num_stages in [2, 3, 4]
23
+ ],
24
+ key=['BC', 'K'],
25
+ use_cuda_graph=use_cuda_graph,
26
+ )
27
+ @triton.jit(do_not_specialize=['T'])
28
+ def chunk_dplr_fwd_A_kernel_intra_sub_inter(
29
+ q,
30
+ k,
31
+ a,
32
+ b,
33
+ gi, # cumsum
34
+ ge, # before cumsum
35
+ Aqk,
36
+ Aqb,
37
+ Aab,
38
+ Aak,
39
+ offsets,
40
+ indices,
41
+ scale: tl.constexpr,
42
+ T,
43
+ H: tl.constexpr,
44
+ K: tl.constexpr,
45
+ BT: tl.constexpr,
46
+ BC: tl.constexpr,
47
+ BK: tl.constexpr,
48
+ NC: tl.constexpr,
49
+ USE_OFFSETS: tl.constexpr,
50
+ HEAD_FIRST: tl.constexpr,
51
+ ):
52
+ i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
53
+ i_b, i_h = i_bh // H, i_bh % H
54
+ i_i, i_j = i_c // NC, i_c % NC
55
+ if USE_OFFSETS:
56
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
57
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
58
+ T = eos - bos
59
+ else:
60
+ bos, eos = i_b * T, i_b * T + T
61
+
62
+ if i_t * BT + i_i * BC >= T:
63
+ return
64
+ if i_i <= i_j:
65
+ return
66
+
67
+ b_Aqk = tl.zeros([BC, BC], dtype=tl.float32)
68
+ b_Aqb = tl.zeros([BC, BC], dtype=tl.float32)
69
+ b_Aab = tl.zeros([BC, BC], dtype=tl.float32)
70
+ b_Aak = tl.zeros([BC, BC], dtype=tl.float32)
71
+ for i_k in range(tl.cdiv(K, BK)):
72
+ o_k = i_k * BK + tl.arange(0, BK)
73
+ m_k = o_k < K
74
+
75
+ if HEAD_FIRST:
76
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
77
+ p_a = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
78
+ p_gq_i = tl.make_block_ptr(gi + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
79
+ p_gq_e = tl.make_block_ptr(ge + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
80
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
81
+ p_b = tl.make_block_ptr(b + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
82
+ p_gk = tl.make_block_ptr(gi + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
83
+ p_gn = tl.max_contiguous(tl.multiple_of(gi + (i_bh * T + i_t * BT + i_i * BC - 1) * K + o_k, BK), BK)
84
+ else:
85
+ p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
86
+ p_a = tl.make_block_ptr(a + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
87
+ p_gq_i = tl.make_block_ptr(gi + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
88
+ p_gq_e = tl.make_block_ptr(ge + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
89
+ p_k = tl.make_block_ptr(k + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
90
+ p_b = tl.make_block_ptr(b + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
91
+ p_gk = tl.make_block_ptr(gi + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
92
+ p_gn = gi + (bos + i_t * BT + i_i * BC - 1) * H*K + i_h * K + o_k
93
+ # [BK,]
94
+ b_gn = tl.load(p_gn, mask=m_k, other=0).to(tl.float32)
95
+ # [BC, BK]
96
+ b_q = tl.load(p_q, boundary_check=(0, 1))
97
+ b_a = tl.load(p_a, boundary_check=(0, 1))
98
+ b_gq_i = tl.load(p_gq_i, boundary_check=(0, 1))
99
+ b_gq_e = tl.load(p_gq_e, boundary_check=(0, 1))
100
+ b_ag = b_a * exp(b_gq_e - b_gn[None, :])
101
+ b_qg = b_q * exp(b_gq_i - b_gn[None, :]) * scale
102
+ # [BK, BC]
103
+ b_k = tl.load(p_k, boundary_check=(0, 1))
104
+ b_b = tl.load(p_b, boundary_check=(0, 1))
105
+ b_gk = tl.load(p_gk, boundary_check=(0, 1)).to(tl.float32)
106
+ tmp = exp(b_gn[:, None] - b_gk)
107
+ b_kg = b_k * tmp
108
+ b_bg = b_b * tmp
109
+ # [BC, BC] using tf32 to improve precision here.
110
+ b_Aab += tl.dot(b_ag, b_bg)
111
+ b_Aak += tl.dot(b_ag, b_kg)
112
+ b_Aqk += tl.dot(b_qg, b_kg)
113
+ b_Aqb += tl.dot(b_qg, b_bg)
114
+
115
+ if HEAD_FIRST:
116
+ p_Aqk = tl.make_block_ptr(Aqk + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
117
+ p_Aqb = tl.make_block_ptr(Aqb + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
118
+ p_Aab = tl.make_block_ptr(Aab + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
119
+ p_Aak = tl.make_block_ptr(Aak + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
120
+ else:
121
+ p_Aqk = tl.make_block_ptr(Aqk + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
122
+ p_Aqb = tl.make_block_ptr(Aqb + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
123
+ p_Aab = tl.make_block_ptr(Aab + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
124
+ p_Aak = tl.make_block_ptr(Aak + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
125
+ tl.store(p_Aqk, b_Aqk.to(Aqk.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
126
+ tl.store(p_Aqb, b_Aqb.to(Aqb.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
127
+ tl.store(p_Aab, b_Aab.to(Aab.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
128
+ tl.store(p_Aak, b_Aak.to(Aak.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
129
+
130
+
131
+ @triton.heuristics({
132
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
133
+ })
134
+ @triton.autotune(
135
+ configs=[
136
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
137
+ for num_warps in [2, 4, 8, 16, 32]
138
+ for num_stages in [2, 3, 4]
139
+ ],
140
+ key=['BK', 'BT'],
141
+ use_cuda_graph=use_cuda_graph,
142
+ )
143
+ @triton.jit(do_not_specialize=['T'])
144
+ def chunk_dplr_fwd_A_kernel_intra_sub_intra(
145
+ q,
146
+ k,
147
+ a,
148
+ b,
149
+ gi,
150
+ ge,
151
+ qg,
152
+ kg,
153
+ ag,
154
+ bg,
155
+ Aqk,
156
+ Aqb,
157
+ Aab,
158
+ Aak,
159
+ offsets,
160
+ indices,
161
+ scale: tl.constexpr,
162
+ T,
163
+ H: tl.constexpr,
164
+ K: tl.constexpr,
165
+ BT: tl.constexpr,
166
+ BC: tl.constexpr,
167
+ BK: tl.constexpr,
168
+ NC: tl.constexpr,
169
+ USE_OFFSETS: tl.constexpr,
170
+ HEAD_FIRST: tl.constexpr,
171
+ GATHER_SUPPORTED: tl.constexpr
172
+ ):
173
+ i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
174
+ i_b, i_h = i_bh // H, i_bh % H
175
+ i_j = i_i
176
+ if USE_OFFSETS:
177
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
178
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
179
+ T = eos - bos
180
+ else:
181
+ bos, eos = i_b * T, i_b * T + T
182
+
183
+ if i_t * BT + i_i * BC >= T:
184
+ return
185
+
186
+ o_i = tl.arange(0, BC)
187
+ o_k = tl.arange(0, BK)
188
+ m_k = o_k < K
189
+ m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
190
+ last_idx = min((i_t+1) * BT, T) - 1
191
+ if HEAD_FIRST:
192
+ o_A = i_bh * T*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC
193
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
194
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
195
+ p_a = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
196
+ p_b = tl.make_block_ptr(b + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
197
+ p_gi = tl.make_block_ptr(gi + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
198
+ p_ge = tl.make_block_ptr(ge + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
199
+ p_g_last = gi + i_bh * T*K + last_idx * K + tl.arange(0, BK)
200
+ b_g_last = tl.load(p_g_last, mask=m_k, other=0)
201
+
202
+ p_qg = tl.make_block_ptr(qg + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
203
+ p_kg = tl.make_block_ptr(kg + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
204
+ p_ag = tl.make_block_ptr(ag + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
205
+ p_bg = tl.make_block_ptr(bg + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
206
+ else:
207
+ o_A = (bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BT + i_h * BT + i_j * BC
208
+ p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
209
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
210
+ p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
211
+ p_b = tl.make_block_ptr(b + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
212
+ p_gi = tl.make_block_ptr(gi + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
213
+ p_ge = tl.make_block_ptr(ge + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
214
+ p_g_last = gi + (bos * H + i_h) * K + last_idx * H * K + tl.arange(0, BK)
215
+ b_g_last = tl.load(p_g_last, mask=m_k, other=0)
216
+ p_qg = tl.make_block_ptr(qg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
217
+ p_kg = tl.make_block_ptr(kg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
218
+ p_ag = tl.make_block_ptr(ag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
219
+ p_bg = tl.make_block_ptr(bg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
220
+
221
+ b_q = tl.load(p_q, boundary_check=(0, 1))
222
+ b_q = b_q * scale
223
+ b_k = tl.load(p_k, boundary_check=(0, 1))
224
+ b_a = tl.load(p_a, boundary_check=(0, 1))
225
+ b_b = tl.load(p_b, boundary_check=(0, 1))
226
+ b_gi = tl.load(p_gi, boundary_check=(0, 1)).to(tl.float32)
227
+ b_ge = tl.load(p_ge, boundary_check=(0, 1)).to(tl.float32)
228
+
229
+ # deal with decay term.
230
+ g_exp = exp(b_gi)
231
+ g_exp_inv = exp(-b_gi + b_g_last[None, :])
232
+ b_qg = b_q * g_exp
233
+ b_kg = b_k * g_exp_inv
234
+ b_bg = b_b * g_exp_inv
235
+ b_ag = b_a * exp(b_ge)
236
+ tl.store(p_qg, b_qg.to(p_qg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
237
+ tl.store(p_bg, b_bg.to(p_bg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
238
+ tl.store(p_ag, b_ag.to(p_ag.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
239
+ tl.store(p_kg, b_kg.to(p_kg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
240
+ # tl.debug_barrier()
241
+
242
+ b_q = b_q.to(b_k.dtype)
243
+ # inner attn
244
+ for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
245
+ # a trick to index the j-th row of b_k, b_g, b_b
246
+ if GATHER_SUPPORTED:
247
+ row_idx = tl.full([1, BK], j, dtype=tl.int16)
248
+ # [1, BK]
249
+ b_k_j = gather(b_k, row_idx, axis=0)
250
+ b_gk_j = gather(b_gi, row_idx, axis=0)
251
+ b_b_j = gather(b_b, row_idx, axis=0)
252
+ else:
253
+ mask = tl.arange(0, BC) == j
254
+ b_k_j = tl.sum(tl.where(mask[:, None], b_k, 0), 0)[None, :]
255
+ b_gk_j = tl.sum(tl.where(mask[:, None], b_gi, 0), 0)[None, :]
256
+ b_b_j = tl.sum(tl.where(mask[:, None], b_b, 0), 0)[None, :]
257
+ mask = tl.arange(0, BC) == j
258
+ tmp = exp(b_gi - b_gk_j)
259
+ b_A_qk = tl.sum(b_q * b_k_j * tmp, 1)
260
+ b_A_qk = tl.where(o_i >= j, b_A_qk, 0.)
261
+ b_A_qb = tl.sum(b_q * b_b_j * tmp, 1)
262
+ b_A_qb = tl.where(o_i >= j, b_A_qb, 0.)
263
+ tmp2 = exp(b_ge - b_gk_j)
264
+ b_A_ak = tl.sum(b_a * b_k_j * tmp2, 1)
265
+ b_A_ak = tl.where(o_i > j, b_A_ak, 0.)
266
+ b_A_ab = tl.sum(b_a * b_b_j * tmp2, 1)
267
+ b_A_ab = tl.where(o_i > j, b_A_ab, 0.)
268
+ tl.store(Aqk + o_A + j, b_A_qk.to(dtype=Aqk.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A)
269
+ tl.store(Aqb + o_A + j, b_A_qb.to(dtype=Aqb.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A)
270
+ tl.store(Aab + o_A + j, b_A_ab.to(dtype=Aqb.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A)
271
+ tl.store(Aak + o_A + j, b_A_ak.to(dtype=Aqk.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A)
272
+
273
+
274
+ def chunk_fwd_intra_dplr_fn(
275
+ q: torch.Tensor,
276
+ k: torch.Tensor,
277
+ a: torch.Tensor,
278
+ b: torch.Tensor,
279
+ gi: torch.Tensor,
280
+ ge: torch.Tensor,
281
+ scale: float,
282
+ chunk_size: int,
283
+ offsets: Optional[torch.LongTensor] = None,
284
+ indices: Optional[torch.LongTensor] = None,
285
+ head_first: bool = True,
286
+ ):
287
+ if head_first:
288
+ B, H, T, K = k.shape
289
+ else:
290
+ B, T, H, K = k.shape
291
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
292
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
293
+ BC = min(16, BT)
294
+ NC = triton.cdiv(BT, BC)
295
+
296
+ Aqk = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=q.dtype)
297
+ Aqb = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=q.dtype)
298
+ # involving matrix inverse and it'd be better to use float here.
299
+ Aab = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=torch.float)
300
+ Aak = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=torch.float)
301
+ grid = (NT, NC * NC, B * H)
302
+
303
+ chunk_dplr_fwd_A_kernel_intra_sub_inter[grid](
304
+ q=q, k=k, a=a, b=b, gi=gi, ge=ge, Aqk=Aqk, Aqb=Aqb, Aab=Aab, Aak=Aak,
305
+ offsets=offsets, indices=indices,
306
+ scale=scale,
307
+ T=T, H=H, K=K, BT=BT, BC=BC, NC=NC,
308
+ HEAD_FIRST=head_first
309
+ )
310
+ grid = (NT, NC, B * H)
311
+ BK = triton.next_power_of_2(K)
312
+ qg = torch.empty_like(q)
313
+ kg = torch.empty_like(k, dtype=q.dtype)
314
+ ag = torch.empty_like(a, dtype=q.dtype)
315
+ bg = torch.empty_like(b, dtype=q.dtype)
316
+ chunk_dplr_fwd_A_kernel_intra_sub_intra[grid](
317
+ q=q, k=k, a=a, b=b, gi=gi, ge=ge, Aqk=Aqk, Aqb=Aqb, Aab=Aab, Aak=Aak,
318
+ qg=qg, kg=kg, ag=ag, bg=bg,
319
+ offsets=offsets, indices=indices,
320
+ scale=scale,
321
+ T=T, H=H, K=K, BT=BT, BC=BC, BK=BK, HEAD_FIRST=head_first, NC=NC,
322
+ GATHER_SUPPORTED=is_gather_supported
323
+ )
324
+ return Aab, Aqk, Aak, Aqb, qg, kg, ag, bg
fla/ops/generalized_delta_rule/dplr/chunk_h_bwd.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.utils import prepare_chunk_offsets
11
+ from fla.ops.utils.op import exp
12
+ from fla.utils import check_shared_mem, use_cuda_graph
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
17
+ 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None,
18
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
23
+ for num_warps in [2, 4, 8, 16, 32]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['BT', 'BK', 'BV', "V"],
27
+ use_cuda_graph=use_cuda_graph,
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def chunk_dplr_bwd_kernel_dhu(
31
+ qg,
32
+ bg,
33
+ w,
34
+ gk,
35
+ dht,
36
+ dh0,
37
+ do,
38
+ dh,
39
+ dv,
40
+ dv2,
41
+ offsets,
42
+ chunk_offsets,
43
+ T,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BT: tl.constexpr,
48
+ BC: tl.constexpr,
49
+ BK: tl.constexpr,
50
+ BV: tl.constexpr,
51
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
52
+ USE_INITIAL_STATE: tl.constexpr,
53
+ USE_OFFSETS: tl.constexpr,
54
+ HEAD_FIRST: tl.constexpr
55
+ ):
56
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
57
+ i_n, i_h = i_nh // H, i_nh % H
58
+ if USE_OFFSETS:
59
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
60
+ T = eos - bos
61
+ NT = tl.cdiv(T, BT)
62
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
63
+ else:
64
+ bos, eos = i_n * T, i_n * T + T
65
+ NT = tl.cdiv(T, BT)
66
+ boh = i_n * NT
67
+
68
+ # [BK, BV]
69
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
70
+ if USE_FINAL_STATE_GRADIENT:
71
+ p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
72
+ b_dh += tl.load(p_dht, boundary_check=(0, 1))
73
+
74
+ mask_k = tl.arange(0, BK) < K
75
+ for i_t in range(NT - 1, -1, -1):
76
+ if HEAD_FIRST:
77
+ p_dh = tl.make_block_ptr(dh + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
78
+ else:
79
+ p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
80
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
81
+ b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32)
82
+ for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1):
83
+ if HEAD_FIRST:
84
+ p_qg = tl.make_block_ptr(qg + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
85
+ p_bg = tl.make_block_ptr(bg + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
86
+ p_w = tl.make_block_ptr(w + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
87
+ p_dv = tl.make_block_ptr(dv + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
88
+ p_do = tl.make_block_ptr(do + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
89
+ p_dv2 = tl.make_block_ptr(dv2 + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
90
+ else:
91
+ p_qg = tl.make_block_ptr(qg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
92
+ p_bg = tl.make_block_ptr(bg+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
93
+ p_w = tl.make_block_ptr(w+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
94
+ p_dv = tl.make_block_ptr(dv+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
95
+ p_do = tl.make_block_ptr(do+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
96
+ p_dv2 = tl.make_block_ptr(dv2+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
97
+ # [BK, BT]
98
+ b_qg = tl.load(p_qg, boundary_check=(0, 1))
99
+ # [BT, BK]
100
+ b_bg = tl.load(p_bg, boundary_check=(0, 1))
101
+ b_w = tl.load(p_w, boundary_check=(0, 1))
102
+ # [BT, V]
103
+ b_do = tl.load(p_do, boundary_check=(0, 1))
104
+ b_dv = tl.load(p_dv, boundary_check=(0, 1))
105
+ b_dv2 = b_dv + tl.dot(b_bg, b_dh.to(b_bg.dtype))
106
+ tl.store(p_dv2, b_dv2.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
107
+ # [BK, BV]
108
+ b_dh_tmp += tl.dot(b_qg, b_do.to(b_qg.dtype))
109
+ b_dh_tmp += tl.dot(b_w, b_dv2.to(b_qg.dtype))
110
+ last_idx = min((i_t + 1) * BT, T) - 1
111
+ if HEAD_FIRST:
112
+ bg_last = tl.load(gk + (i_nh * T + last_idx) * K + tl.arange(0, BK), mask=mask_k)
113
+ else:
114
+ bg_last = tl.load(gk + ((bos + last_idx) * H + i_h) * K + tl.arange(0, BK), mask=mask_k)
115
+ b_dh *= exp(bg_last)[:, None]
116
+ b_dh += b_dh_tmp
117
+
118
+ if USE_INITIAL_STATE:
119
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
120
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
121
+
122
+
123
+ def chunk_dplr_bwd_dhu(
124
+ qg: torch.Tensor,
125
+ bg: torch.Tensor,
126
+ w: torch.Tensor,
127
+ gk: torch.Tensor,
128
+ h0: torch.Tensor,
129
+ dht: Optional[torch.Tensor],
130
+ do: torch.Tensor,
131
+ dv: torch.Tensor,
132
+ offsets: Optional[torch.LongTensor] = None,
133
+ indices: Optional[torch.LongTensor] = None,
134
+ head_first: bool = True,
135
+ chunk_size: int = 64
136
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
137
+ if head_first:
138
+ B, H, T, K, V = *qg.shape, do.shape[-1]
139
+ else:
140
+ B, T, H, K, V = *qg.shape, do.shape[-1]
141
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
142
+ BK = triton.next_power_of_2(K)
143
+ assert BK <= 256, "current kernel does not support head dimension being larger than 256."
144
+ # H100
145
+ if check_shared_mem('hopper', qg.device.index):
146
+ BV = 64
147
+ BC = 64 if K <= 128 else 32
148
+ elif check_shared_mem('ampere', qg.device.index): # A100
149
+ BV = 32
150
+ BC = 32
151
+ else: # Etc: 4090
152
+ BV = 16
153
+ BC = 16
154
+
155
+ # N: the actual number of sequences in the batch with either equal or variable lengths
156
+ if offsets is None:
157
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
158
+ else:
159
+ N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT)
160
+
161
+ BC = min(BT, BC)
162
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
163
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
164
+
165
+ if head_first:
166
+ dh = qg.new_empty(B, H, NT, K, V)
167
+ else:
168
+ dh = qg.new_empty(B, NT, H, K, V)
169
+ dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None
170
+ dv2 = torch.zeros_like(dv)
171
+
172
+ grid = (NK, NV, N * H)
173
+ chunk_dplr_bwd_kernel_dhu[grid](
174
+ qg=qg,
175
+ bg=bg,
176
+ w=w,
177
+ gk=gk,
178
+ dht=dht,
179
+ dh0=dh0,
180
+ do=do,
181
+ dh=dh,
182
+ dv=dv,
183
+ dv2=dv2,
184
+ offsets=offsets,
185
+ chunk_offsets=chunk_offsets,
186
+ T=T,
187
+ H=H,
188
+ K=K,
189
+ V=V,
190
+ BT=BT,
191
+ BC=BC,
192
+ BK=BK,
193
+ BV=BV,
194
+ HEAD_FIRST=head_first
195
+ )
196
+ return dh, dh0, dv2
fla/ops/generalized_delta_rule/dplr/chunk_h_fwd.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.utils import prepare_chunk_offsets
11
+ from fla.ops.utils.op import exp
12
+ from fla.utils import check_shared_mem, use_cuda_graph
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
17
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
18
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
23
+ for num_warps in [2, 4, 8, 16, 32]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['BT', 'BK', 'BV'],
27
+ use_cuda_graph=use_cuda_graph,
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def chunk_dplr_fwd_kernel_h(
31
+ kg,
32
+ v,
33
+ w,
34
+ bg,
35
+ u,
36
+ v_new,
37
+ gk,
38
+ h,
39
+ h0,
40
+ ht,
41
+ offsets,
42
+ chunk_offsets,
43
+ T,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BT: tl.constexpr,
48
+ BC: tl.constexpr,
49
+ BK: tl.constexpr,
50
+ BV: tl.constexpr,
51
+ NT: tl.constexpr,
52
+ USE_INITIAL_STATE: tl.constexpr,
53
+ STORE_FINAL_STATE: tl.constexpr,
54
+ USE_OFFSETS: tl.constexpr,
55
+ HEAD_FIRST: tl.constexpr,
56
+ ):
57
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
58
+ i_n, i_h = i_nh // H, i_nh % H
59
+ if USE_OFFSETS:
60
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
61
+ T = eos - bos
62
+ NT = tl.cdiv(T, BT)
63
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
64
+ else:
65
+ bos, eos = i_n * T, i_n * T + T
66
+ NT = tl.cdiv(T, BT)
67
+ boh = i_n * NT
68
+
69
+ # [BK, BV]
70
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
71
+ if USE_INITIAL_STATE:
72
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
73
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
74
+
75
+ for i_t in range(NT):
76
+ if HEAD_FIRST:
77
+ p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
78
+ else:
79
+ p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
80
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
81
+
82
+ b_hc = tl.zeros([BK, BV], dtype=tl.float32)
83
+ # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden
84
+ for i_c in range(tl.cdiv(min(BT, T - i_t * BT), BC)):
85
+ if HEAD_FIRST:
86
+ p_kg = tl.make_block_ptr(kg + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
87
+ p_bg = tl.make_block_ptr(bg + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
88
+ p_w = tl.make_block_ptr(w + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
89
+ p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
90
+ p_u = tl.make_block_ptr(u + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
91
+ p_v_new = tl.make_block_ptr(v_new+i_nh*T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
92
+ else:
93
+ p_kg = tl.make_block_ptr(kg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
94
+ p_bg = tl.make_block_ptr(bg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
95
+ p_w = tl.make_block_ptr(w+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
96
+ p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
97
+ p_u = tl.make_block_ptr(u+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
98
+ p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT+i_c*BC, i_v * BV), (BC, BV), (1, 0))
99
+ # [BK, BC]
100
+ b_kg = tl.load(p_kg, boundary_check=(0, 1))
101
+ b_v = tl.load(p_v, boundary_check=(0, 1))
102
+ b_w = tl.load(p_w, boundary_check=(0, 1))
103
+ b_bg = tl.load(p_bg, boundary_check=(0, 1))
104
+ b_v2 = tl.dot(b_w, b_h.to(b_w.dtype)) + tl.load(p_u, boundary_check=(0, 1))
105
+ b_hc += tl.dot(b_kg, b_v)
106
+ b_hc += tl.dot(b_bg.to(b_hc.dtype), b_v2)
107
+ tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
108
+
109
+ last_idx = min((i_t + 1) * BT, T) - 1
110
+ if HEAD_FIRST:
111
+ b_g_last = tl.load(gk + i_nh * T * K + last_idx * K + tl.arange(0, BK), mask=tl.arange(0, BK) < K).to(tl.float32)
112
+ else:
113
+ b_g_last = tl.load(gk + (bos + last_idx) * H * K + i_h * K +
114
+ tl.arange(0, BK), mask=tl.arange(0, BK) < K).to(tl.float32)
115
+ b_h *= exp(b_g_last[:, None])
116
+ b_h += b_hc
117
+
118
+ if STORE_FINAL_STATE:
119
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
120
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
121
+
122
+
123
+ def chunk_dplr_fwd_h(
124
+ kg: torch.Tensor,
125
+ v: torch.Tensor,
126
+ w: torch.Tensor,
127
+ u: torch.Tensor,
128
+ bg: torch.Tensor,
129
+ gk: torch.Tensor,
130
+ initial_state: Optional[torch.Tensor] = None,
131
+ output_final_state: bool = False,
132
+ offsets: Optional[torch.LongTensor] = None,
133
+ indices: Optional[torch.LongTensor] = None,
134
+ head_first: bool = True,
135
+ chunk_size: int = 64
136
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
137
+ if head_first:
138
+ B, H, T, K, V = *kg.shape, u.shape[-1]
139
+ else:
140
+ B, T, H, K, V = *kg.shape, u.shape[-1]
141
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
142
+ # N: the actual number of sequences in the batch with either equal or variable lengths
143
+ if offsets is None:
144
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
145
+ else:
146
+ N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT)
147
+ BK = triton.next_power_of_2(K)
148
+ assert BK <= 256, "current kernel does not support head dimension larger than 256."
149
+ # H100 can have larger block size
150
+
151
+ if check_shared_mem('hopper', kg.device.index):
152
+ BV = 64
153
+ BC = 64 if K <= 128 else 32
154
+ elif check_shared_mem('ampere', kg.device.index): # A100
155
+ BV = 32
156
+ BC = 32
157
+ else:
158
+ BV = 16
159
+ BC = 16
160
+
161
+ BC = min(BT, BC)
162
+ NK = triton.cdiv(K, BK)
163
+ NV = triton.cdiv(V, BV)
164
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
165
+
166
+ if head_first:
167
+ h = kg.new_empty(B, H, NT, K, V)
168
+ else:
169
+ h = kg.new_empty(B, NT, H, K, V)
170
+ final_state = kg.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
171
+ v_new = torch.empty_like(u)
172
+ grid = (NK, NV, N * H)
173
+ chunk_dplr_fwd_kernel_h[grid](
174
+ kg=kg,
175
+ v=v,
176
+ w=w,
177
+ bg=bg,
178
+ u=u,
179
+ v_new=v_new,
180
+ h=h,
181
+ gk=gk,
182
+ h0=initial_state,
183
+ ht=final_state,
184
+ offsets=offsets,
185
+ chunk_offsets=chunk_offsets,
186
+ T=T,
187
+ H=H,
188
+ K=K,
189
+ V=V,
190
+ BT=BT,
191
+ BC=BC,
192
+ BK=BK,
193
+ BV=BV,
194
+ NT=NT,
195
+ HEAD_FIRST=head_first
196
+ )
197
+ return h, v_new, final_state
fla/ops/generalized_delta_rule/dplr/chunk_o_fwd.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.utils import check_shared_mem, use_cuda_graph
11
+
12
+ BK_LIST = [32, 64, 128] if check_shared_mem() else [16, 32]
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
17
+ })
18
+ @triton.autotune(
19
+ configs=[
20
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
21
+ for BK in BK_LIST
22
+ for BV in BK_LIST
23
+ for num_warps in [2, 4, 8, 16, 32]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['BT'],
27
+ use_cuda_graph=use_cuda_graph,
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def chunk_dplr_fwd_kernel_o(
31
+ qg,
32
+ v,
33
+ v_new,
34
+ A_qk,
35
+ A_qb,
36
+ h,
37
+ o,
38
+ offsets,
39
+ indices,
40
+ T,
41
+ H: tl.constexpr,
42
+ K: tl.constexpr,
43
+ V: tl.constexpr,
44
+ BT: tl.constexpr,
45
+ BK: tl.constexpr,
46
+ BV: tl.constexpr,
47
+ USE_OFFSETS: tl.constexpr,
48
+ HEAD_FIRST: tl.constexpr,
49
+ ):
50
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
51
+ i_b, i_h = i_bh // H, i_bh % H
52
+
53
+ if USE_OFFSETS:
54
+ i_tg = i_t
55
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
56
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
57
+ T = eos - bos
58
+ NT = tl.cdiv(T, BT)
59
+ else:
60
+ NT = tl.cdiv(T, BT)
61
+ i_tg = i_b * NT + i_t
62
+ bos, eos = i_b * T, i_b * T + T
63
+
64
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
65
+ for i_k in range(tl.cdiv(K, BK)):
66
+ if HEAD_FIRST:
67
+ p_qg = tl.make_block_ptr(qg + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
68
+ p_h = tl.make_block_ptr(h + (i_bh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
69
+ else:
70
+ p_qg = tl.make_block_ptr(qg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
71
+ p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
72
+ b_qg = tl.load(p_qg, boundary_check=(0, 1))
73
+ b_h = tl.load(p_h, boundary_check=(0, 1))
74
+ b_o += tl.dot(b_qg, b_h)
75
+
76
+ if HEAD_FIRST:
77
+ p_Aqk = tl.make_block_ptr(A_qk + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
78
+ p_Aqb = tl.make_block_ptr(A_qb + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
79
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
80
+ p_v_new = tl.make_block_ptr(v_new + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
81
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
82
+ else:
83
+ p_Aqk = tl.make_block_ptr(A_qk + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
84
+ p_Aqb = tl.make_block_ptr(A_qb + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
85
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
86
+ p_v_new = tl.make_block_ptr(v_new + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
87
+ p_o = tl.make_block_ptr(o + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
88
+
89
+ m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :]
90
+ b_Aqk = tl.load(p_Aqk, boundary_check=(0, 1))
91
+ b_Aqb = tl.load(p_Aqb, boundary_check=(0, 1))
92
+ b_Aqk = tl.where(m_s, b_Aqk, 0)
93
+ b_Aqb = tl.where(m_s, b_Aqb, 0)
94
+ b_v = tl.load(p_v, boundary_check=(0, 1))
95
+ b_v_new = tl.load(p_v_new, boundary_check=(0, 1))
96
+ b_o = b_o + tl.dot(b_Aqk.to(b_v.dtype), b_v) + tl.dot(b_Aqb.to(b_v_new.dtype), b_v_new)
97
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
98
+
99
+
100
+ def chunk_dplr_fwd_o(
101
+ qg: torch.Tensor,
102
+ v: torch.Tensor,
103
+ v_new: torch.Tensor,
104
+ A_qk: torch.Tensor,
105
+ A_qb: torch.Tensor,
106
+ h: torch.Tensor,
107
+ offsets: Optional[torch.LongTensor] = None,
108
+ indices: Optional[torch.LongTensor] = None,
109
+ head_first: bool = True,
110
+ chunk_size: int = 64
111
+ ) -> torch.Tensor:
112
+ if head_first:
113
+ B, H, T, K, V = *qg.shape, v.shape[-1]
114
+ else:
115
+ B, T, H, K, V = *qg.shape, v.shape[-1]
116
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
117
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
118
+
119
+ o = torch.empty_like(v)
120
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H)
121
+ chunk_dplr_fwd_kernel_o[grid](
122
+ qg=qg,
123
+ v=v,
124
+ v_new=v_new,
125
+ A_qk=A_qk,
126
+ A_qb=A_qb,
127
+ h=h,
128
+ o=o,
129
+ offsets=offsets,
130
+ indices=indices,
131
+ T=T,
132
+ H=H,
133
+ K=K,
134
+ V=V,
135
+ BT=BT,
136
+ HEAD_FIRST=head_first
137
+ )
138
+ return o
fla/ops/generalized_delta_rule/dplr/fused_recurrent.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import exp
11
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard, use_cuda_graph
12
+
13
+
14
+ @triton.heuristics({
15
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
16
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
17
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
18
+ })
19
+ @triton.autotune(
20
+ configs=[
21
+ triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages)
22
+ for BV in [16, 32, 64]
23
+ for num_warps in [2, 4, 8, 16]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['BK'],
27
+ use_cuda_graph=use_cuda_graph,
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def fused_recurrent_dplr_delta_rule_fwd_kernel(
31
+ q,
32
+ k,
33
+ v,
34
+ a,
35
+ b,
36
+ gk,
37
+ o,
38
+ h0,
39
+ ht,
40
+ offsets,
41
+ scale,
42
+ T,
43
+ B: tl.constexpr,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BK: tl.constexpr,
48
+ BV: tl.constexpr,
49
+ REVERSE: tl.constexpr,
50
+ USE_INITIAL_STATE: tl.constexpr,
51
+ STORE_FINAL_STATE: tl.constexpr,
52
+ USE_OFFSETS: tl.constexpr,
53
+ HEAD_FIRST: tl.constexpr
54
+ ):
55
+ i_v, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64)
56
+ i_n, i_h = i_nh // H, i_nh % H
57
+
58
+ if USE_OFFSETS:
59
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
60
+ T = eos - bos
61
+ else:
62
+ bos, eos = i_n * T, i_n * T + T
63
+
64
+ o_k = tl.arange(0, BK)
65
+ o_v = i_v * BV + tl.arange(0, BV)
66
+ if HEAD_FIRST:
67
+ p_q = q + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k
68
+ p_k = k + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k
69
+ p_a = a + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k
70
+ p_b = b + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k
71
+ p_gk = gk + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k
72
+ p_v = v + i_nh * T*V + ((T-1) * V if REVERSE else 0) + o_v
73
+ p_o = o + i_nh * T*V + ((T-1) * V if REVERSE else 0) + o_v
74
+
75
+ else:
76
+ p_q = q + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
77
+ p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
78
+ p_a = a + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
79
+ p_b = b + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
80
+ p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
81
+ p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v
82
+ p_o = o + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v
83
+
84
+ mask_k = o_k < K
85
+ mask_v = o_v < V
86
+ mask_h = mask_k[None, :] & mask_v[:, None]
87
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
88
+
89
+ if USE_INITIAL_STATE:
90
+ p_h0 = h0 + i_nh * K*V + o_k[None, :] * V + o_v[:, None]
91
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
92
+
93
+ for _ in range(0, T):
94
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
95
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
96
+ b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32)
97
+ b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32)
98
+ b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)
99
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
100
+
101
+ tmp = tl.sum(b_h * b_a[None, :], axis=1)
102
+ b_h = exp(b_gk)[None, :] * b_h + (tmp[:, None] * b_b[None, :] + b_k[None, :] * b_v[:, None])
103
+ b_o = tl.sum(b_h * b_q[None, :], axis=1)
104
+
105
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
106
+ p_q += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
107
+ p_k += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
108
+ p_a += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
109
+ p_b += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
110
+ p_gk += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
111
+ p_v += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
112
+ p_o += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
113
+
114
+ if STORE_FINAL_STATE:
115
+ p_ht = ht + i_nh * K*V + o_k[None, :] * V + o_v[:, None]
116
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
117
+
118
+
119
+ def fused_recurrent_dplr_delta_rule_fwd(
120
+ q: torch.Tensor,
121
+ k: torch.Tensor,
122
+ v: torch.Tensor,
123
+ a: torch.Tensor,
124
+ b: torch.Tensor,
125
+ gk: torch.Tensor,
126
+ scale: Optional[float] = 1.0,
127
+ initial_state: Optional[torch.Tensor] = None,
128
+ output_final_state: bool = False,
129
+ reverse: bool = False,
130
+ offsets: Optional[torch.LongTensor] = None,
131
+ head_first: bool = True
132
+ ):
133
+ if head_first:
134
+ B, H, T, K, V = *k.shape, v.shape[-1]
135
+ else:
136
+ B, T, H, K, V = *k.shape, v.shape[-1]
137
+ N = B if offsets is None else len(offsets) - 1
138
+ BK = triton.next_power_of_2(K)
139
+
140
+ h0 = initial_state
141
+ if output_final_state:
142
+ ht = q.new_empty(N, H, K, V, dtype=torch.float32)
143
+ else:
144
+ ht = None
145
+ o = torch.empty_like(v)
146
+
147
+ def grid(meta): return (triton.cdiv(V, meta['BV']), N * H)
148
+ fused_recurrent_dplr_delta_rule_fwd_kernel[grid](
149
+ q,
150
+ k,
151
+ v,
152
+ a,
153
+ b,
154
+ gk,
155
+ o,
156
+ h0,
157
+ ht,
158
+ offsets,
159
+ scale,
160
+ T=T,
161
+ B=B,
162
+ H=H,
163
+ K=K,
164
+ V=V,
165
+ BK=BK,
166
+ REVERSE=reverse,
167
+ HEAD_FIRST=head_first
168
+ )
169
+ return o, ht
170
+
171
+
172
+ class FusedRecurrentDPLRDeltaRuleFunction(torch.autograd.Function):
173
+
174
+ @staticmethod
175
+ @input_guard
176
+ @autocast_custom_fwd
177
+ def forward(
178
+ ctx,
179
+ q: torch.Tensor,
180
+ k: torch.Tensor,
181
+ v: torch.Tensor,
182
+ a: torch.Tensor,
183
+ b: torch.Tensor,
184
+ gk: torch.Tensor,
185
+ scale: Optional[float] = 1.0,
186
+ initial_state: Optional[torch.Tensor] = None,
187
+ output_final_state: bool = False,
188
+ reverse: bool = False,
189
+ offsets: Optional[torch.LongTensor] = None,
190
+ head_first: bool = False
191
+ ):
192
+ o, ht = fused_recurrent_dplr_delta_rule_fwd(
193
+ q=q,
194
+ k=k,
195
+ v=v,
196
+ a=a,
197
+ b=b,
198
+ gk=gk,
199
+ scale=scale,
200
+ initial_state=initial_state,
201
+ output_final_state=output_final_state,
202
+ reverse=reverse,
203
+ offsets=offsets,
204
+ head_first=head_first
205
+ )
206
+ return o, ht
207
+
208
+ @staticmethod
209
+ @input_guard
210
+ @autocast_custom_bwd
211
+ def backward(ctx, do, dht):
212
+ raise NotImplementedError(
213
+ "Backward pass for fused_recurrent_dplr_delta_rule is not implemented and will not be supported. "
214
+ "This kernel is only for inference. "
215
+ "For training, please use `chunk_dplr_delta_rule`."
216
+ )
217
+
218
+
219
+ def fused_recurrent_dplr_delta_rule(
220
+ q: torch.Tensor,
221
+ k: torch.Tensor,
222
+ v: torch.Tensor,
223
+ a: torch.Tensor,
224
+ b: torch.Tensor,
225
+ gk: torch.Tensor,
226
+ scale: Optional[float] = 1.0,
227
+ initial_state: Optional[torch.Tensor] = None,
228
+ output_final_state: bool = False,
229
+ reverse: bool = False,
230
+ cu_seqlens: Optional[torch.Tensor] = None,
231
+ head_first: bool = False
232
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
233
+ r"""
234
+ This function computes the recurrence S_t = S_t @ (I + a_t b_t^T) + v_t k_t^T in a recurrent manner.
235
+
236
+ Args:
237
+ q (torch.Tensor):
238
+ queries of shape `[B, H, T, K]`
239
+ k (torch.Tensor):
240
+ keys of shape `[B, H, T, K]`
241
+ v (torch.Tensor):
242
+ values of shape `[B, H, T, V]`
243
+ a (torch.Tensor):
244
+ as of shape `[B, H, T, K]`
245
+ b (torch.Tensor):
246
+ bs of shape `[B, H, T, K]`
247
+ gk (torch.Tensor):
248
+ gk of shape `[B, H, T, K]`
249
+ scale (Optional[int]):
250
+ Scale factor for the RetNet attention scores.
251
+ If None, it will default to `1 / sqrt(K)`. Default: `1.0`.
252
+ initial_state (Optional[torch.Tensor]):
253
+ Initial state of shape `[B, H, K, V]`. Default: `None`.
254
+ output_final_state (Optional[bool]):
255
+ Whether to output the final state of shape `[B, H, K, V]`. Default: `False`.
256
+ reverse (Optional[bool]):
257
+ If `True`, process the state passing in reverse order. Default: `False`.
258
+ cu_seqlens (Optional[torch.Tensor]):
259
+ Cumulative sequence lengths of shape `[N + 1]` used for variable-length training,
260
+ consistent with the FlashAttention API.
261
+ head_first (Optional[bool]):
262
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
263
+ Default: `False`.
264
+ """
265
+ if cu_seqlens is not None:
266
+ if q.shape[0] != 1:
267
+ raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
268
+ f"Please flatten variable-length inputs before processing.")
269
+ if head_first:
270
+ raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
271
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
272
+ raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, "
273
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.")
274
+ if scale is None:
275
+ scale = q.shape[-1] ** -0.5
276
+ else:
277
+ assert scale > 0, "scale must be positive"
278
+ o, final_state = FusedRecurrentDPLRDeltaRuleFunction.apply(
279
+ q,
280
+ k,
281
+ v,
282
+ a,
283
+ b,
284
+ gk,
285
+ scale,
286
+ initial_state,
287
+ output_final_state,
288
+ reverse,
289
+ cu_seqlens,
290
+ head_first
291
+ )
292
+ return o, final_state
fla/ops/generalized_delta_rule/dplr/naive.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ from einops import rearrange
5
+
6
+ # S_t = S_t @ (I + alpha_t beta_t^T) + v_t k_t^T
7
+ # q, k, alpha, beta [B, H, L, D_K]
8
+ # v [B, H, L, D_V]
9
+
10
+
11
+ def dplr_recurrence(q, k, v, alpha, beta, gk, initial_state=None, output_final_state=True):
12
+ orig_dtype = q.dtype
13
+ b, h, l, d_k = q.shape
14
+ q, k, v, beta, gk = map(lambda x: x.float(), [q, k, v, beta, gk])
15
+ d_v = v.shape[-1]
16
+ o = torch.zeros_like(v)
17
+ S = torch.zeros(b, h, d_k, d_v).to(v)
18
+ q = q * (d_k ** -0.5)
19
+
20
+ if initial_state is not None:
21
+ S += initial_state
22
+
23
+ for i in range(l):
24
+ _k = k[:, :, i]
25
+ _q = q[:, :, i]
26
+ _v = v[:, :, i]
27
+ _alpha = alpha[:, :, i].clone()
28
+ _beta = beta[:, :, i].clone()
29
+ _kv = _k[..., None] * _v[..., None, :] + (S.clone() * _alpha[..., None]).sum(-2, keepdim=True) * _beta[..., None]
30
+ S = S.clone() * gk[:, :, i].exp()[..., None] + _kv
31
+ o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S)
32
+ S = None if output_final_state is False else S
33
+ return o.to(orig_dtype), S
34
+
35
+
36
+ def dplr_chunkwise(q, k, v, alpha, beta, gk, initial_state=None, output_final_state=True, chunk_size=32):
37
+ b, h, l, d_k = q.shape
38
+ d_v = v.shape[-1]
39
+ q = q * (d_k ** -0.5)
40
+ v = v
41
+ assert l % chunk_size == 0
42
+
43
+ S = k.new_zeros(b, h, d_k, d_v).to(q)
44
+ if initial_state is not None:
45
+ S += initial_state
46
+
47
+ # note that diagonal is masked.
48
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0)
49
+ q, k, v, alpha, beta, gk = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d',
50
+ c=chunk_size).float(), [q, k, v, alpha, beta, gk])
51
+
52
+ gk_cumsum = gk.cumsum(-2)
53
+
54
+ # v2 = (alpha @ k.transpose(-1, -2)).masked_fill_(mask, 0) @ v
55
+ A_ab = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
56
+ A_qk = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
57
+ A_ak = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
58
+ A_qb = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
59
+
60
+ for i in range(chunk_size):
61
+ alpha_i = alpha[:, :, :, i, None]
62
+ q_i = q[:, :, :, i, None]
63
+ gk_i = gk_cumsum[:, :, :, i, None]
64
+ mask = (torch.arange(chunk_size) <= i).to(q.device)
65
+ attn_i = (gk_i - gk_cumsum).masked_fill(~mask.unsqueeze(-1), float('-inf')).exp()
66
+ A_qk[:, :, :, i, :] = (q_i * k * attn_i).sum(-1).clone()
67
+ A_qb[:, :, :, i, :] = (q_i * beta * attn_i).sum(-1).clone()
68
+ mask = (torch.arange(chunk_size) < i).to(q.device)
69
+ # shift by one.
70
+ attn_i = (gk_i - gk[:, :, :, i, None] - gk_cumsum).masked_fill(~mask.unsqueeze(-1), float('-inf')).exp()
71
+ A_ab[:, :, :, i, :] = (alpha_i * beta * attn_i).sum(-1).clone()
72
+ A_ak[:, :, :, i, :] = (alpha_i * k * attn_i).sum(-1).clone()
73
+
74
+ A_ab = A_ab
75
+ for i in range(1, chunk_size):
76
+ A_ab[..., i, :i] = A_ab[..., i, :i].clone() + (A_ab[..., i, :, None].clone() * A_ab[..., :, :i].clone()).sum(-2)
77
+
78
+ A_ab = A_ab + torch.eye(chunk_size, dtype=torch.float, device=q.device)
79
+ u = A_ab @ (A_ak @ v)
80
+ w = A_ab @ ((gk_cumsum-gk).exp() * alpha)
81
+
82
+ o = torch.zeros_like(v)
83
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1)
84
+ for i in range(0, l // chunk_size):
85
+ q_i, k_i, v_i, u_i, w_i, beta_i = q[:, :, i], k[:, :, i], v[:, :, i], u[:, :, i], w[:, :, i], beta[:, :, i]
86
+ v2_i = u_i + w_i @ S
87
+
88
+ o_1 = A_qk[:, :, i] @ v_i
89
+ o_2 = A_qb[:, :, i] @ v2_i
90
+ o_3 = (q_i * gk_cumsum[:, :, i].exp()) @ S
91
+ o[:, :, i] = o_1 + o_2 + o_3
92
+ decay = (gk_cumsum[:, :, i, -1, None] - gk_cumsum[:, :, i]).exp()
93
+ S = S*gk_cumsum[:, :, i, -1, :, None].exp() + (k_i * decay).transpose(-1, -2) @ v_i + \
94
+ (beta_i * decay).transpose(-1, -2) @ v2_i
95
+ S = None if output_final_state is False else S
96
+ return rearrange(o, 'b h n c d -> b h (n c) d'), S
fla/ops/generalized_delta_rule/dplr/wy_fast_bwd.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.utils import check_shared_mem, is_intel_alchemist, use_cuda_graph
11
+
12
+ # https://github.com/intel/intel-xpu-backend-for-triton/issues/3449
13
+ triton_config = {'grf_mode': 'large'} if is_intel_alchemist else {}
14
+
15
+
16
+ @triton.heuristics({
17
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
18
+ })
19
+ @triton.autotune(
20
+ configs=[
21
+ triton.Config(triton_config, num_warps=num_warps, num_stages=num_stages)
22
+ for num_warps in [2, 4, 8, 16, 32]
23
+ for num_stages in [2, 3, 4]
24
+ ],
25
+ key=['BT', 'BK', 'BV'],
26
+ use_cuda_graph=use_cuda_graph,
27
+ )
28
+ @triton.jit(do_not_specialize=['T'])
29
+ def bwd_prepare_wy_repr_kernel(
30
+ A_ab_inv,
31
+ A_ak,
32
+ ag,
33
+ v,
34
+ dw,
35
+ du,
36
+ dv,
37
+ dv0,
38
+ dag,
39
+ dAak,
40
+ dAab,
41
+ offsets,
42
+ indices,
43
+ T,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BT: tl.constexpr,
48
+ BK: tl.constexpr,
49
+ BV: tl.constexpr,
50
+ USE_OFFSETS: tl.constexpr,
51
+ HEAD_FIRST: tl.constexpr
52
+ ):
53
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
54
+ i_b, i_h = i_bh // H, i_bh % H
55
+ if USE_OFFSETS:
56
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
57
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
58
+ T = eos - bos
59
+ else:
60
+ bos, eos = i_b * T, i_b * T + T
61
+
62
+ if HEAD_FIRST:
63
+ p_Aab_inv_t = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
64
+ p_Aak_t = tl.make_block_ptr(A_ak + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
65
+ p_dAak = tl.make_block_ptr(dAak + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
66
+ p_dAab = tl.make_block_ptr(dAab + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
67
+ else:
68
+ p_Aak_t = tl.make_block_ptr(A_ak + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
69
+ p_Aab_inv_t = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
70
+ p_dAak = tl.make_block_ptr(dAak + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
71
+ p_dAab = tl.make_block_ptr(dAab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
72
+
73
+ b_A_ab_inv_t = tl.load(p_Aab_inv_t, boundary_check=(0, 1))
74
+ b_A_ak_t = tl.load(p_Aak_t, boundary_check=(0, 1))
75
+ b_A_ak_t = tl.where(tl.arange(0, BT)[:, None] < tl.arange(0, BT)[None, :], b_A_ak_t, 0)
76
+ b_A_ab_inv_t = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A_ab_inv_t, 0)
77
+ b_A_tmp_t = tl.dot(b_A_ak_t, b_A_ab_inv_t).to(v.dtype.element_ty)
78
+ b_dA_tmp = tl.zeros([BT, BT], dtype=tl.float32)
79
+
80
+ for i_v in range(tl.cdiv(V, BV)):
81
+ if HEAD_FIRST:
82
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
83
+ p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
84
+ p_dv0 = tl.make_block_ptr(dv0 + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
85
+ p_du = tl.make_block_ptr(du + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
86
+ else:
87
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
88
+ p_dv = tl.make_block_ptr(dv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
89
+ p_dv0 = tl.make_block_ptr(dv0 + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
90
+ p_du = tl.make_block_ptr(du + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
91
+ b_v = tl.load(p_v, boundary_check=(0, 1))
92
+ b_du = tl.load(p_du, boundary_check=(0, 1))
93
+ b_dA_tmp += tl.dot(b_du.to(b_v.dtype), tl.trans(b_v))
94
+ b_dv0 = tl.load(p_dv0, boundary_check=(0, 1))
95
+ b_dv = b_dv0 + tl.dot(b_A_tmp_t, b_du)
96
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
97
+
98
+ b_dA_tmp = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA_tmp, 0)
99
+ b_dA_ak = tl.dot(b_A_ab_inv_t, b_dA_tmp)
100
+ b_dA_ak = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA_ak, 0)
101
+ tl.store(p_dAak, b_dA_ak, boundary_check=(0, 1))
102
+ b_dA_ab_inv = tl.dot(b_dA_tmp, b_A_ak_t)
103
+
104
+ for i_k in range(tl.cdiv(K, BK)):
105
+ if HEAD_FIRST:
106
+ p_ag = tl.make_block_ptr(ag + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
107
+ p_dag = tl.make_block_ptr(dag + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
108
+ p_dw = tl.make_block_ptr(dw + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
109
+ else:
110
+ p_ag = tl.make_block_ptr(ag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
111
+ p_dag = tl.make_block_ptr(dag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
112
+ p_dw = tl.make_block_ptr(dw + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
113
+ b_ag = tl.load(p_ag, boundary_check=(0, 1))
114
+ b_dw = tl.load(p_dw, boundary_check=(0, 1))
115
+ b_dA_ab_inv += tl.dot(b_dw, tl.trans(b_ag))
116
+ b_dag = tl.dot(b_A_ab_inv_t.to(b_dw.dtype), b_dw)
117
+ tl.store(p_dag, b_dag.to(p_dag.dtype.element_ty), boundary_check=(0, 1))
118
+
119
+ # if we know dL/dA^(-1), for dL/dA, we can use the following formula:
120
+ # dL/dA = -(A^(-1))^T @ (dL/dA^(-1)) @ (A^(-1))^T
121
+ # in the fwd pass we use fwd substitution to calculate (I-lower(A_ab))^-1.
122
+ # denote A = I - lower(A_ab), B = A^-1
123
+ # in the backward pass.
124
+ # dL/dA = -(B)^T @ (dL/dB) @ B^T
125
+ # dL/dA_ab = lower(B^T @ dL/dB @ B^T)
126
+ b_dA_ab_inv = tl.where(tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :], b_dA_ab_inv, 0)
127
+ b_dA_ab_inv = tl.dot(b_A_ab_inv_t, b_dA_ab_inv)
128
+ b_dA_ab_inv = tl.dot(b_dA_ab_inv, b_A_ab_inv_t)
129
+ b_dA_ab_inv = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA_ab_inv, 0)
130
+ tl.store(p_dAab, b_dA_ab_inv, boundary_check=(0, 1))
131
+
132
+
133
+ def chunk_dplr_bwd_wy(
134
+ A_ab_inv: torch.Tensor,
135
+ A_ak: torch.Tensor,
136
+ v: torch.Tensor,
137
+ ag: torch.Tensor,
138
+ dw: torch.Tensor,
139
+ du: torch.Tensor,
140
+ dv0: torch.Tensor,
141
+ offsets: Optional[torch.LongTensor],
142
+ indices: Optional[torch.LongTensor],
143
+ head_first: bool,
144
+ chunk_size: int,
145
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
146
+ A_ab_inv, A_ak, v, ag, dw, du = map(lambda x: x.contiguous(), [A_ab_inv, A_ak, v, ag, dw, du])
147
+ if head_first:
148
+ B, H, T, K, V = *dw.shape, du.shape[-1]
149
+ else:
150
+ B, T, H, K, V = *dw.shape, du.shape[-1]
151
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
152
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
153
+ BK = min(triton.next_power_of_2(K), 64)
154
+ BV = min(triton.next_power_of_2(V), 64) if check_shared_mem() else min(triton.next_power_of_2(V), 32)
155
+
156
+ dA_ab = torch.empty_like(A_ab_inv, dtype=torch.float)
157
+ dA_ak = torch.empty_like(A_ak, dtype=torch.float)
158
+ dv = torch.empty_like(v)
159
+ dag = torch.empty_like(ag)
160
+
161
+ bwd_prepare_wy_repr_kernel[(NT, B * H)](
162
+ A_ab_inv=A_ab_inv,
163
+ A_ak=A_ak,
164
+ ag=ag,
165
+ v=v,
166
+ dw=dw,
167
+ du=du,
168
+ dv=dv,
169
+ dv0=dv0,
170
+ dag=dag,
171
+ dAak=dA_ak,
172
+ dAab=dA_ab,
173
+ offsets=offsets,
174
+ indices=indices,
175
+ T=T,
176
+ H=H,
177
+ K=K,
178
+ V=V,
179
+ BT=BT,
180
+ BK=BK,
181
+ BV=BV,
182
+ HEAD_FIRST=head_first
183
+ )
184
+ return dA_ab, dA_ak, dv, dag
fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2024, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import gather
11
+ from fla.utils import is_gather_supported, use_cuda_graph
12
+
13
+
14
+ @triton.heuristics({
15
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
16
+ })
17
+ @triton.autotune(
18
+ configs=[
19
+ triton.Config({}, num_warps=num_warps)
20
+ for num_warps in [1, 2, 4, 8, 16]
21
+ ],
22
+ key=['BT'],
23
+ use_cuda_graph=use_cuda_graph,
24
+ )
25
+ @triton.jit(do_not_specialize=['T'])
26
+ def fwd_prepare_wy_repr_kernel_chunk32(
27
+ A_ab,
28
+ A_ab_inv,
29
+ offsets,
30
+ indices,
31
+ T,
32
+ H: tl.constexpr,
33
+ BT: tl.constexpr,
34
+ BC: tl.constexpr, # placeholder, do not delete
35
+ USE_OFFSETS: tl.constexpr,
36
+ HEAD_FIRST: tl.constexpr
37
+ ):
38
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
39
+ i_b, i_h = i_bh // H, i_bh % H
40
+ if USE_OFFSETS:
41
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
42
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
43
+ T = eos - bos
44
+ else:
45
+ bos, eos = i_b * T, i_b * T + T
46
+ if HEAD_FIRST:
47
+ p_Aab = tl.make_block_ptr(A_ab + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
48
+ p_Aab_inv = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
49
+ else:
50
+ p_Aab = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
51
+ p_Aab_inv = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
52
+ b_A_ab = tl.load(p_Aab, boundary_check=(0, 1))
53
+ b_A_ab = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A_ab, 0)
54
+ for i in range(1, BT):
55
+ mask = tl.arange(0, BT) == i
56
+ b_a = tl.sum(tl.where(mask[:, None], b_A_ab, 0), 0)
57
+ b_a = b_a + tl.sum(b_a[:, None] * b_A_ab, 0) * (tl.arange(0, BT) < i)
58
+ b_A_ab = tl.where(mask[:, None], b_a, b_A_ab)
59
+ b_A_ab += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]
60
+ tl.store(p_Aab_inv, b_A_ab.to(p_Aab_inv.dtype.element_ty), boundary_check=(0, 1))
61
+
62
+
63
+ @triton.heuristics({
64
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
65
+ })
66
+ @triton.autotune(
67
+ configs=[
68
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
69
+ for num_warps in [2, 4, 8]
70
+ for num_stages in [2, 3, 4]
71
+ ],
72
+ key=['BC'],
73
+ use_cuda_graph=use_cuda_graph,
74
+ )
75
+ @triton.jit(do_not_specialize=['T'])
76
+ def fwd_prepare_wy_repr_kernel_chunk64(
77
+ A_ab,
78
+ A_ab_inv,
79
+ offsets,
80
+ indices,
81
+ T,
82
+ H: tl.constexpr,
83
+ BT: tl.constexpr,
84
+ BC: tl.constexpr,
85
+ USE_OFFSETS: tl.constexpr,
86
+ HEAD_FIRST: tl.constexpr,
87
+ GATHER_SUPPORTED: tl.constexpr = is_gather_supported
88
+ ):
89
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
90
+ i_b, i_h = i_bh // H, i_bh % H
91
+ if USE_OFFSETS:
92
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
93
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
94
+ T = eos - bos
95
+ else:
96
+ bos, eos = i_b * T, i_b * T + T
97
+
98
+ if HEAD_FIRST:
99
+
100
+ p_A1 = tl.make_block_ptr(A_ab + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
101
+ p_A2 = tl.make_block_ptr(A_ab + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
102
+ p_A3 = tl.make_block_ptr(A_ab + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
103
+ p_A_inv1 = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
104
+ p_A_inv2 = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
105
+ p_A_inv3 = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
106
+ p_A_inv4 = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
107
+ else:
108
+ p_A1 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
109
+ p_A2 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
110
+ p_A3 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
111
+ p_A_inv1 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
112
+ p_A_inv2 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
113
+ p_A_inv3 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
114
+ p_A_inv4 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
115
+
116
+ b_A = tl.load(p_A1, boundary_check=(0, 1))
117
+ b_A2 = tl.load(p_A2, boundary_check=(0, 1))
118
+ b_A3 = tl.load(p_A3, boundary_check=(0, 1))
119
+ b_A = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A, 0)
120
+ b_A2 = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A2, 0)
121
+
122
+ for i in range(1, BC):
123
+ if GATHER_SUPPORTED:
124
+ row_idx = tl.full([1, BC], i, dtype=tl.int16)
125
+ # [1, BK] -> [BK]
126
+ b_a = tl.sum(gather(b_A, row_idx, axis=0), 0)
127
+ b_a2 = tl.sum(gather(b_A2, row_idx, axis=0), 0)
128
+ else:
129
+ mask = tl.arange(0, BC) == i
130
+ b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
131
+ b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0)
132
+ mask = tl.arange(0, BC) == i
133
+ # b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
134
+ # b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0)
135
+ b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BC) < i)
136
+ b_a2 = b_a2 + tl.sum(b_a2[:, None] * b_A2, 0) * (tl.arange(0, BC) < i)
137
+ b_A = tl.where(mask[:, None], b_a, b_A)
138
+ b_A2 = tl.where(mask[:, None], b_a2, b_A2)
139
+
140
+ # blockwise computation of lower triangular matrix's inverse
141
+ # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1]
142
+ b_A += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
143
+ b_A2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
144
+ b_A3 = tl.dot(tl.dot(b_A2, b_A3), b_A)
145
+ # tl.debug_barrier()
146
+ tl.store(p_A_inv1, b_A.to(p_A_inv1.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
147
+ tl.store(p_A_inv2, b_A2.to(p_A_inv2.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
148
+ tl.store(p_A_inv3, b_A3.to(p_A_inv3.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
149
+ # causal mask
150
+ tl.store(p_A_inv4, tl.zeros([BC, BC], dtype=tl.float32).to(p_A_inv4.dtype.element_ty), boundary_check=(0, 1))
151
+
152
+
153
+ @triton.heuristics({
154
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
155
+ })
156
+ @triton.autotune(
157
+ configs=[
158
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
159
+ for num_warps in [2, 4, 8, 16, 32]
160
+ for num_stages in [2, 3, 4]
161
+ ],
162
+ key=['BT', 'BK', 'BV'],
163
+ use_cuda_graph=use_cuda_graph,
164
+ )
165
+ @triton.jit(do_not_specialize=['T'])
166
+ def fwd_wu_kernel(
167
+ u,
168
+ w,
169
+ ag,
170
+ v,
171
+ A_ab_inv,
172
+ A_ak,
173
+ offsets,
174
+ indices,
175
+ T,
176
+ H: tl.constexpr,
177
+ K: tl.constexpr,
178
+ V: tl.constexpr,
179
+ BT: tl.constexpr,
180
+ BK: tl.constexpr,
181
+ BV: tl.constexpr,
182
+ USE_OFFSETS: tl.constexpr,
183
+ HEAD_FIRST: tl.constexpr,
184
+ ):
185
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
186
+ i_b, i_h = i_bh // H, i_bh % H
187
+ if USE_OFFSETS:
188
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
189
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
190
+ T = eos - bos
191
+ else:
192
+ bos, eos = i_b * T, i_b * T + T
193
+
194
+ if HEAD_FIRST:
195
+ p_A_ab_inv = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
196
+ p_A_ak = tl.make_block_ptr(A_ak + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
197
+ else:
198
+ p_A_ab_inv = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
199
+ p_A_ak = tl.make_block_ptr(A_ak + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
200
+ b_Aab_inv = tl.load(p_A_ab_inv, boundary_check=(0, 1))
201
+ b_Aak = tl.load(p_A_ak, boundary_check=(0, 1))
202
+ o_s = tl.arange(0, BT)
203
+ b_Aab_inv = tl.where(o_s[:, None] >= o_s[None, :], b_Aab_inv, 0)
204
+ b_Aak = tl.where(o_s[:, None] > o_s[None, :], b_Aak, 0)
205
+ # let's use tf32 here
206
+ b_Aak = tl.dot(b_Aab_inv, b_Aak)
207
+ # (SY 01/04) should be bf16 or tf32? To verify.
208
+ b_Aak = b_Aak.to(v.dtype.element_ty, fp_downcast_rounding="rtne")
209
+ b_Aab_inv = b_Aab_inv.to(ag.dtype.element_ty, fp_downcast_rounding="rtne")
210
+
211
+ for i_k in range(tl.cdiv(K, BK)):
212
+ if HEAD_FIRST:
213
+ p_ag = tl.make_block_ptr(ag + i_bh * T * K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
214
+ p_w = tl.make_block_ptr(w + i_bh * T * K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
215
+ else:
216
+ p_ag = tl.make_block_ptr(ag + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
217
+ p_w = tl.make_block_ptr(w + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
218
+ b_ag = tl.load(p_ag, boundary_check=(0, 1))
219
+ b_w = tl.dot(b_Aab_inv, b_ag) # both bf16 or fp16
220
+ tl.store(p_w, b_w.to(p_w.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
221
+
222
+ for i_v in range(tl.cdiv(V, BV)):
223
+ if HEAD_FIRST:
224
+ p_v = tl.make_block_ptr(v + i_bh * T * V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
225
+ p_u = tl.make_block_ptr(u + i_bh * T * V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
226
+ else:
227
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
228
+ p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
229
+ b_v = tl.load(p_v, boundary_check=(0, 1))
230
+ b_u = tl.dot(b_Aak, b_v) # both bf16 or fp16
231
+ tl.store(p_u, b_u.to(p_u.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
232
+
233
+
234
+ def fwd_prepare_wy_repr(
235
+ ag: torch.Tensor,
236
+ v: torch.Tensor,
237
+ A_ak: torch.Tensor,
238
+ A_ab: torch.Tensor,
239
+ offsets: Optional[torch.LongTensor],
240
+ indices: Optional[torch.LongTensor],
241
+ head_first: bool = True,
242
+ chunk_size: int = 64
243
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
244
+ if head_first:
245
+ B, H, T, K = ag.shape
246
+ else:
247
+ B, T, H, K = ag.shape
248
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
249
+
250
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
251
+ BC = min(BT, 32)
252
+ fwd_fn = fwd_prepare_wy_repr_kernel_chunk64 if BT == 64 else fwd_prepare_wy_repr_kernel_chunk32
253
+ A_ab_inv = torch.empty_like(A_ab)
254
+ fwd_fn[(NT, B * H)](
255
+ A_ab=A_ab,
256
+ A_ab_inv=A_ab_inv,
257
+ offsets=offsets,
258
+ indices=indices,
259
+ T=T,
260
+ H=H,
261
+ BT=BT,
262
+ BC=BC,
263
+ HEAD_FIRST=head_first
264
+ )
265
+ w, u = fwd_wu(
266
+ ag=ag,
267
+ v=v,
268
+ A_ak=A_ak,
269
+ A_ab_inv=A_ab_inv,
270
+ offsets=offsets,
271
+ indices=indices,
272
+ head_first=head_first,
273
+ chunk_size=BT
274
+ )
275
+ return w, u, A_ab_inv
276
+
277
+
278
+ def fwd_wu(
279
+ ag: torch.Tensor,
280
+ v: torch.Tensor,
281
+ A_ak: torch.Tensor,
282
+ A_ab_inv: torch.Tensor,
283
+ offsets: Optional[torch.LongTensor],
284
+ indices: Optional[torch.LongTensor],
285
+ head_first: bool,
286
+ chunk_size: int
287
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
288
+ if head_first:
289
+ B, H, T, K, V = *ag.shape, v.shape[-1]
290
+ else:
291
+ B, T, H, K, V = *ag.shape, v.shape[-1]
292
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
293
+
294
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
295
+ BK = min(triton.next_power_of_2(K), 64)
296
+ BV = min(triton.next_power_of_2(V), 64)
297
+
298
+ u = torch.empty_like(v)
299
+ w = torch.empty_like(ag)
300
+ fwd_wu_kernel[(NT, B*H)](
301
+ ag=ag,
302
+ v=v,
303
+ A_ak=A_ak,
304
+ A_ab_inv=A_ab_inv,
305
+ w=w,
306
+ u=u,
307
+ offsets=offsets,
308
+ indices=indices,
309
+ T=T,
310
+ H=H,
311
+ K=K,
312
+ V=V,
313
+ BT=BT,
314
+ BK=BK,
315
+ BV=BV,
316
+ HEAD_FIRST=head_first
317
+ )
318
+ return w, u
fla/ops/generalized_delta_rule/iplr/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .chunk import chunk_iplr_delta_rule
2
+ from .fused_recurrent import fused_recurrent_iplr_delta_rule
3
+
4
+ __all__ = [
5
+ 'chunk_iplr_delta_rule',
6
+ 'fused_recurrent_iplr_delta_rule'
7
+ ]
fla/ops/generalized_delta_rule/iplr/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (367 Bytes). View file
 
fla/ops/generalized_delta_rule/iplr/__pycache__/chunk.cpython-311.pyc ADDED
Binary file (27.7 kB). View file
 
fla/ops/generalized_delta_rule/iplr/__pycache__/fused_recurrent.cpython-311.pyc ADDED
Binary file (28.5 kB). View file
 
fla/ops/generalized_delta_rule/iplr/__pycache__/wy_fast.cpython-311.pyc ADDED
Binary file (23.5 kB). View file
 
fla/ops/generalized_delta_rule/iplr/chunk.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.chunk_delta_h import prepare_chunk_offsets
11
+ from fla.ops.generalized_delta_rule.iplr.wy_fast import fwd_prepare_wy_repr
12
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, input_guard, use_cuda_graph
13
+
14
+ BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
19
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
20
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
21
+ })
22
+ @triton.autotune(
23
+ configs=[
24
+ triton.Config({}, num_warps=num_warps)
25
+ for num_warps in [2, 4, 8, 16]
26
+ ],
27
+ key=['BT', 'BK', 'BV'],
28
+ use_cuda_graph=use_cuda_graph,
29
+ )
30
+ @triton.jit(do_not_specialize=['T'])
31
+ def chunk_generalized_iplr_delta_rule_fwd_kernel_h(
32
+ k,
33
+ v,
34
+ d,
35
+ b,
36
+ u,
37
+ v_new,
38
+ h,
39
+ h0,
40
+ ht,
41
+ offsets,
42
+ chunk_offsets,
43
+ T,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BT: tl.constexpr,
48
+ BC: tl.constexpr,
49
+ BK: tl.constexpr,
50
+ BV: tl.constexpr,
51
+ NT: tl.constexpr,
52
+ USE_INITIAL_STATE: tl.constexpr,
53
+ STORE_FINAL_STATE: tl.constexpr,
54
+ USE_OFFSETS: tl.constexpr,
55
+ HEAD_FIRST: tl.constexpr,
56
+ ):
57
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
58
+ i_n, i_h = i_nh // H, i_nh % H
59
+ if USE_OFFSETS:
60
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
61
+ T = eos - bos
62
+ NT = tl.cdiv(T, BT)
63
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
64
+ else:
65
+ bos, eos = i_n * T, i_n * T + T
66
+ NT = tl.cdiv(T, BT)
67
+ boh = i_n * NT
68
+
69
+ # [BK, BV]
70
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
71
+ if USE_INITIAL_STATE:
72
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
73
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
74
+
75
+ for i_t in range(NT):
76
+ if HEAD_FIRST:
77
+ p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
78
+ else:
79
+ p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
80
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
81
+ b_hc = tl.zeros([BK, BV], dtype=tl.float32)
82
+ # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden
83
+ for i_c in range(tl.cdiv(min(BT, T - i_t * BT), BC)):
84
+ if HEAD_FIRST:
85
+ p_k = tl.make_block_ptr(k + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
86
+ p_b = tl.make_block_ptr(b + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
87
+ p_d = tl.make_block_ptr(d + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
88
+ p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
89
+ p_u = tl.make_block_ptr(u + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
90
+ p_v_new = tl.make_block_ptr(v_new+i_nh*T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
91
+ else:
92
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
93
+ p_b = tl.make_block_ptr(b+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
94
+ p_d = tl.make_block_ptr(d+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
95
+ p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
96
+ p_u = tl.make_block_ptr(u+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
97
+ p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT+i_c*BC, i_v * BV), (BC, BV), (1, 0))
98
+ # [BK, BC]
99
+ b_k = tl.load(p_k, boundary_check=(0, 1))
100
+ b_v = tl.load(p_v, boundary_check=(0, 1))
101
+ b_d = tl.load(p_d, boundary_check=(0, 1))
102
+ b_b = tl.load(p_b, boundary_check=(0, 1))
103
+ b_v2 = tl.dot(b_d, b_h.to(b_d.dtype)) + tl.load(p_u, boundary_check=(0, 1))
104
+ b_hc += tl.dot(b_k, b_v)
105
+ b_hc += tl.dot(b_b, b_v2.to(b_k.dtype))
106
+ tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
107
+ b_h += b_hc
108
+
109
+ if STORE_FINAL_STATE:
110
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
111
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
112
+
113
+
114
+ @triton.heuristics({
115
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
116
+ })
117
+ @triton.autotune(
118
+ configs=[
119
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
120
+ for BK in BKV_LIST
121
+ for BV in BKV_LIST
122
+ for num_warps in [2, 4, 8]
123
+ for num_stages in [2, 3]
124
+ ],
125
+ key=['BT'],
126
+ use_cuda_graph=use_cuda_graph,
127
+ )
128
+ @triton.jit(do_not_specialize=['T'])
129
+ def chunk_generalized_iplr_delta_rule_fwd_kernel_o(
130
+ q,
131
+ k,
132
+ v,
133
+ u,
134
+ b,
135
+ h,
136
+ o,
137
+ offsets,
138
+ indices,
139
+ scale,
140
+ T,
141
+ H: tl.constexpr,
142
+ K: tl.constexpr,
143
+ V: tl.constexpr,
144
+ BT: tl.constexpr,
145
+ BK: tl.constexpr,
146
+ BV: tl.constexpr,
147
+ USE_OFFSETS: tl.constexpr,
148
+ HEAD_FIRST: tl.constexpr,
149
+ ):
150
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
151
+ i_b, i_h = i_bh // H, i_bh % H
152
+
153
+ if USE_OFFSETS:
154
+ i_tg = i_t
155
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
156
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
157
+ T = eos - bos
158
+ NT = tl.cdiv(T, BT)
159
+ else:
160
+ NT = tl.cdiv(T, BT)
161
+ i_tg = i_b * NT + i_t
162
+ bos, eos = i_b * T, i_b * T + T
163
+
164
+ # offset calculation
165
+ q += (i_bh * T * K) if HEAD_FIRST else ((bos * H + i_h) * K)
166
+ k += (i_bh * T * K) if HEAD_FIRST else ((bos * H + i_h) * K)
167
+ b += (i_bh * T * K) if HEAD_FIRST else ((bos * H + i_h) * K)
168
+ v += (i_bh * T * V) if HEAD_FIRST else ((bos * H + i_h) * V)
169
+ u += (i_bh * T * V) if HEAD_FIRST else ((bos * H + i_h) * V)
170
+ o += (i_bh * T * V) if HEAD_FIRST else ((bos * H + i_h) * V)
171
+ h += ((i_bh * NT + i_t) * K * V) if HEAD_FIRST else ((i_tg * H + i_h) * K * V)
172
+ stride_qk = K if HEAD_FIRST else H*K
173
+ stride_vo = V if HEAD_FIRST else H*V
174
+
175
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
176
+ b_Aqk = tl.zeros([BT, BT], dtype=tl.float32)
177
+ b_Aqb = tl.zeros([BT, BT], dtype=tl.float32)
178
+
179
+ for i_k in range(tl.cdiv(K, BK)):
180
+ p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
181
+ p_k = tl.make_block_ptr(k, (K, T), (1, stride_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
182
+ p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
183
+ p_b = tl.make_block_ptr(b, (K, T), (1, stride_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
184
+ # [BT, BK]
185
+ b_q = tl.load(p_q, boundary_check=(0, 1))
186
+ # [BK, BT]
187
+ b_k = tl.load(p_k, boundary_check=(0, 1))
188
+ b_b = tl.load(p_b, boundary_check=(0, 1))
189
+ # [BK, BV]
190
+ b_h = tl.load(p_h, boundary_check=(0, 1))
191
+ # [BT, BK] @ [BK, BV] -> [BT, BV]
192
+ b_o += tl.dot(b_q, b_h)
193
+ # [BT, BK] @ [BK, BT] -> [BT, BT]
194
+ b_Aqk += tl.dot(b_q, b_k)
195
+ # [BT, BK] @ [BK, BT] -> [BT, BT]
196
+ b_Aqb += tl.dot(b_q, b_b)
197
+
198
+ o_i = tl.arange(0, BT)
199
+ m_A = o_i[:, None] >= o_i[None, :]
200
+ b_Aqk = tl.where(m_A, b_Aqk, 0)
201
+ b_Aqb = tl.where(m_A, b_Aqb, 0)
202
+
203
+ p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
204
+ p_u = tl.make_block_ptr(u, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
205
+ p_o = tl.make_block_ptr(o, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
206
+ b_v = tl.load(p_v, boundary_check=(0, 1))
207
+ b_u = tl.load(p_u, boundary_check=(0, 1))
208
+ b_o = (b_o + tl.dot(b_Aqk.to(b_v.dtype), b_v) + tl.dot(b_Aqb.to(b_u.dtype), b_u)) * scale
209
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
210
+
211
+
212
+ def chunk_generalized_iplr_delta_rule_fwd_o(
213
+ q: torch.Tensor,
214
+ k: torch.Tensor,
215
+ v: torch.Tensor,
216
+ v_new: torch.Tensor,
217
+ b: torch.Tensor,
218
+ h: torch.Tensor,
219
+ scale: Optional[float] = None,
220
+ offsets: Optional[torch.LongTensor] = None,
221
+ indices: Optional[torch.LongTensor] = None,
222
+ head_first: bool = True,
223
+ chunk_size: int = 64
224
+ ) -> torch.Tensor:
225
+ if head_first:
226
+ B, H, T, K, V = *q.shape, v.shape[-1]
227
+ else:
228
+ B, T, H, K, V = *q.shape, v.shape[-1]
229
+ if scale is None:
230
+ scale = k.shape[-1] ** -0.5
231
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
232
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
233
+
234
+ o = torch.empty_like(v)
235
+
236
+ def grid(meta): return (
237
+ triton.cdiv(V, meta['BV']),
238
+ NT,
239
+ B * H
240
+ )
241
+ chunk_generalized_iplr_delta_rule_fwd_kernel_o[grid](
242
+ q=q,
243
+ k=k,
244
+ v=v,
245
+ u=v_new,
246
+ b=b,
247
+ h=h,
248
+ o=o,
249
+ offsets=offsets,
250
+ indices=indices,
251
+ scale=scale,
252
+ T=T,
253
+ H=H,
254
+ K=K,
255
+ V=V,
256
+ BT=BT,
257
+ HEAD_FIRST=head_first
258
+ )
259
+ return o
260
+
261
+
262
+ def chunk_generalized_iplr_delta_rule_fwd_h(
263
+ k: torch.Tensor,
264
+ v: torch.Tensor,
265
+ w: torch.Tensor,
266
+ u: torch.Tensor,
267
+ b: torch.Tensor,
268
+ initial_state: Optional[torch.Tensor] = None,
269
+ output_final_state: bool = False,
270
+ offsets: Optional[torch.LongTensor] = None,
271
+ indices: Optional[torch.LongTensor] = None,
272
+ head_first: bool = True,
273
+ chunk_size: int = 64
274
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
275
+ if head_first:
276
+ B, H, T, K, V = *k.shape, u.shape[-1]
277
+ else:
278
+ B, T, H, K, V = *k.shape, u.shape[-1]
279
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
280
+ # N: the actual number of sequences in the batch with either equal or variable lengths
281
+ if offsets is None:
282
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
283
+ else:
284
+ N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT)
285
+
286
+ BK = triton.next_power_of_2(K)
287
+ assert BK <= 256, "current kernel does not support head dimension larger than 256."
288
+ # H100 can have larger block size
289
+
290
+ if check_shared_mem('hopper', k.device.index):
291
+ BV = 64
292
+ BC = 64 if K <= 128 else 32
293
+ elif check_shared_mem('ampere', k.device.index): # A100
294
+ BV = 32
295
+ BC = 32
296
+ else:
297
+ BV = 16
298
+ BC = 16
299
+
300
+ BC = min(BT, BC)
301
+ NK = triton.cdiv(K, BK)
302
+ NV = triton.cdiv(V, BV)
303
+
304
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
305
+
306
+ if head_first:
307
+ h = k.new_empty(B, H, NT, K, V)
308
+ else:
309
+ h = k.new_empty(B, NT, H, K, V)
310
+ final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
311
+
312
+ v_new = torch.empty_like(u)
313
+ grid = (NK, NV, N * H)
314
+
315
+ chunk_generalized_iplr_delta_rule_fwd_kernel_h[grid](
316
+ k=k,
317
+ v=v,
318
+ d=w,
319
+ b=b,
320
+ u=u,
321
+ v_new=v_new,
322
+ h=h,
323
+ h0=initial_state,
324
+ ht=final_state,
325
+ offsets=offsets,
326
+ chunk_offsets=chunk_offsets,
327
+ T=T,
328
+ H=H,
329
+ K=K,
330
+ V=V,
331
+ BT=BT,
332
+ BC=BC,
333
+ BK=BK,
334
+ BV=BV,
335
+ NT=NT,
336
+ HEAD_FIRST=head_first
337
+ )
338
+ return h, v_new, final_state
339
+
340
+
341
+ def chunk_generalized_iplr_delta_rule_fwd(
342
+ q: torch.Tensor,
343
+ k: torch.Tensor,
344
+ v: torch.Tensor,
345
+ a: torch.Tensor,
346
+ b: torch.Tensor,
347
+ scale: float,
348
+ initial_state: torch.Tensor,
349
+ output_final_state: bool,
350
+ offsets: Optional[torch.LongTensor] = None,
351
+ indices: Optional[torch.LongTensor] = None,
352
+ head_first: bool = True,
353
+ chunk_size: int = 64
354
+ ):
355
+ T = q.shape[2] if head_first else q.shape[1]
356
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
357
+ w, u, _ = fwd_prepare_wy_repr(
358
+ a=a,
359
+ b=b,
360
+ k=k,
361
+ v=v,
362
+ offsets=offsets,
363
+ indices=indices,
364
+ head_first=head_first,
365
+ chunk_size=BT
366
+ )
367
+
368
+ h, v_new, final_state = chunk_generalized_iplr_delta_rule_fwd_h(
369
+ k=k,
370
+ v=v,
371
+ b=b,
372
+ w=w,
373
+ u=u,
374
+ initial_state=initial_state,
375
+ output_final_state=output_final_state,
376
+ offsets=offsets,
377
+ indices=indices,
378
+ head_first=head_first,
379
+ chunk_size=BT
380
+ )
381
+ o = chunk_generalized_iplr_delta_rule_fwd_o(
382
+ q=q,
383
+ k=k,
384
+ v=v,
385
+ v_new=v_new,
386
+ b=b,
387
+ h=h,
388
+ scale=scale,
389
+ offsets=offsets,
390
+ indices=indices,
391
+ head_first=head_first,
392
+ chunk_size=BT
393
+ )
394
+ return o, final_state
395
+
396
+
397
+ class ChunkGeneralizedIPLRDeltaRuleFunction(torch.autograd.Function):
398
+
399
+ @staticmethod
400
+ @input_guard
401
+ @autocast_custom_fwd
402
+ def forward(
403
+ ctx,
404
+ q: torch.Tensor,
405
+ k: torch.Tensor,
406
+ v: torch.Tensor,
407
+ a: torch.Tensor,
408
+ b: torch.Tensor,
409
+ scale: float,
410
+ initial_state: torch.Tensor,
411
+ output_final_state: bool,
412
+ offsets: Optional[torch.LongTensor] = None,
413
+ head_first: bool = True
414
+ ):
415
+ chunk_size = 64
416
+
417
+ # 2-d indices denoting the offsets of chunks in each sequence
418
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
419
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
420
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
421
+ indices = None
422
+ if offsets is not None:
423
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], chunk_size).tolist()])
424
+ indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets)
425
+
426
+ o, final_state = chunk_generalized_iplr_delta_rule_fwd(
427
+ q=q,
428
+ k=k,
429
+ v=v,
430
+ a=a,
431
+ b=b,
432
+ scale=scale,
433
+ initial_state=initial_state,
434
+ output_final_state=output_final_state,
435
+ offsets=offsets,
436
+ indices=indices,
437
+ head_first=head_first,
438
+ chunk_size=chunk_size
439
+ )
440
+ return o.to(q.dtype), final_state
441
+
442
+ @staticmethod
443
+ @input_guard
444
+ @autocast_custom_bwd
445
+ def backward(
446
+ ctx,
447
+ do: torch.Tensor,
448
+ dht: torch.Tensor
449
+ ):
450
+ raise NotImplementedError(
451
+ "Backward pass for ChunkGeneralizedIPLRDeltaRuleFunction is not implemented yet. "
452
+ "Stay tuned!"
453
+ )
454
+
455
+
456
+ @torch.compiler.disable
457
+ def chunk_iplr_delta_rule(
458
+ q: torch.Tensor,
459
+ k: torch.Tensor,
460
+ v: torch.Tensor,
461
+ a: torch.Tensor,
462
+ b: torch.Tensor,
463
+ scale: float = None,
464
+ initial_state: torch.Tensor = None,
465
+ output_final_state: bool = False,
466
+ cu_seqlens: Optional[torch.LongTensor] = None,
467
+ head_first: bool = True
468
+ ):
469
+ r"""
470
+ Args:
471
+ q (torch.Tensor):
472
+ queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
473
+ k (torch.Tensor):
474
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
475
+ v (torch.Tensor):
476
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
477
+ a (torch.Tensor):
478
+ activations of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
479
+ b (torch.Tensor):
480
+ betas of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
481
+ scale (Optional[int]):
482
+ Scale factor for the RetNet attention scores.
483
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
484
+ initial_state (Optional[torch.Tensor]):
485
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
486
+ For equal-length input sequences, `N` equals the batch size `B`.
487
+ Default: `None`.
488
+ output_final_state (Optional[bool]):
489
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
490
+ cu_seqlens (torch.LongTensor):
491
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
492
+ consistent with the FlashAttention API.
493
+ head_first (Optional[bool]):
494
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
495
+ Default: `True`.
496
+
497
+ Returns:
498
+ o (torch.Tensor):
499
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
500
+ final_state (torch.Tensor):
501
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
502
+ """
503
+ assert q.dtype == k.dtype == v.dtype
504
+ assert q.dtype != torch.float32, "ChunkDeltaRuleFunction does not support float32. Please use bfloat16."
505
+
506
+ if cu_seqlens is not None:
507
+ if q.shape[0] != 1:
508
+ raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
509
+ f"Please flatten variable-length inputs before processing.")
510
+ if head_first:
511
+ raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
512
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
513
+ raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, "
514
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.")
515
+ scale = k.shape[-1] ** -0.5 if scale is None else scale
516
+ o, final_state = ChunkGeneralizedIPLRDeltaRuleFunction.apply(
517
+ q,
518
+ k,
519
+ v,
520
+ a,
521
+ b,
522
+ scale,
523
+ initial_state,
524
+ output_final_state,
525
+ cu_seqlens,
526
+ head_first
527
+ )
528
+ return o, final_state
fla/ops/generalized_delta_rule/iplr/fused_recurrent.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2024-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.utils import input_guard
11
+
12
+
13
+ @triton.heuristics({
14
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
15
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
16
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
17
+ })
18
+ @triton.autotune(
19
+ configs=[
20
+ triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages)
21
+ for BV in [32, 64]
22
+ for num_warps in [2, 4, 8, 16]
23
+ for num_stages in [2, 3, 4]
24
+ ],
25
+ key=["BK"],
26
+ )
27
+ @triton.jit
28
+ def fused_recurrent_fwd_kernel(
29
+ q, # query [B, H, L, K]
30
+ k, # key [B, H, L, V]
31
+ v, # value [B, H, L, V].
32
+ a, # a [B, H, L, K]
33
+ b, # b [B, H, L, K]
34
+ o, # output [B, H, L, V]
35
+ ha, # tmp variable [B, H, L, V] for storing intermediate results of (h * a[None, :]).sum(0)
36
+ h0, # initial hidden state [B, H, K, V]
37
+ ht, # final hidden state [B, H, K, V]
38
+ offsets, # varlen offsets
39
+ scale, # K ** -0.5
40
+ H, # n_heads
41
+ T, # seq_len
42
+ K: tl.constexpr, # K
43
+ V: tl.constexpr, # V
44
+ BK: tl.constexpr, # BLOCK SIZE along the K dimension
45
+ BV: tl.constexpr, # BLOCK SIZE along the V dimension
46
+ USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
47
+ STORE_FINAL_STATE: tl.constexpr, # whether to store final state
48
+ USE_OFFSETS: tl.constexpr,
49
+ HEAD_FIRST: tl.constexpr
50
+ ):
51
+ # indices
52
+ i_v, i_nh = tl.program_id(0), tl.program_id(1)
53
+ i_n, i_h = i_nh // H, i_nh % H
54
+
55
+ if USE_OFFSETS:
56
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
57
+ T = eos - bos
58
+ else:
59
+ bos, eos = i_n * T, i_n * T + T
60
+
61
+ if HEAD_FIRST:
62
+ p_q = q + i_nh * T*K + tl.arange(0, BK)
63
+ p_k = k + i_nh * T*K + tl.arange(0, BK)
64
+ p_a = a + i_nh * T*K + tl.arange(0, BK)
65
+ p_b = b + i_nh * T*K + tl.arange(0, BK)
66
+ p_o = o + i_nh * T*V + i_v * BV + tl.arange(0, BV)
67
+ p_v = v + i_nh * T*V + i_v * BV + tl.arange(0, BV)
68
+ p_ha = ha + i_nh * T*V + i_v * BV + tl.arange(0, BV)
69
+ else:
70
+ p_q = q + (bos * H + i_h) * K + tl.arange(0, BK)
71
+ p_k = k + (bos * H + i_h) * K + tl.arange(0, BK)
72
+ p_a = a + (bos * H + i_h) * K + tl.arange(0, BK)
73
+ p_b = b + (bos * H + i_h) * K + tl.arange(0, BK)
74
+ p_ha = ha + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
75
+ p_v = v + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
76
+ p_o = o + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
77
+
78
+ mask_k = tl.arange(0, BK) < K
79
+ mask_v = (i_v * BV + tl.arange(0, BV)) < V
80
+ mask_h = mask_k[None, :] & mask_v[:, None]
81
+
82
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
83
+
84
+ if USE_INITIAL_STATE:
85
+ p_h0 = h0 + i_nh * K * V + (tl.arange(0, BK)[None, :]) * V + ((i_v * BV + tl.arange(0, BV))[:, None])
86
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
87
+
88
+ for _ in range(0, T):
89
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
90
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
91
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
92
+ b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32)
93
+ b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32)
94
+ # to store
95
+ tmp = tl.sum(b_h * b_a[None, :], axis=1)
96
+ b_h += (tmp[:, None] * b_b[None, :] + b_k[None, :] * b_v[:, None])
97
+ _o = b_h * b_q[None, :]
98
+ _o = tl.sum(_o, axis=1)
99
+ tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_v)
100
+ tl.store(p_ha, tmp.to(p_ha.dtype.element_ty), mask=mask_v)
101
+ p_q += K if HEAD_FIRST else K*H
102
+ p_k += K if HEAD_FIRST else K*H
103
+ p_o += V if HEAD_FIRST else V*H
104
+ p_v += V if HEAD_FIRST else V*H
105
+ p_ha += V if HEAD_FIRST else V*H
106
+ p_a += K if HEAD_FIRST else K*H
107
+ p_b += K if HEAD_FIRST else K*H
108
+
109
+ if STORE_FINAL_STATE:
110
+ p_ht = ht + i_nh * K * V + (tl.arange(0, BK)[None, :]) * V + ((i_v * BV + tl.arange(0, BV))[:, None])
111
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
112
+
113
+
114
+ @triton.heuristics({
115
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
116
+ 'USE_DHT': lambda args: args['dht'] is not None,
117
+ 'USE_DH0': lambda args: args['dh0'] is not None,
118
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
119
+ })
120
+ @triton.autotune(
121
+ configs=[
122
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
123
+ for num_warps in [2, 4, 8, 16]
124
+ for num_stages in [2, 3]
125
+ ],
126
+ key=["BK", "BV"],
127
+ )
128
+ @triton.jit
129
+ def fused_recurrent_bwd_kernel(
130
+ # B: batch_size, H: n_heads, T: seq_len, D: b_dhead
131
+ # NV: number of split in the V dimension. NK: number of split in the K dimension
132
+ q, # query [B, H, L, K]
133
+ k, # key [B, H, L, V]
134
+ v, # value [B, H, L, V]
135
+ a, # a [B, H, L, K]
136
+ b, # b [B, H, L, K]
137
+ ha, # ha [B, H, L, V]
138
+ dht, # gradient of final state [B, H, K, V]
139
+ dh0, # gradient of initial state [B, H, K, V]
140
+ do, # gradient of output [B, H, L, V]
141
+ dq, # gradient of query [NV, B, H, L, K]
142
+ dk, # gradient of key [NV, B, H, L, K]
143
+ dv, # gradient of value [NK, B, H, L, V]
144
+ da, # gradient of a [NV, B, H, L, K]
145
+ db, # gradient of b [NV, B, H, L, K]
146
+ dha, # gradient of ha [NK, B, H, L, V]
147
+ h0, # initial state [B, H, K, V]
148
+ scale, # K ** -0.5
149
+ offsets, # offsets
150
+ B, # batch_size
151
+ H, # n_heads
152
+ T, # seq_len
153
+ K: tl.constexpr, # K
154
+ V: tl.constexpr, # V
155
+ BK: tl.constexpr, # BLOCK SIZE along the K dimension
156
+ BV: tl.constexpr, # BLOCK SIZE along the V dimension
157
+ USE_INITIAL_STATE: tl.constexpr, # whether to use initial state h0
158
+ USE_DH0: tl.constexpr, # whether to use dh0
159
+ USE_DHT: tl.constexpr, # whether to use dht
160
+ USE_OFFSETS: tl.constexpr,
161
+ HEAD_FIRST: tl.constexpr
162
+ ):
163
+ i_v, i_nh = tl.program_id(0), tl.program_id(1)
164
+ i_n, i_h = i_nh // H, i_nh % H
165
+ dk += i_v * B * H * K * T
166
+ db += i_v * B * H * K * T
167
+ dq += i_v * B * H * K * T
168
+ da += i_v * B * H * K * T
169
+ if USE_OFFSETS:
170
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
171
+ T = eos - bos
172
+ else:
173
+ bos, eos = i_n * T, i_n * T + T
174
+ mask_k = tl.arange(0, BK) < K
175
+ mask_v = (tl.arange(0, BV) + i_v * BV) < V
176
+
177
+ q += (i_nh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K)
178
+ k += (i_nh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K)
179
+ v += (i_nh * T*V + i_v * BV) if HEAD_FIRST else ((bos * H + i_h) * V + i_v * BV)
180
+ ha += (i_nh * T*V + i_v * BV) if HEAD_FIRST else ((bos * H + i_h) * V + i_v * BV)
181
+ a += (i_nh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K)
182
+ b += (i_nh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K)
183
+ do += (i_nh * T*V + i_v * BV) if HEAD_FIRST else ((bos * H + i_h) * V + i_v * BV)
184
+ dq += (i_nh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K)
185
+ dk += (i_nh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K)
186
+ dv += (i_nh * T*V + i_v * BV) if HEAD_FIRST else ((bos * H + i_h) * V + i_v * BV)
187
+ da += (i_nh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K)
188
+ db += (i_nh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K)
189
+ dha += (i_nh * T*V + i_v * BV) if HEAD_FIRST else ((bos * H + i_h) * V + i_v * BV)
190
+
191
+ p_q = q + tl.arange(0, BK) + (T - 1) * K * (1 if HEAD_FIRST else H)
192
+ p_k = k + tl.arange(0, BK) + (T - 1) * K * (1 if HEAD_FIRST else H)
193
+ p_v = v + tl.arange(0, BV) + (T - 1) * V * (1 if HEAD_FIRST else H)
194
+ p_ha = ha + tl.arange(0, BV) + (T - 1) * V * (1 if HEAD_FIRST else H)
195
+ p_a = a + tl.arange(0, BK) + (T - 1) * K * (1 if HEAD_FIRST else H)
196
+ p_b = b + tl.arange(0, BK) + (T - 1) * K * (1 if HEAD_FIRST else H)
197
+ p_do = do + tl.arange(0, BV) + (T - 1) * V * (1 if HEAD_FIRST else H)
198
+ p_dk = dk + tl.arange(0, BK) + (T - 1) * K * (1 if HEAD_FIRST else H)
199
+ p_dv = dv + tl.arange(0, BV) + (T - 1) * V * (1 if HEAD_FIRST else H)
200
+ p_dha = dha + tl.arange(0, BV) + (T - 1) * V * (1 if HEAD_FIRST else H)
201
+ p_db = db + tl.arange(0, BK) + (T - 1) * K * (1 if HEAD_FIRST else H)
202
+ p_da = da + tl.arange(0, BK) + (T - 1) * K * (1 if HEAD_FIRST else H)
203
+ p_dq = dq + tl.arange(0, BK) + (T - 1) * K * (1 if HEAD_FIRST else H)
204
+
205
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
206
+ if USE_DHT:
207
+ p_ht = dht + i_nh * K * V + (tl.arange(0, BK)[:, None]) * V + ((i_v * BV + tl.arange(0, BV))[None, :])
208
+ b_dh += tl.load(p_ht, mask=mask_k[:, None] & mask_v[None, :], other=0).to(tl.float32)
209
+
210
+ for _ in range(T):
211
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
212
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
213
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
214
+ b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
215
+ b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32)
216
+ b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32)
217
+ b_ha = tl.load(p_ha, mask=mask_v, other=0).to(tl.float32)
218
+
219
+ b_dh += b_q[:, None] * b_do[None, :]
220
+ d_k = tl.sum(b_dh * b_v[None, :], axis=1)
221
+ d_v = tl.sum(b_dh * b_k[:, None], axis=0)
222
+ tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_k)
223
+ tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_v)
224
+
225
+ b_dha = tl.sum(b_dh * b_b[:, None], axis=0)
226
+ tl.store(p_dha, b_dha.to(p_dha.dtype.element_ty), mask=mask_v)
227
+ b_db = tl.sum(b_dh * b_ha[None, :], axis=1)
228
+ tl.store(p_db, b_db.to(p_db.dtype.element_ty), mask=mask_k)
229
+
230
+ b_dh += b_dha[None, :] * b_a[:, None]
231
+ p_do -= V if HEAD_FIRST else V*H
232
+ p_q -= K if HEAD_FIRST else K*H
233
+ p_k -= K if HEAD_FIRST else K*H
234
+ p_v -= V if HEAD_FIRST else V*H
235
+ p_dk -= K if HEAD_FIRST else K*H
236
+ p_dv -= V if HEAD_FIRST else V*H
237
+ p_b -= K if HEAD_FIRST else K*H
238
+ p_db -= K if HEAD_FIRST else K*H
239
+ p_a -= K if HEAD_FIRST else K*H
240
+ p_dha -= V if HEAD_FIRST else V*H
241
+ p_ha -= V if HEAD_FIRST else V*H
242
+
243
+ if USE_DH0:
244
+ p_dh0 = dh0 + i_nh * K * V + (tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
245
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_k[:, None] & mask_v[None, :])
246
+
247
+ tl.debug_barrier()
248
+
249
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
250
+
251
+ if USE_INITIAL_STATE:
252
+ mask_kv = mask_k[:, None] & mask_v[None, :]
253
+ p_h0 = h0 + i_nh * K * V + (tl.arange(0, BK)[:, None]) * V + ((i_v * BV + tl.arange(0, BV))[None, :])
254
+ b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)
255
+
256
+ p_k = k + tl.arange(0, BK)
257
+ p_v = v + tl.arange(0, BV)
258
+ p_ha = ha + tl.arange(0, BV)
259
+ p_do = do + tl.arange(0, BV)
260
+ p_dha = dha + tl.arange(0, BV)
261
+ p_da = da + tl.arange(0, BK)
262
+ p_dq = dq + tl.arange(0, BK)
263
+ p_b = b + tl.arange(0, BK)
264
+
265
+ for i in range(0, T):
266
+ b_dha = tl.load(p_dha, mask=mask_v, other=0).to(tl.float32)
267
+ d_a = tl.sum(b_dha[None, :] * b_h, axis=1)
268
+ tl.store(p_da, d_a.to(p_da.dtype.element_ty), mask=mask_k)
269
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
270
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
271
+ b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
272
+ b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32)
273
+ b_ha = tl.load(p_ha, mask=mask_v, other=0).to(tl.float32)
274
+ b_h += b_k[:, None] * b_v[None, :] + b_b[:, None] * b_ha[None, :]
275
+ _d_q = b_h * b_do[None, :]
276
+ d_q = tl.sum(_d_q, axis=1) * scale
277
+ tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_k)
278
+
279
+ p_k += K if HEAD_FIRST else K*H
280
+ p_do += V if HEAD_FIRST else V*H
281
+ p_v += V if HEAD_FIRST else V*H
282
+ p_da += K if HEAD_FIRST else K*H
283
+ p_dha += V if HEAD_FIRST else V*H
284
+ p_ha += V if HEAD_FIRST else V*H
285
+ p_dq += K if HEAD_FIRST else K*H
286
+ p_b += K if HEAD_FIRST else K*H
287
+
288
+
289
+ class FusedRecurrentIPLRDeltaRuleFunction(torch.autograd.Function):
290
+
291
+ @staticmethod
292
+ @input_guard
293
+ def forward(ctx, q, k, v, a, b, scale=None, initial_state=None, output_final_state=False, offsets=None, head_first=False):
294
+ if head_first:
295
+ B, H, T, K, V = *k.shape, v.shape[-1]
296
+ else:
297
+ B, T, H, K, V = *k.shape, v.shape[-1]
298
+ N = B if offsets is None else len(offsets) - 1
299
+
300
+ BK = triton.next_power_of_2(K)
301
+ if output_final_state:
302
+ final_state = q.new_empty(B, H, K, V, dtype=torch.float32)
303
+ else:
304
+ final_state = None
305
+
306
+ ha = torch.empty_like(v, dtype=torch.float32)
307
+
308
+ def grid(meta): return (
309
+ triton.cdiv(V, meta['BV']),
310
+ N * H
311
+ )
312
+ o = torch.empty_like(v)
313
+ fused_recurrent_fwd_kernel[grid](
314
+ q=q,
315
+ k=k,
316
+ v=v,
317
+ a=a,
318
+ b=b,
319
+ o=o,
320
+ ha=ha,
321
+ h0=initial_state,
322
+ ht=final_state,
323
+ scale=scale,
324
+ offsets=offsets,
325
+ H=H,
326
+ T=T,
327
+ K=K,
328
+ V=V,
329
+ BK=BK,
330
+ HEAD_FIRST=head_first
331
+ )
332
+ ctx.save_for_backward(q, k, v, a, b, ha, initial_state)
333
+ ctx.scale = scale
334
+ ctx.head_first = head_first
335
+ ctx.offsets = offsets
336
+ return o, final_state
337
+
338
+ @staticmethod
339
+ @input_guard
340
+ def backward(ctx, do, dht):
341
+ q, k, v, a, b, ha, initial_state = ctx.saved_tensors
342
+ if ctx.head_first:
343
+ B, H, T, K, V = *q.shape, v.shape[-1]
344
+ else:
345
+ B, T, H, K, V = *q.shape, v.shape[-1]
346
+
347
+ N = B if ctx.offsets is None else len(ctx.offsets) - 1
348
+ scale = ctx.scale
349
+ BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 64)
350
+ NV = triton.cdiv(V, BV)
351
+
352
+ dq = q.new_empty(NV, *q.shape)
353
+ dk = k.new_empty(NV, *k.shape)
354
+ da = a.new_empty(NV, *a.shape)
355
+ db = b.new_empty(NV, *b.shape)
356
+ dv = torch.empty_like(v)
357
+ dha = torch.empty_like(ha)
358
+ grid = (NV, N * H)
359
+
360
+ if initial_state is not None and initial_state.requires_grad:
361
+ dh0 = torch.empty_like(initial_state, dtype=torch.float32)
362
+ else:
363
+ dh0 = None
364
+
365
+ fused_recurrent_bwd_kernel[grid](
366
+ q=q,
367
+ k=k,
368
+ v=v,
369
+ a=a,
370
+ b=b,
371
+ ha=ha,
372
+ dht=dht,
373
+ dh0=dh0,
374
+ do=do,
375
+ dq=dq,
376
+ dk=dk,
377
+ dv=dv,
378
+ da=da,
379
+ db=db,
380
+ dha=dha,
381
+ h0=initial_state,
382
+ scale=scale,
383
+ offsets=ctx.offsets,
384
+ B=B,
385
+ H=H,
386
+ T=T,
387
+ K=K,
388
+ V=V,
389
+ BK=BK,
390
+ BV=BV,
391
+ HEAD_FIRST=ctx.head_first
392
+ )
393
+ dq = dq.sum(0)
394
+ dk = dk.sum(0)
395
+ da = da.sum(0)
396
+ db = db.sum(0)
397
+ return dq.to(q), dk.to(k), dv.to(v), da.to(a), db.to(b), None, dh0, None, None, None
398
+
399
+
400
+ def fused_recurrent_iplr_delta_rule(
401
+ q: torch.Tensor,
402
+ k: torch.Tensor,
403
+ v: torch.Tensor,
404
+ a: torch.Tensor,
405
+ b: torch.Tensor,
406
+ scale: float = None,
407
+ initial_state: torch.Tensor = None,
408
+ output_final_state: bool = False,
409
+ offsets: torch.Tensor = None,
410
+ head_first: bool = False
411
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
412
+ r"""
413
+ This function computes the recurrence S_t = S_t @ (I + a_t b_t^T) + v_t k_t^T in a recurrent manner.
414
+
415
+ Args:
416
+ q (torch.Tensor):
417
+ queries of shape `[B, H, T, K]`
418
+ k (torch.Tensor):
419
+ keys of shape `[B, H, T, K]`
420
+ v (torch.Tensor):
421
+ values of shape `[B, H, T, V]`
422
+ a (torch.Tensor):
423
+ as of shape `[B, H, T, K]`
424
+ b (torch.Tensor):
425
+ bs of shape `[B, H, T, K]`
426
+ scale (Optional[int]):
427
+ Scale factor for the RetNet attention scores.
428
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
429
+ initial_state (Optional[torch.Tensor]):
430
+ Initial state of shape `[B, H, K, V]`. Default: `None`.
431
+ output_final_state (Optional[bool]):
432
+ Whether to output the final state of shape `[B, H, K, V]`. Default: `False`.
433
+ offsets (Optional[torch.Tensor]):
434
+
435
+ """
436
+ if offsets is not None:
437
+ if q.shape[0] != 1:
438
+ raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `offsets`."
439
+ f"Please flatten variable-length inputs before processing.")
440
+ if head_first:
441
+ raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
442
+ if initial_state is not None and initial_state.shape[0] != len(offsets) - 1:
443
+ raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, "
444
+ f"i.e., {len(offsets) - 1} rather than {initial_state.shape[0]}.")
445
+ if scale is None:
446
+ scale = q.shape[-1] ** -0.5
447
+ else:
448
+ assert scale > 0, "scale must be positive"
449
+ o, final_state = FusedRecurrentIPLRDeltaRuleFunction.apply(
450
+ q, k, v, a, b, scale, initial_state, output_final_state, offsets, head_first)
451
+ return o, final_state
fla/ops/generalized_delta_rule/iplr/naive.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ from einops import rearrange
5
+
6
+
7
+ # S_t = S_t @ (I + alpha_t beta_t^T) + v_t k_t^T
8
+ # q, k, alpha, beta [B, H, L, D_K]
9
+ # v [B, H, L, D_V]
10
+ def iplr_recurrence(q, k, v, alpha, beta, initial_state=None, output_final_state=True):
11
+ orig_dtype = q.dtype
12
+ b, h, l, d_k = q.shape
13
+ q, k, v, beta = map(lambda x: x.float(), [q, k, v, beta])
14
+ d_v = v.shape[-1]
15
+ o = torch.zeros_like(v)
16
+ S = torch.zeros(b, h, d_k, d_v).to(v)
17
+ q = q * (d_k ** -0.5)
18
+
19
+ if initial_state is not None:
20
+ S += initial_state
21
+
22
+ for i in range(l):
23
+ _k = k[:, :, i]
24
+ _q = q[:, :, i]
25
+ _v = v[:, :, i]
26
+ _alpha = alpha[:, :, i]
27
+ _beta = beta[:, :, i]
28
+ _kv = _k[..., None] * _v[..., None, :] + (S.clone() * _alpha[..., None]).sum(-2, keepdim=True) * _beta[..., None]
29
+ S = S + _kv
30
+ o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S)
31
+ S = None if output_final_state is False else S
32
+ return o.to(orig_dtype), S
33
+
34
+
35
+ def iplr_chunkwise(q, k, v, alpha, beta, initial_state=None, output_final_state=True, chunk_size=32):
36
+ b, h, l, d_k = q.shape
37
+ d_v = v.shape[-1]
38
+ q = q * (d_k ** -0.5)
39
+ v = v
40
+ assert l % chunk_size == 0
41
+
42
+ S = k.new_zeros(b, h, d_k, d_v)
43
+ if initial_state is not None:
44
+ S += initial_state
45
+
46
+ # note that diagonal is masked.
47
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0)
48
+ q, k, v, alpha, beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), [q, k, v, alpha, beta])
49
+
50
+ v2 = (alpha @ k.transpose(-1, -2)).masked_fill_(mask, 0) @ v
51
+ attn = (alpha @ beta.transpose(-1, -2)).masked_fill(mask, 0)
52
+ for i in range(1, chunk_size):
53
+ attn[..., i, :i] = attn[..., i, :i] + (attn[..., i, :, None].clone() * attn[..., :, :i].clone()).sum(-2)
54
+
55
+ attn = attn + torch.eye(chunk_size, dtype=torch.float, device=q.device)
56
+ u = attn @ v2
57
+ w = attn @ alpha
58
+ o = torch.zeros_like(v)
59
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1)
60
+ for i in range(0, l // chunk_size):
61
+ q_i, k_i, v_i, u_i, w_i, beta_i = q[:, :, i], k[:, :, i], v[:, :, i], u[:, :, i], w[:, :, i], beta[:, :, i]
62
+ o_1 = (q_i @ k_i.transpose(-1, -2)).masked_fill_(mask, 0) @ v_i
63
+ v2_i = u_i + w_i @ S
64
+ o_2 = (q_i @ beta_i.transpose(-1, -2)).masked_fill_(mask, 0) @ (v2_i)
65
+ o_3 = q_i @ S
66
+ o[:, :, i] = o_1 + o_2 + o_3
67
+ S = S + k_i.transpose(-1, -2) @ v_i + beta_i.transpose(-1, -2) @ v2_i
68
+ S = None if output_final_state is False else S
69
+ return rearrange(o, 'b h n c d -> b h (n c) d'), S
fla/ops/generalized_delta_rule/iplr/wy_fast.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
4
+
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+
11
+ from fla.utils import check_shared_mem, is_nvidia_hopper
12
+
13
+ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
14
+
15
+
16
+ @triton.heuristics({
17
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
18
+ })
19
+ @triton.autotune(
20
+ configs=[
21
+ triton.Config({}, num_warps=num_warps)
22
+ for num_warps in [1, 2, 4, 8, 16]
23
+ ],
24
+ key=['BK']
25
+ )
26
+ @triton.jit(do_not_specialize=['T'])
27
+ def fwd_prepare_wy_repr_kernel_chunk32(
28
+ a,
29
+ b,
30
+ A,
31
+ offsets,
32
+ indices,
33
+ T,
34
+ H: tl.constexpr,
35
+ K: tl.constexpr,
36
+ BT: tl.constexpr,
37
+ BK: tl.constexpr,
38
+ BC: tl.constexpr, # dummy placeholder
39
+ USE_OFFSETS: tl.constexpr,
40
+ HEAD_FIRST: tl.constexpr,
41
+ ):
42
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
43
+ i_b, i_h = i_bh // H, i_bh % H
44
+ if USE_OFFSETS:
45
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
46
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
47
+ T = eos - bos
48
+ else:
49
+ bos, eos = i_b * T, i_b * T + T
50
+
51
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
52
+ for i_k in range(tl.cdiv(K, BK)):
53
+ if HEAD_FIRST:
54
+ p_a = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
55
+ p_b = tl.make_block_ptr(b + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
56
+ else:
57
+ p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
58
+ p_b = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
59
+ b_a = tl.load(p_a, boundary_check=(0, 1))
60
+ b_b = tl.load(p_b, boundary_check=(0, 1))
61
+ b_A += tl.dot(b_a, b_b)
62
+
63
+ b_A = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0)
64
+ for i in range(1, BT):
65
+ mask = tl.arange(0, BT) == i
66
+ b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
67
+ b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i)
68
+ b_A = tl.where(mask[:, None], b_a, b_A)
69
+ b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]
70
+
71
+ if HEAD_FIRST:
72
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
73
+ else:
74
+ p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
75
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
76
+
77
+
78
+ @triton.heuristics({
79
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
80
+ })
81
+ @triton.autotune(
82
+ configs=[
83
+ triton.Config({}, num_warps=num_warps)
84
+ for num_warps in [1, 2, 4, 8, 16]
85
+ ],
86
+ key=['BK']
87
+ )
88
+ @triton.jit(do_not_specialize=['T'])
89
+ def fwd_prepare_wy_repr_kernel_chunk64(
90
+ a,
91
+ b,
92
+ A,
93
+ offsets,
94
+ indices,
95
+ T,
96
+ H: tl.constexpr,
97
+ K: tl.constexpr,
98
+ BT: tl.constexpr,
99
+ BK: tl.constexpr,
100
+ BC: tl.constexpr,
101
+ USE_OFFSETS: tl.constexpr,
102
+ HEAD_FIRST: tl.constexpr
103
+ ):
104
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
105
+ i_b, i_h = i_bh // H, i_bh % H
106
+ if USE_OFFSETS:
107
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
108
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
109
+ T = eos - bos
110
+ else:
111
+ bos, eos = i_b * T, i_b * T + T
112
+
113
+ b_A = tl.zeros([BC, BC], dtype=tl.float32)
114
+ b_A2 = tl.zeros([BC, BC], dtype=tl.float32)
115
+ b_A3 = tl.zeros([BC, BC], dtype=tl.float32)
116
+
117
+ for i_k in range(tl.cdiv(K, BK)):
118
+ if HEAD_FIRST:
119
+ p_a1 = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
120
+ p_a2 = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT + BC, i_k * BK), (BC, BK), (1, 0))
121
+ p_b1 = tl.make_block_ptr(b + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BC), (0, 1))
122
+ p_b2 = tl.make_block_ptr(b + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + BC), (BK, BC), (0, 1))
123
+ else:
124
+ p_a1 = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
125
+ p_a2 = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + BC, i_k * BK), (BC, BK), (1, 0))
126
+ p_b1 = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT), (BK, BC), (0, 1))
127
+ p_b2 = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT + BC), (BK, BC), (0, 1))
128
+ b_a1 = tl.load(p_a1, boundary_check=(0, 1))
129
+ b_a2 = tl.load(p_a2, boundary_check=(0, 1))
130
+ b_b1 = tl.load(p_b1, boundary_check=(0, 1))
131
+ b_b2 = tl.load(p_b2, boundary_check=(0, 1))
132
+ b_A += tl.dot(b_a1, b_b1, allow_tf32=False)
133
+ b_A2 += tl.dot(b_a2, b_b2, allow_tf32=False)
134
+ b_A3 += tl.dot(b_a2, b_b1, allow_tf32=False)
135
+
136
+ b_A = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A, 0)
137
+ b_A2 = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A2, 0)
138
+
139
+ for i in range(1, BC):
140
+ mask = tl.arange(0, BC) == i
141
+ b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
142
+ b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0)
143
+ b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BC) < i)
144
+ b_a2 = b_a2 + tl.sum(b_a2[:, None] * b_A2, 0) * (tl.arange(0, BC) < i)
145
+ b_A = tl.where(mask[:, None], b_a, b_A)
146
+ b_A2 = tl.where(mask[:, None], b_a2, b_A2)
147
+
148
+ # blockwise computation of lower triangular matrix's inverse
149
+ # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1]
150
+ b_A += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
151
+ b_A2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
152
+ b_A3 = tl.dot(tl.dot(b_A2, b_A3, allow_tf32=False), b_A, allow_tf32=False)
153
+
154
+ if HEAD_FIRST:
155
+ p_A1 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
156
+ p_A2 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
157
+ p_A3 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
158
+ p_A4 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
159
+ else:
160
+ p_A1 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
161
+ p_A2 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
162
+ p_A3 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
163
+ p_A4 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
164
+ tl.store(p_A1, b_A.to(p_A1.dtype.element_ty), boundary_check=(0, 1))
165
+ tl.store(p_A2, b_A2.to(p_A2.dtype.element_ty), boundary_check=(0, 1))
166
+ tl.store(p_A3, b_A3.to(p_A3.dtype.element_ty), boundary_check=(0, 1))
167
+ # causal mask
168
+ tl.store(p_A4, tl.zeros([BC, BC], dtype=tl.float32).to(p_A4.dtype.element_ty), boundary_check=(0, 1))
169
+
170
+
171
+ @triton.heuristics({
172
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
173
+ })
174
+ @triton.autotune(
175
+ configs=[
176
+ triton.Config({}, num_warps=num_warps)
177
+ for num_warps in NUM_WARPS
178
+ ],
179
+ key=['BT', 'BK', 'BV']
180
+ )
181
+ @triton.jit(do_not_specialize=['T'])
182
+ def fwd_wu_kernel(
183
+ w,
184
+ u,
185
+ a,
186
+ k,
187
+ v,
188
+ A,
189
+ offsets,
190
+ indices,
191
+ T,
192
+ H: tl.constexpr,
193
+ K: tl.constexpr,
194
+ V: tl.constexpr,
195
+ BT: tl.constexpr,
196
+ BK: tl.constexpr,
197
+ BV: tl.constexpr,
198
+ USE_OFFSETS: tl.constexpr,
199
+ HEAD_FIRST: tl.constexpr
200
+ ):
201
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
202
+ i_b, i_h = i_bh // H, i_bh % H
203
+ if USE_OFFSETS:
204
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
205
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
206
+ T = eos - bos
207
+ else:
208
+ bos, eos = i_b * T, i_b * T + T
209
+
210
+ if HEAD_FIRST:
211
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
212
+ else:
213
+ p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
214
+
215
+ b_A = tl.load(p_A, boundary_check=(0, 1))
216
+ b_Aak = tl.zeros([BT, BT], dtype=tl.float32)
217
+
218
+ for i_k in range(tl.cdiv(K, BK)):
219
+ if HEAD_FIRST:
220
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
221
+ p_a = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
222
+ p_w = tl.make_block_ptr(w + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
223
+ else:
224
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
225
+ p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
226
+ p_w = tl.make_block_ptr(w + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
227
+ b_k = tl.load(p_k, boundary_check=(0, 1))
228
+ b_a = tl.load(p_a, boundary_check=(0, 1))
229
+ b_w = tl.dot(b_A, b_a)
230
+ b_Aak += tl.dot(b_a, tl.trans(b_k))
231
+ tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
232
+
233
+ b_Aak = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_Aak, 0)
234
+ b_Aak = b_Aak.to(k.dtype.element_ty)
235
+
236
+ for i_v in range(tl.cdiv(V, BV)):
237
+ if HEAD_FIRST:
238
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
239
+ p_u = tl.make_block_ptr(u + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
240
+ else:
241
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
242
+ p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
243
+ b_v = tl.load(p_v, boundary_check=(0, 1))
244
+ b_v = tl.dot(b_Aak, b_v).to(v.dtype.element_ty)
245
+ b_u = tl.dot(b_A, b_v)
246
+ tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
247
+
248
+
249
+ def fwd_prepare_wy_repr(
250
+ a: torch.Tensor,
251
+ b: torch.Tensor,
252
+ v: torch.Tensor,
253
+ k: torch.Tensor,
254
+ offsets: Optional[torch.LongTensor],
255
+ indices: Optional[torch.LongTensor],
256
+ head_first: bool = True,
257
+ chunk_size: int = 64
258
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
259
+ if head_first:
260
+ B, H, T, K = a.shape
261
+ else:
262
+ B, T, H, K = a.shape
263
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
264
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
265
+ BC = min(BT, 32)
266
+ BK = min(triton.next_power_of_2(K), 64)
267
+
268
+ A = torch.empty(B, *((H, T) if head_first else (T, H)), BT, device=a.device, dtype=a.dtype)
269
+ fwd_fn = fwd_prepare_wy_repr_kernel_chunk64 if BT == 64 else fwd_prepare_wy_repr_kernel_chunk32
270
+
271
+ fwd_fn[(NT, B * H)](
272
+ a=a,
273
+ b=b,
274
+ A=A,
275
+ offsets=offsets,
276
+ indices=indices,
277
+ T=T,
278
+ H=H,
279
+ K=K,
280
+ BT=BT,
281
+ BK=BK,
282
+ BC=BC,
283
+ HEAD_FIRST=head_first
284
+ )
285
+ w, u = fwd_wu(
286
+ a=a,
287
+ v=v,
288
+ k=k,
289
+ A=A,
290
+ offsets=offsets,
291
+ indices=indices,
292
+ head_first=head_first,
293
+ chunk_size=chunk_size
294
+ )
295
+ return w, u, A
296
+
297
+
298
+ def fwd_wu(
299
+ a: torch.Tensor,
300
+ v: torch.Tensor,
301
+ k: torch.Tensor,
302
+ A: torch.Tensor,
303
+ offsets: Optional[torch.LongTensor],
304
+ indices: Optional[torch.LongTensor],
305
+ head_first: bool,
306
+ chunk_size: int
307
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
308
+ if head_first:
309
+ B, H, T, K, V = *a.shape, v.shape[-1]
310
+ else:
311
+ B, T, H, K, V = *a.shape, v.shape[-1]
312
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
313
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
314
+ CONST_TILING = 64 if check_shared_mem() else 32
315
+ BK = min(triton.next_power_of_2(K), CONST_TILING)
316
+ BV = min(triton.next_power_of_2(V), CONST_TILING)
317
+
318
+ u = torch.empty_like(v)
319
+ w = torch.empty_like(a)
320
+ fwd_wu_kernel[(NT, B*H)](
321
+ a=a,
322
+ v=v,
323
+ w=w,
324
+ u=u,
325
+ A=A,
326
+ k=k,
327
+ offsets=offsets,
328
+ indices=indices,
329
+ T=T,
330
+ H=H,
331
+ K=K,
332
+ V=V,
333
+ BT=BT,
334
+ BK=BK,
335
+ BV=BV,
336
+ HEAD_FIRST=head_first
337
+ )
338
+ return w, u
fla/ops/gla/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk import chunk_gla
4
+ from .fused_chunk import fused_chunk_gla
5
+ from .fused_recurrent import fused_recurrent_gla
6
+
7
+ __all__ = [
8
+ 'chunk_gla',
9
+ 'fused_chunk_gla',
10
+ 'fused_recurrent_gla'
11
+ ]
fla/ops/gla/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (397 Bytes). View file
 
fla/ops/gla/__pycache__/chunk.cpython-311.pyc ADDED
Binary file (83.1 kB). View file
 
fla/ops/gla/__pycache__/fused_chunk.cpython-311.pyc ADDED
Binary file (36.3 kB). View file