zaydzuhri commited on
Commit
8d737a7
·
verified ·
1 Parent(s): 7be73e9

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fla/ops/abc/__pycache__/__init__.cpython-312.pyc +0 -0
  2. fla/ops/abc/__pycache__/chunk.cpython-312.pyc +0 -0
  3. fla/ops/attn/__pycache__/parallel.cpython-312.pyc +0 -0
  4. fla/ops/based/__pycache__/fused_chunk.cpython-312.pyc +0 -0
  5. fla/ops/common/__pycache__/chunk_delta_h.cpython-312.pyc +0 -0
  6. fla/ops/common/__pycache__/chunk_o.cpython-312.pyc +0 -0
  7. fla/ops/common/chunk_scaled_dot_kkt.py +126 -0
  8. fla/ops/delta_rule/__pycache__/__init__.cpython-312.pyc +0 -0
  9. fla/ops/delta_rule/fused_chunk.py +6 -0
  10. fla/ops/forgetting_attn/parallel.py +708 -0
  11. fla/ops/gated_delta_rule/__pycache__/__init__.cpython-312.pyc +0 -0
  12. fla/ops/gated_delta_rule/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  13. fla/ops/generalized_delta_rule/__pycache__/__init__.cpython-312.pyc +0 -0
  14. fla/ops/generalized_delta_rule/dplr/__pycache__/chunk.cpython-312.pyc +0 -0
  15. fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_bwd.cpython-312.pyc +0 -0
  16. fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_fwd.cpython-312.pyc +0 -0
  17. fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_fwd.cpython-312.pyc +0 -0
  18. fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_bwd.cpython-312.pyc +0 -0
  19. fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_fwd.cpython-312.pyc +0 -0
  20. fla/ops/generalized_delta_rule/dplr/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  21. fla/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_bwd.cpython-312.pyc +0 -0
  22. fla/ops/generalized_delta_rule/dplr/chunk.py +388 -0
  23. fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py +446 -0
  24. fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py +324 -0
  25. fla/ops/generalized_delta_rule/dplr/chunk_h_bwd.py +196 -0
  26. fla/ops/generalized_delta_rule/dplr/chunk_h_fwd.py +197 -0
  27. fla/ops/generalized_delta_rule/dplr/chunk_o_fwd.py +138 -0
  28. fla/ops/generalized_delta_rule/dplr/fused_recurrent.py +292 -0
  29. fla/ops/generalized_delta_rule/dplr/naive.py +96 -0
  30. fla/ops/generalized_delta_rule/dplr/wy_fast_bwd.py +184 -0
  31. fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py +318 -0
  32. fla/ops/generalized_delta_rule/iplr/__pycache__/__init__.cpython-312.pyc +0 -0
  33. fla/ops/generalized_delta_rule/iplr/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  34. fla/ops/generalized_delta_rule/iplr/fused_recurrent.py +451 -0
  35. fla/ops/generalized_delta_rule/iplr/wy_fast.py +338 -0
  36. fla/ops/gla/__pycache__/__init__.cpython-312.pyc +0 -0
  37. fla/ops/gla/__pycache__/chunk.cpython-312.pyc +0 -0
  38. fla/ops/gla/__pycache__/fused_chunk.cpython-312.pyc +0 -0
  39. fla/ops/gsa/__init__.py +9 -0
  40. fla/ops/gsa/__pycache__/__init__.cpython-312.pyc +0 -0
  41. fla/ops/gsa/__pycache__/chunk.cpython-312.pyc +0 -0
  42. fla/ops/gsa/chunk.py +1264 -0
  43. fla/ops/gsa/naive.py +68 -0
  44. fla/ops/hgrn/__pycache__/__init__.cpython-312.pyc +0 -0
  45. fla/ops/hgrn/fused_recurrent.py +308 -0
  46. fla/ops/hgrn/naive.py +63 -0
  47. fla/ops/lightning_attn/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  48. fla/ops/linear_attn/__pycache__/fused_chunk.cpython-312.pyc +0 -0
  49. fla/ops/linear_attn/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  50. fla/ops/nsa/__pycache__/__init__.cpython-312.pyc +0 -0
fla/ops/abc/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (212 Bytes). View file
 
fla/ops/abc/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (72 kB). View file
 
fla/ops/attn/__pycache__/parallel.cpython-312.pyc ADDED
Binary file (33.1 kB). View file
 
fla/ops/based/__pycache__/fused_chunk.cpython-312.pyc ADDED
Binary file (22.4 kB). View file
 
fla/ops/common/__pycache__/chunk_delta_h.cpython-312.pyc ADDED
Binary file (23.9 kB). View file
 
fla/ops/common/__pycache__/chunk_o.cpython-312.pyc ADDED
Binary file (37 kB). View file
 
fla/ops/common/chunk_scaled_dot_kkt.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.utils import prepare_chunk_indices
11
+
12
+
13
+ @triton.heuristics({
14
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
15
+ })
16
+ @triton.autotune(
17
+ configs=[
18
+ triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages)
19
+ for BK in [32, 64, 128]
20
+ for num_warps in [2, 4, 8]
21
+ for num_stages in [2, 3, 4]
22
+ ],
23
+ key=['H', 'K', 'BT', 'USE_OFFSETS'],
24
+ )
25
+ @triton.jit(do_not_specialize=['T'])
26
+ def chunk_scaled_dot_kkt_fwd_kernel(
27
+ k,
28
+ beta,
29
+ A,
30
+ offsets,
31
+ indices,
32
+ T,
33
+ H: tl.constexpr,
34
+ K: tl.constexpr,
35
+ BT: tl.constexpr,
36
+ BK: tl.constexpr,
37
+ HEAD_FIRST: tl.constexpr,
38
+ USE_OFFSETS: tl.constexpr,
39
+ ):
40
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
41
+ i_b, i_h = i_bh // H, i_bh % H
42
+ if USE_OFFSETS:
43
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
44
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
45
+ T = eos - bos
46
+ else:
47
+ bos, eos = i_b * T, i_b * T + T
48
+ o_t = tl.arange(0, BT)
49
+
50
+ if HEAD_FIRST:
51
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
52
+ else:
53
+ p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
54
+ b_beta = tl.load(p_beta, boundary_check=(0,))
55
+
56
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
57
+ for i_k in range(tl.cdiv(K, BK)):
58
+ if HEAD_FIRST:
59
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
60
+ else:
61
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
62
+ b_k = tl.load(p_k, boundary_check=(0, 1))
63
+ b_kb = b_k * b_beta[:, None]
64
+ b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k))
65
+
66
+ b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0)
67
+ if HEAD_FIRST:
68
+ p_A = tl.make_block_ptr(A + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
69
+ else:
70
+ p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (BT*H, 1), (i_t * BT, 0), (BT, BT), (1, 0))
71
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
72
+
73
+
74
+ def chunk_scaled_dot_kkt_fwd(
75
+ k: torch.Tensor,
76
+ beta: torch.Tensor,
77
+ cu_seqlens: Optional[torch.LongTensor],
78
+ head_first: bool = False,
79
+ chunk_size: int = 64,
80
+ output_dtype: torch.dtype = torch.float32
81
+ ) -> torch.Tensor:
82
+ r"""
83
+ Compute beta * K * K^T.
84
+
85
+ Args:
86
+ k (torch.Tensor):
87
+ The key tensor of shape `[B, T, H, K]` if not `head_first` else `[B, H, T, K]`.
88
+ beta (torch.Tensor):
89
+ The beta tensor of shape `[B, T, H]` if not `head_first` else `[B, H, T]`.
90
+ cu_seqlens (torch.LongTensor):
91
+ The cumulative sequence lengths of the input tensor.
92
+ Default: None
93
+ head_first (bool):
94
+ If False, the input/output tensor is in the shape of `[B, T, H, K]`.
95
+ If True, the input/output tensor is in the shape of `[B, H, T, K]`.
96
+ Default: False
97
+ chunk_size (int):
98
+ The chunk size. Default: 64.
99
+ output_dtype (torch.dtype):
100
+ The dtype of the output tensor. Default: `torch.float32`
101
+
102
+ Returns:
103
+ beta * K * K^T of shape `[B, T, H, BT]` if not `head_first` else `[B, H, T, BT]`,
104
+ where `BT` is the chunk size.
105
+ """
106
+ if head_first:
107
+ B, H, T, K = k.shape
108
+ else:
109
+ B, T, H, K = k.shape
110
+ BT = chunk_size
111
+ indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
112
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(indices)
113
+ A = torch.empty(B, *((H, T) if head_first else (T, H)), BT, device=k.device, dtype=output_dtype)
114
+ chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)](
115
+ k=k,
116
+ beta=beta,
117
+ A=A,
118
+ offsets=cu_seqlens,
119
+ indices=indices,
120
+ T=T,
121
+ H=H,
122
+ K=K,
123
+ BT=BT,
124
+ HEAD_FIRST=head_first
125
+ )
126
+ return A
fla/ops/delta_rule/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (361 Bytes). View file
 
fla/ops/delta_rule/fused_chunk.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ def fused_chunk_delta_rule(
4
+ **kwargs
5
+ ):
6
+ raise NotImplementedError("fused_chunk_delta_rule is deprecated. Please use chunk_delta_rule instead.")
fla/ops/forgetting_attn/parallel.py ADDED
@@ -0,0 +1,708 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from einops import rearrange, reduce
10
+
11
+ from fla.ops.common.utils import prepare_chunk_indices
12
+ from fla.ops.utils import chunk_global_cumsum, chunk_local_cumsum
13
+ from fla.ops.utils.op import div, exp, log
14
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, input_guard
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
23
+ for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
24
+ for num_stages in [2, 3, 4, 5]
25
+ ],
26
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
27
+ )
28
+ @triton.jit
29
+ def parallel_forgetting_attn_fwd_kernel(
30
+ q,
31
+ k,
32
+ v,
33
+ g,
34
+ o,
35
+ lse,
36
+ scale,
37
+ offsets,
38
+ indices,
39
+ T,
40
+ B: tl.constexpr,
41
+ H: tl.constexpr,
42
+ HQ: tl.constexpr,
43
+ G: tl.constexpr,
44
+ K: tl.constexpr,
45
+ V: tl.constexpr,
46
+ BT: tl.constexpr,
47
+ BS: tl.constexpr,
48
+ BK: tl.constexpr,
49
+ BV: tl.constexpr,
50
+ USE_OFFSETS: tl.constexpr
51
+ ):
52
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
53
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
54
+ i_h = i_hq // G
55
+
56
+ if USE_OFFSETS:
57
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
58
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
59
+ T = eos - bos
60
+ else:
61
+ i_n = i_b
62
+ bos, eos = i_n * T, i_n * T + T
63
+
64
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
65
+ p_g = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
66
+ p_o = tl.make_block_ptr(o + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
67
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
68
+
69
+ # the Q block is kept in the shared memory throughout the whole kernel
70
+ # [BT, BK]
71
+ b_q = tl.load(p_q, boundary_check=(0, 1))
72
+ b_q = (b_q * scale).to(b_q.dtype)
73
+ # [BT,]
74
+ b_gq = tl.load(p_g, boundary_check=(0,)).to(tl.float32)
75
+ # [BT, BV]
76
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
77
+
78
+ b_m = tl.full([BT], float('-inf'), dtype=tl.float32)
79
+ b_acc = tl.zeros([BT], dtype=tl.float32)
80
+
81
+ # [BT]
82
+ o_q = i_t * BT + tl.arange(0, BT)
83
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
84
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
85
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
86
+ p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
87
+
88
+ # [BS]
89
+ o_k = i_s + tl.arange(0, BS)
90
+ # [BK, BS]
91
+ b_k = tl.load(p_k, boundary_check=(0, 1))
92
+ # [BS, BV]
93
+ b_v = tl.load(p_v, boundary_check=(0, 1))
94
+ # [BS,]
95
+ b_gk = tl.load(p_gk, boundary_check=(0,))
96
+ # [BT, BS]
97
+ b_s = tl.dot(b_q, b_k) + b_gq[:, None] - b_gk[None, :]
98
+ b_s = tl.where(o_q[:, None] >= o_k[None, :], b_s, float('-inf'))
99
+
100
+ # [BT]
101
+ b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m
102
+ b_r = exp(b_mp - b_m)
103
+ # [BT, BS]
104
+ b_p = exp(b_s - b_m[:, None])
105
+ # [BT]
106
+ b_acc = b_acc * b_r + tl.sum(b_p, 1)
107
+ # [BT, BV]
108
+ b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v)
109
+
110
+ b_mp = b_m
111
+
112
+ for i_s in range(i_t * BT - BS, -BS, -BS):
113
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
114
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
115
+ p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
116
+
117
+ # [BK, BS]
118
+ b_k = tl.load(p_k, boundary_check=(0, 1))
119
+ # [BS, BV]
120
+ b_v = tl.load(p_v, boundary_check=(0, 1))
121
+ # [BS,]
122
+ b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32)
123
+
124
+ b_gn = tl.load(g + (bos + min(i_s + BS, T) - 1) * HQ + i_hq).to(tl.float32)
125
+ b_gp = tl.load(g + (bos + i_s - 1) * HQ + i_hq).to(tl.float32) if i_s % BT > 0 else 0.
126
+ # [BT, BS]
127
+ b_s = tl.dot(b_q, b_k) + b_gq[:, None] + (b_gn - b_gk)[None, :]
128
+
129
+ b_gq += b_gn - b_gp
130
+ b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m
131
+ b_r = exp(b_mp - b_m)
132
+ # [BT, BS]
133
+ b_p = exp(b_s - b_m[:, None])
134
+ # [BT]
135
+ b_acc = b_acc * b_r + tl.sum(b_p, 1)
136
+ # [BT, BV]
137
+ b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v)
138
+
139
+ b_mp = b_m
140
+
141
+ b_o = div(b_o, b_acc[:, None])
142
+ b_m += log(b_acc)
143
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
144
+ tl.store(p_lse, b_m.to(p_lse.dtype.element_ty), boundary_check=(0,))
145
+
146
+
147
+ @triton.jit
148
+ def parallel_forgetting_attn_bwd_kernel_preprocess(
149
+ o,
150
+ do,
151
+ delta,
152
+ B: tl.constexpr,
153
+ V: tl.constexpr
154
+ ):
155
+ i_n = tl.program_id(0)
156
+ o_d = tl.arange(0, B)
157
+ m_d = o_d < V
158
+
159
+ b_o = tl.load(o + i_n * V + o_d, mask=m_d, other=0)
160
+ b_do = tl.load(do + i_n * V + o_d, mask=m_d, other=0).to(tl.float32)
161
+ b_delta = tl.sum(b_o * b_do)
162
+
163
+ tl.store(delta + i_n, b_delta.to(delta.dtype.element_ty))
164
+
165
+
166
+ @triton.heuristics({
167
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
168
+ })
169
+ @triton.autotune(
170
+ configs=[
171
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
172
+ for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
173
+ for num_stages in [2, 3, 4]
174
+ ],
175
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
176
+ )
177
+ @triton.jit(do_not_specialize=['T'])
178
+ def parallel_forgetting_attn_bwd_kernel_dq(
179
+ q,
180
+ k,
181
+ v,
182
+ g,
183
+ lse,
184
+ delta,
185
+ do,
186
+ dq,
187
+ dg,
188
+ scale,
189
+ offsets,
190
+ indices,
191
+ T,
192
+ B: tl.constexpr,
193
+ H: tl.constexpr,
194
+ HQ: tl.constexpr,
195
+ G: tl.constexpr,
196
+ K: tl.constexpr,
197
+ V: tl.constexpr,
198
+ BT: tl.constexpr,
199
+ BS: tl.constexpr,
200
+ BK: tl.constexpr,
201
+ BV: tl.constexpr,
202
+ USE_OFFSETS: tl.constexpr
203
+ ):
204
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
205
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
206
+ i_h = i_hq // G
207
+
208
+ if USE_OFFSETS:
209
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
210
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
211
+ T = eos - bos
212
+ else:
213
+ i_n = i_b
214
+ bos, eos = i_n * T, i_n * T + T
215
+
216
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
217
+ p_g = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
218
+ p_dq = tl.make_block_ptr(dq + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
219
+ p_dg = tl.make_block_ptr(dg + (bos * HQ + i_hq), (T,), (HQ,), (i_t * BT,), (BT,), (0,))
220
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
221
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
222
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
223
+
224
+ # [BT, BK]
225
+ b_q = tl.load(p_q, boundary_check=(0, 1))
226
+ b_q = (b_q * scale).to(b_q.dtype)
227
+ # [BT, BV]
228
+ b_do = tl.load(p_do, boundary_check=(0, 1))
229
+ # [BT]
230
+ b_gq = tl.load(p_g, boundary_check=(0,)).to(tl.float32)
231
+ b_lse = tl.load(p_lse, boundary_check=(0,))
232
+ b_delta = tl.load(p_delta, boundary_check=(0,))
233
+
234
+ # [BT]
235
+ o_q = i_t * BT + tl.arange(0, BT)
236
+ # [BT, BK]
237
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
238
+ # [BT]
239
+ b_dg = tl.zeros([BT,], dtype=tl.float32)
240
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
241
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
242
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1))
243
+ p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
244
+
245
+ # [BS]
246
+ o_k = i_s + tl.arange(0, BS)
247
+ # [BK, BS]
248
+ b_k = tl.load(p_k, boundary_check=(0, 1))
249
+ # [BV, BS]
250
+ b_v = tl.load(p_v, boundary_check=(0, 1))
251
+ # [BS,]
252
+ b_gk = tl.load(p_gk, boundary_check=(0,))
253
+ # [BT, BS]
254
+ b_s = tl.dot(b_q, b_k) + (b_gq - b_lse)[:, None] - b_gk[None, :]
255
+ b_p = exp(tl.where(o_q[:, None] >= o_k[None, :], b_s, float('-inf')))
256
+
257
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
258
+ b_dp = tl.dot(b_do, b_v)
259
+ b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None])
260
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
261
+ b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k))
262
+ # [BT]
263
+ b_dg += tl.sum(b_ds, 1)
264
+
265
+ for i_s in range(i_t * BT - BS, -BS, -BS):
266
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
267
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1))
268
+ p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
269
+
270
+ # [BK, BS]
271
+ b_k = tl.load(p_k, boundary_check=(0, 1))
272
+ # [BV, BS]
273
+ b_v = tl.load(p_v, boundary_check=(0, 1))
274
+ # [BS,]
275
+ b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32)
276
+
277
+ b_gn = tl.load(g + (bos + min(i_s + BS, T) - 1) * HQ + i_hq).to(tl.float32)
278
+ b_gp = tl.load(g + (bos + i_s - 1) * HQ + i_hq).to(tl.float32) if i_s % BT > 0 else 0.
279
+ # [BT, BS]
280
+ b_s = tl.dot(b_q, b_k) + (b_gq - b_lse)[:, None] + (b_gn - b_gk)[None, :]
281
+ b_p = exp(b_s)
282
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
283
+ b_dp = tl.dot(b_do, b_v)
284
+ b_ds = b_p * (b_dp - b_delta[:, None])
285
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
286
+ b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k))
287
+ # [BT]
288
+ b_dg += tl.sum(b_ds, 1)
289
+
290
+ b_gq += b_gn - b_gp
291
+
292
+ b_dq *= scale
293
+
294
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
295
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,))
296
+
297
+
298
+ @triton.heuristics({
299
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
300
+ })
301
+ @triton.autotune(
302
+ configs=[
303
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
304
+ for num_warps in [1, 2, 4, 8]
305
+ for num_stages in [2, 3, 4]
306
+ ],
307
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
308
+ )
309
+ @triton.jit(do_not_specialize=['T'])
310
+ def parallel_forgetting_attn_bwd_kernel_dkv(
311
+ q,
312
+ k,
313
+ v,
314
+ g,
315
+ lse,
316
+ delta,
317
+ do,
318
+ dk,
319
+ dv,
320
+ dg,
321
+ offsets,
322
+ indices,
323
+ scale,
324
+ T,
325
+ B: tl.constexpr,
326
+ H: tl.constexpr,
327
+ HQ: tl.constexpr,
328
+ G: tl.constexpr,
329
+ K: tl.constexpr,
330
+ V: tl.constexpr,
331
+ BT: tl.constexpr,
332
+ BS: tl.constexpr,
333
+ BK: tl.constexpr,
334
+ BV: tl.constexpr,
335
+ USE_OFFSETS: tl.constexpr
336
+ ):
337
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
338
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
339
+ i_h = i_hq // G
340
+
341
+ if USE_OFFSETS:
342
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
343
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
344
+ T = eos - bos
345
+ else:
346
+ i_n = i_b
347
+ bos, eos = i_n * T, i_n * T + T
348
+
349
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
350
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
351
+ p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
352
+ p_dk = tl.make_block_ptr(dk + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
353
+ p_dv = tl.make_block_ptr(dv + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
354
+ p_dg = tl.make_block_ptr(dg + (bos * HQ + i_hq), (T,), (HQ,), (i_t * BT,), (BT,), (0,))
355
+
356
+ # [BT, BK]
357
+ b_k = tl.load(p_k, boundary_check=(0, 1))
358
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
359
+ # [BT, BV]
360
+ b_v = tl.load(p_v, boundary_check=(0, 1))
361
+ b_dv = tl.zeros([BT, BV], dtype=tl.float32)
362
+ # [BT]
363
+ b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32)
364
+ b_dg = tl.zeros([BT,], dtype=tl.float32)
365
+
366
+ o_k = i_t * BT + tl.arange(0, BT)
367
+ m_k = o_k < T
368
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
369
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0))
370
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
371
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
372
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
373
+ p_gq = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
374
+
375
+ # [BS]
376
+ o_q = i_s + tl.arange(0, BS)
377
+ # [BS, BK]
378
+ b_q = tl.load(p_q, boundary_check=(0, 1))
379
+ b_q = (b_q * scale).to(b_q.dtype)
380
+ # [BS, BV]
381
+ b_do = tl.load(p_do, boundary_check=(0, 1))
382
+ # [BS]
383
+ b_lse = tl.load(p_lse, boundary_check=(0,))
384
+ b_delta = tl.load(p_delta, boundary_check=(0,))
385
+ b_gq = tl.load(p_gq, boundary_check=(0,)).to(tl.float32)
386
+
387
+ m_q = o_q < T
388
+ m_s = (o_k[:, None] <= o_q[None, :]) & m_k[:, None] & m_q[None, :]
389
+ # [BT, BS]
390
+ b_s = tl.dot(b_k, tl.trans(b_q)) - b_gk[:, None] + (b_gq - b_lse)[None, :]
391
+ b_p = tl.where(m_s, exp(b_s), 0)
392
+ # [BT, BS] @ [BS, BV] -> [BT, BV]
393
+ b_dv += tl.dot(b_p.to(b_do.dtype), b_do)
394
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
395
+ b_dp = tl.dot(b_v, tl.trans(b_do))
396
+ # [BT, BS]
397
+ b_ds = b_p * (b_dp - b_delta[None, :])
398
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
399
+ b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
400
+ # [BT]
401
+ b_dg -= tl.sum(b_ds, 1)
402
+
403
+ b_gk -= tl.load(g + (bos + min((i_t + 1) * BT, T) - 1) * HQ + i_hq).to(tl.float32)
404
+ for i_s in range((i_t + 1) * BT, T, BS):
405
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0))
406
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
407
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
408
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
409
+ p_gq = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
410
+
411
+ # [BS]
412
+ o_q = i_s + tl.arange(0, BS)
413
+ # [BS, BK]
414
+ b_q = tl.load(p_q, boundary_check=(0, 1))
415
+ b_q = (b_q * scale).to(b_q.dtype)
416
+ # [BS, BV]
417
+ b_do = tl.load(p_do, boundary_check=(0, 1))
418
+ # [BS]
419
+ b_lse = tl.load(p_lse, boundary_check=(0,))
420
+ b_delta = tl.load(p_delta, boundary_check=(0,))
421
+ b_gq = tl.load(p_gq, boundary_check=(0,)).to(tl.float32)
422
+
423
+ b_gn = tl.load(g + (bos + min(i_s + BS, T) - 1) * HQ + i_hq).to(tl.float32)
424
+ b_gp = tl.load(g + (bos + i_s - 1) * HQ + i_hq).to(tl.float32) if i_s % BT > 0 else 0.
425
+ # [BT, BS]
426
+ b_s = tl.dot(b_k, tl.trans(b_q)) - (b_gk + b_gp)[:, None] + (b_gq - b_lse)[None, :]
427
+ b_p = exp(b_s)
428
+ # [BT, BS] @ [BS, BV] -> [BT, BV]
429
+ b_dv += tl.dot(b_p.to(b_do.dtype), b_do)
430
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
431
+ b_dp = tl.dot(b_v, tl.trans(b_do))
432
+ # [BT, BS]
433
+ b_ds = b_p * (b_dp - b_delta[None, :])
434
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
435
+ b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
436
+ # [BT]
437
+ b_dg -= tl.sum(b_ds, 1)
438
+
439
+ b_gk -= b_gn - b_gp
440
+
441
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
442
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
443
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,))
444
+
445
+
446
+ def parallel_forgetting_attn_fwd(
447
+ q: torch.Tensor,
448
+ k: torch.Tensor,
449
+ v: torch.Tensor,
450
+ g: torch.Tensor,
451
+ scale: float,
452
+ chunk_size: int = 128,
453
+ offsets: Optional[torch.LongTensor] = None,
454
+ indices: Optional[torch.LongTensor] = None,
455
+ ):
456
+ B, T, H, K, V = *k.shape, v.shape[-1]
457
+ HQ = q.shape[2]
458
+ G = HQ // H
459
+ BT = chunk_size
460
+ BK = max(16, triton.next_power_of_2(K))
461
+ assert V <= 256, "V must be less than or equal to 256"
462
+ if check_shared_mem('hopper'):
463
+ BS = min(64, max(16, triton.next_power_of_2(T)))
464
+ else:
465
+ BS = min(32, max(16, triton.next_power_of_2(T)))
466
+ BV = min(256, max(16, triton.next_power_of_2(V)))
467
+ NV = triton.cdiv(V, BV)
468
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
469
+
470
+ o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device)
471
+ lse = torch.empty(B, T, HQ, dtype=torch.float, device=q.device)
472
+
473
+ grid = (NV, NT, B * HQ)
474
+ parallel_forgetting_attn_fwd_kernel[grid](
475
+ q=q,
476
+ k=k,
477
+ v=v,
478
+ g=g,
479
+ o=o,
480
+ lse=lse,
481
+ scale=scale,
482
+ offsets=offsets,
483
+ indices=indices,
484
+ B=B,
485
+ T=T,
486
+ H=H,
487
+ HQ=HQ,
488
+ G=G,
489
+ K=K,
490
+ V=V,
491
+ BT=BT,
492
+ BS=BS,
493
+ BK=BK,
494
+ BV=BV,
495
+ )
496
+ return o, lse
497
+
498
+
499
+ def parallel_forgetting_attn_bwd_preprocess(
500
+ o: torch.Tensor,
501
+ do: torch.Tensor
502
+ ):
503
+ V = o.shape[-1]
504
+ delta = torch.empty_like(o[..., 0], dtype=torch.float)
505
+ parallel_forgetting_attn_bwd_kernel_preprocess[(delta.numel(),)](
506
+ o=o,
507
+ do=do,
508
+ delta=delta,
509
+ B=triton.next_power_of_2(V),
510
+ V=V,
511
+ )
512
+ return delta
513
+
514
+
515
+ def parallel_forgetting_attn_bwd(
516
+ q: torch.Tensor,
517
+ k: torch.Tensor,
518
+ v: torch.Tensor,
519
+ g: torch.Tensor,
520
+ o: torch.Tensor,
521
+ lse: torch.Tensor,
522
+ do: torch.Tensor,
523
+ scale: float = None,
524
+ chunk_size: int = 128,
525
+ offsets: Optional[torch.LongTensor] = None,
526
+ indices: Optional[torch.LongTensor] = None,
527
+ ):
528
+ B, T, H, K, V = *k.shape, v.shape[-1]
529
+ HQ = q.shape[2]
530
+ G = HQ // H
531
+ BT = chunk_size
532
+ BS = min(32, max(16, triton.next_power_of_2(T)))
533
+ BK = max(16, triton.next_power_of_2(K))
534
+ BV = max(16, triton.next_power_of_2(V))
535
+ NV = triton.cdiv(V, BV)
536
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
537
+
538
+ delta = parallel_forgetting_attn_bwd_preprocess(o, do)
539
+ dq = q.new_empty(B, T, HQ, K, dtype=q.dtype)
540
+ dk = q.new_empty(B, T, HQ, K, dtype=k.dtype if H == HQ else torch.float)
541
+ dv = q.new_empty(B, T, HQ, V, dtype=v.dtype if H == HQ else torch.float)
542
+ dg = q.new_empty(g.shape, dtype=torch.float)
543
+ # NOTE: the original `dg` can be destroyed during autotuning
544
+ # this is [a known triton issue](https://github.com/triton-lang/triton/issues/5082), which will be fixed in 3.3 (?)
545
+ # so we need to make a copy of `dg`
546
+ dg2 = q.new_empty(g.shape, dtype=torch.float)
547
+ grid = (NV, NT, B * HQ)
548
+ parallel_forgetting_attn_bwd_kernel_dq[grid](
549
+ q=q,
550
+ k=k,
551
+ v=v,
552
+ g=g,
553
+ lse=lse,
554
+ delta=delta,
555
+ do=do,
556
+ dq=dq,
557
+ dg=dg,
558
+ offsets=offsets,
559
+ indices=indices,
560
+ scale=scale,
561
+ T=T,
562
+ B=B,
563
+ H=H,
564
+ HQ=HQ,
565
+ G=G,
566
+ K=K,
567
+ V=V,
568
+ BT=BT,
569
+ BS=BS,
570
+ BK=BK,
571
+ BV=BV
572
+ )
573
+ parallel_forgetting_attn_bwd_kernel_dkv[grid](
574
+ q=q,
575
+ k=k,
576
+ v=v,
577
+ g=g,
578
+ lse=lse,
579
+ delta=delta,
580
+ do=do,
581
+ dk=dk,
582
+ dv=dv,
583
+ dg=dg2,
584
+ offsets=offsets,
585
+ indices=indices,
586
+ scale=scale,
587
+ T=T,
588
+ B=B,
589
+ H=H,
590
+ HQ=HQ,
591
+ G=G,
592
+ K=K,
593
+ V=V,
594
+ BT=BT,
595
+ BS=BS,
596
+ BK=BK,
597
+ BV=BV
598
+ )
599
+ dk = reduce(dk, 'b t (h g) k -> b t h k', g=G, reduction='sum')
600
+ dv = reduce(dv, 'b t (h g) v -> b t h v', g=G, reduction='sum')
601
+ dg = dg.add_(dg2)
602
+ return dq, dk, dv, dg
603
+
604
+
605
+ @torch.compile
606
+ class ParallelForgettingAttentionFunction(torch.autograd.Function):
607
+
608
+ @staticmethod
609
+ @input_guard
610
+ @autocast_custom_fwd
611
+ def forward(ctx, q, k, v, g, scale, offsets):
612
+ ctx.dtype = q.dtype
613
+ if check_shared_mem('hopper'):
614
+ chunk_size = min(128, max(16, triton.next_power_of_2(q.shape[1])))
615
+ else:
616
+ chunk_size = min(64, max(16, triton.next_power_of_2(q.shape[1])))
617
+ # 2-d indices denoting the offsets of chunks in each sequence
618
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
619
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
620
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
621
+ indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None
622
+
623
+ g = chunk_local_cumsum(g, chunk_size, offsets=offsets, indices=indices, head_first=False)
624
+ o, lse = parallel_forgetting_attn_fwd(
625
+ q=q,
626
+ k=k,
627
+ v=v,
628
+ g=g,
629
+ scale=scale,
630
+ chunk_size=chunk_size,
631
+ offsets=offsets,
632
+ indices=indices
633
+ )
634
+ ctx.save_for_backward(q, k, v, g, o, lse)
635
+ ctx.chunk_size = chunk_size
636
+ ctx.offsets = offsets
637
+ ctx.indices = indices
638
+ ctx.scale = scale
639
+ return o.to(q.dtype)
640
+
641
+ @staticmethod
642
+ @input_guard
643
+ @autocast_custom_bwd
644
+ def backward(ctx, do):
645
+ q, k, v, g, o, lse = ctx.saved_tensors
646
+ dq, dk, dv, dg = parallel_forgetting_attn_bwd(
647
+ q=q,
648
+ k=k,
649
+ v=v,
650
+ g=g,
651
+ o=o,
652
+ lse=lse,
653
+ do=do,
654
+ scale=ctx.scale,
655
+ chunk_size=ctx.chunk_size,
656
+ offsets=ctx.offsets,
657
+ indices=ctx.indices
658
+ )
659
+ dg = chunk_global_cumsum(dg, reverse=True, head_first=False, offsets=ctx.offsets)
660
+ return dq.to(q), dk.to(k), dv.to(v), dg.to(g), None, None, None, None, None, None, None, None
661
+
662
+
663
+ def parallel_forgetting_attn(
664
+ q: torch.Tensor,
665
+ k: torch.Tensor,
666
+ v: torch.Tensor,
667
+ g: torch.Tensor,
668
+ scale: Optional[float] = None,
669
+ cu_seqlens: Optional[torch.LongTensor] = None,
670
+ head_first: bool = False
671
+ ) -> torch.Tensor:
672
+ r"""
673
+ Args:
674
+ q (torch.Tensor):
675
+ queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
676
+ k (torch.Tensor):
677
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
678
+ GQA will be applied if HQ is divisible by H.
679
+ v (torch.Tensor):
680
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
681
+ g (torch.Tensor):
682
+ Forget gates (in **log space**) of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
683
+ scale (Optional[int]):
684
+ Scale factor for attention scores.
685
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
686
+ cu_seqlens (torch.LongTensor):
687
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
688
+ consistent with the FlashAttention API.
689
+ head_first (Optional[bool]):
690
+ Whether the inputs are in the head-first format. Default: `False`.
691
+
692
+ Returns:
693
+ o (torch.Tensor):
694
+ Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
695
+ """
696
+ if scale is None:
697
+ scale = k.shape[-1] ** -0.5
698
+ if cu_seqlens is not None:
699
+ assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
700
+ if g is not None:
701
+ g = g.float()
702
+ if head_first:
703
+ q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v))
704
+ g = rearrange(g, 'b h t -> b t h')
705
+ o = ParallelForgettingAttentionFunction.apply(q, k, v, g, scale, cu_seqlens)
706
+ if head_first:
707
+ o = rearrange(o, 'b t h d -> b h t d')
708
+ return o
fla/ops/gated_delta_rule/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (319 Bytes). View file
 
fla/ops/gated_delta_rule/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (15.1 kB). View file
 
fla/ops/generalized_delta_rule/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (389 Bytes). View file
 
fla/ops/generalized_delta_rule/dplr/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (11.6 kB). View file
 
fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_bwd.cpython-312.pyc ADDED
Binary file (30.6 kB). View file
 
fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_fwd.cpython-312.pyc ADDED
Binary file (25.4 kB). View file
 
fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_fwd.cpython-312.pyc ADDED
Binary file (12.5 kB). View file
 
fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_bwd.cpython-312.pyc ADDED
Binary file (28 kB). View file
 
fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_fwd.cpython-312.pyc ADDED
Binary file (8.91 kB). View file
 
fla/ops/generalized_delta_rule/dplr/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (14.5 kB). View file
 
fla/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_bwd.cpython-312.pyc ADDED
Binary file (13.2 kB). View file
 
fla/ops/generalized_delta_rule/dplr/chunk.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+
9
+ from fla.ops.common.utils import prepare_chunk_indices
10
+ from fla.ops.generalized_delta_rule.dplr.chunk_A_bwd import chunk_dplr_bwd_dqk_intra
11
+ from fla.ops.generalized_delta_rule.dplr.chunk_A_fwd import chunk_fwd_intra_dplr_fn
12
+ from fla.ops.generalized_delta_rule.dplr.chunk_h_bwd import chunk_dplr_bwd_dhu
13
+ from fla.ops.generalized_delta_rule.dplr.chunk_h_fwd import chunk_dplr_fwd_h
14
+ from fla.ops.generalized_delta_rule.dplr.chunk_o_bwd import chunk_dplr_bwd_dAu, chunk_dplr_bwd_dv, chunk_dplr_bwd_o
15
+ from fla.ops.generalized_delta_rule.dplr.chunk_o_fwd import chunk_dplr_fwd_o
16
+ from fla.ops.generalized_delta_rule.dplr.wy_fast_bwd import chunk_dplr_bwd_wy
17
+ from fla.ops.generalized_delta_rule.dplr.wy_fast_fwd import fwd_prepare_wy_repr
18
+ from fla.ops.rwkv6.chunk import chunk_rwkv6_fwd_cumsum
19
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
20
+
21
+
22
+ def chunk_dplr_fwd(
23
+ q: torch.Tensor,
24
+ k: torch.Tensor,
25
+ v: torch.Tensor,
26
+ a: torch.Tensor,
27
+ b: torch.Tensor,
28
+ gk: torch.Tensor,
29
+ scale: float,
30
+ initial_state: torch.Tensor,
31
+ output_final_state: bool,
32
+ offsets: Optional[torch.LongTensor] = None,
33
+ indices: Optional[torch.LongTensor] = None,
34
+ head_first: bool = True,
35
+ chunk_size: int = 64
36
+ ):
37
+ T = q.shape[2] if head_first else q.shape[1]
38
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
39
+ gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT, offsets=offsets, indices=indices, head_first=head_first)
40
+
41
+ A_ab, A_qk, A_ak, A_qb, qg, kg, ag, bg = chunk_fwd_intra_dplr_fn(
42
+ q=q,
43
+ k=k,
44
+ a=a,
45
+ b=b,
46
+ gi=gi,
47
+ ge=ge,
48
+ scale=scale,
49
+ offsets=offsets,
50
+ indices=indices,
51
+ chunk_size=BT,
52
+ head_first=head_first
53
+ )
54
+ del ge
55
+
56
+ # A_ab, A_ak, gi, ge torch.float32
57
+ # A_qk, A_qb, qg, kg, ag, bg, dtype=q.dtype, eg: bf16
58
+ w, u, _ = fwd_prepare_wy_repr(
59
+ ag=ag,
60
+ A_ab=A_ab,
61
+ A_ak=A_ak,
62
+ v=v,
63
+ offsets=offsets,
64
+ indices=indices,
65
+ head_first=head_first,
66
+ chunk_size=BT
67
+ )
68
+ del A_ab, A_ak
69
+ h, v_new, final_state = chunk_dplr_fwd_h(
70
+ kg=kg,
71
+ bg=bg,
72
+ v=v,
73
+ w=w,
74
+ u=u,
75
+ gk=gi,
76
+ initial_state=initial_state,
77
+ output_final_state=output_final_state,
78
+ offsets=offsets,
79
+ indices=indices,
80
+ head_first=head_first,
81
+ chunk_size=BT
82
+ )
83
+ del u, kg, bg, gi
84
+
85
+ o = chunk_dplr_fwd_o(
86
+ qg=qg,
87
+ v=v,
88
+ v_new=v_new,
89
+ A_qk=A_qk,
90
+ A_qb=A_qb,
91
+ h=h,
92
+ offsets=offsets,
93
+ indices=indices,
94
+ head_first=head_first,
95
+ chunk_size=BT
96
+ )
97
+ del v_new, h, A_qk, A_qb
98
+
99
+ return o, final_state
100
+
101
+
102
+ class ChunkDPLRDeltaRuleFunction(torch.autograd.Function):
103
+
104
+ @staticmethod
105
+ @input_guard
106
+ @autocast_custom_fwd
107
+ def forward(
108
+ ctx,
109
+ q: torch.Tensor,
110
+ k: torch.Tensor,
111
+ v: torch.Tensor,
112
+ a: torch.Tensor,
113
+ b: torch.Tensor,
114
+ gk: torch.Tensor,
115
+ scale: float,
116
+ initial_state: torch.Tensor,
117
+ output_final_state: bool,
118
+ offsets: Optional[torch.LongTensor] = None,
119
+ head_first: bool = True
120
+ ):
121
+ chunk_size = 16
122
+
123
+ # 2-d indices denoting the offsets of chunks in each sequence
124
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
125
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
126
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
127
+ indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None
128
+
129
+ o, final_state = chunk_dplr_fwd(
130
+ q=q,
131
+ k=k,
132
+ v=v,
133
+ a=a,
134
+ b=b,
135
+ gk=gk,
136
+ scale=scale,
137
+ initial_state=initial_state,
138
+ output_final_state=output_final_state,
139
+ offsets=offsets,
140
+ indices=indices,
141
+ head_first=head_first,
142
+ chunk_size=chunk_size
143
+ )
144
+ ctx.save_for_backward(q, k, v, a, b, gk, initial_state)
145
+ ctx.head_first = head_first
146
+ ctx.offsets = offsets
147
+ ctx.indices = indices
148
+ ctx.scale = scale
149
+ ctx.chunk_size = chunk_size
150
+ return o.to(q.dtype), final_state
151
+
152
+ @staticmethod
153
+ @input_guard
154
+ @autocast_custom_bwd
155
+ def backward(
156
+ ctx,
157
+ do: torch.Tensor,
158
+ dht: torch.Tensor
159
+ ):
160
+ q, k, v, a, b, gk, initial_state = ctx.saved_tensors
161
+ BT = ctx.chunk_size
162
+ head_first = ctx.head_first
163
+ offsets = ctx.offsets
164
+ indices = ctx.indices
165
+ scale = ctx.scale
166
+
167
+ # ******* start recomputing everything, otherwise i believe the gpu memory will be exhausted *******
168
+ gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT, offsets=offsets, indices=indices, head_first=head_first)
169
+
170
+ A_ab, A_qk, A_ak, A_qb, qg, kg, ag, bg = chunk_fwd_intra_dplr_fn(
171
+ q=q,
172
+ k=k,
173
+ a=a,
174
+ b=b,
175
+ gi=gi,
176
+ ge=ge,
177
+ scale=scale,
178
+ offsets=offsets,
179
+ indices=indices,
180
+ chunk_size=BT,
181
+ head_first=head_first
182
+ )
183
+ w, u, A_ab_inv = fwd_prepare_wy_repr(
184
+ ag=ag,
185
+ A_ab=A_ab,
186
+ A_ak=A_ak,
187
+ v=v,
188
+ offsets=offsets,
189
+ indices=indices,
190
+ head_first=head_first,
191
+ chunk_size=BT
192
+ )
193
+ del A_ab
194
+ h, v_new, _ = chunk_dplr_fwd_h(
195
+ kg=kg,
196
+ bg=bg,
197
+ v=v,
198
+ w=w,
199
+ u=u,
200
+ gk=gi,
201
+ initial_state=initial_state,
202
+ offsets=offsets,
203
+ indices=indices,
204
+ head_first=head_first,
205
+ chunk_size=BT
206
+ )
207
+ del u
208
+ # ******* end of recomputation *******
209
+ # A_ak, A_ab_inv, gi, ge torch.float32
210
+ # A_qk, A_qb, qg, kg, ag, bg, v_new dtype=q.dtype, eg: bf16
211
+
212
+ dv_new_intra, dA_qk, dA_qb = chunk_dplr_bwd_dAu(
213
+ v=v,
214
+ v_new=v_new,
215
+ do=do,
216
+ A_qb=A_qb,
217
+ scale=scale,
218
+ offsets=offsets,
219
+ indices=indices,
220
+ head_first=head_first,
221
+ chunk_size=BT
222
+ )
223
+
224
+ dh, dh0, dv_new = chunk_dplr_bwd_dhu(
225
+ qg=qg,
226
+ bg=bg,
227
+ w=w,
228
+ gk=gi,
229
+ h0=initial_state,
230
+ dht=dht,
231
+ do=do,
232
+ dv=dv_new_intra,
233
+ offsets=offsets,
234
+ indices=indices,
235
+ head_first=head_first,
236
+ chunk_size=BT
237
+ )
238
+
239
+ dv = chunk_dplr_bwd_dv(
240
+ A_qk=A_qk,
241
+ kg=kg,
242
+ do=do,
243
+ dh=dh,
244
+ offsets=offsets,
245
+ indices=indices,
246
+ head_first=head_first,
247
+ chunk_size=BT
248
+ )
249
+ del A_qk
250
+
251
+ dqg, dkg, dw, dbg, dgk_last = chunk_dplr_bwd_o(
252
+ k=kg,
253
+ b=bg,
254
+ v=v,
255
+ v_new=v_new,
256
+ do=do,
257
+ h=h,
258
+ dh=dh,
259
+ dv=dv_new,
260
+ w=w,
261
+ gk=gi,
262
+ offsets=offsets,
263
+ indices=indices,
264
+ chunk_size=BT,
265
+ scale=scale,
266
+ head_first=head_first,
267
+ )
268
+ del v_new
269
+
270
+ dA_ab, dA_ak, dv, dag = chunk_dplr_bwd_wy(
271
+ A_ab_inv=A_ab_inv,
272
+ A_ak=A_ak,
273
+ v=v,
274
+ ag=ag,
275
+ dw=dw,
276
+ du=dv_new,
277
+ dv0=dv,
278
+ offsets=offsets,
279
+ indices=indices,
280
+ head_first=head_first,
281
+ chunk_size=BT
282
+ )
283
+ del A_ak
284
+
285
+ dq, dk, da, db, dgk = chunk_dplr_bwd_dqk_intra(
286
+ q=q,
287
+ k=k,
288
+ a=a,
289
+ b=b,
290
+ gi=gi,
291
+ ge=ge,
292
+ dAqk=dA_qk,
293
+ dAqb=dA_qb,
294
+ dAak=dA_ak,
295
+ dAab=dA_ab,
296
+ dgk_last=dgk_last,
297
+ dqg=dqg,
298
+ dkg=dkg,
299
+ dag=dag,
300
+ dbg=dbg,
301
+ chunk_size=BT,
302
+ scale=scale,
303
+ head_first=head_first,
304
+ offsets=offsets,
305
+ indices=indices
306
+ )
307
+
308
+ return dq.to(q), dk.to(k), dv.to(v), da.to(a), db.to(b), dgk.to(gk), None, dh0, None, None, None
309
+
310
+
311
+ @torch.compiler.disable
312
+ def chunk_dplr_delta_rule(
313
+ q: torch.Tensor,
314
+ k: torch.Tensor,
315
+ v: torch.Tensor,
316
+ a: torch.Tensor,
317
+ b: torch.Tensor,
318
+ gk: torch.Tensor,
319
+ scale: Optional[float] = None,
320
+ initial_state: Optional[torch.Tensor] = None,
321
+ output_final_state: bool = False,
322
+ cu_seqlens: Optional[torch.LongTensor] = None,
323
+ head_first: bool = False
324
+ ):
325
+ r"""
326
+ Args:
327
+ q (torch.Tensor):
328
+ queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
329
+ k (torch.Tensor):
330
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
331
+ v (torch.Tensor):
332
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
333
+ a (torch.Tensor):
334
+ activations of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
335
+ b (torch.Tensor):
336
+ betas of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
337
+ gk (torch.Tensor):
338
+ gk of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. decay term in log space!
339
+ scale (Optional[int]):
340
+ Scale factor for the RetNet attention scores.
341
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
342
+ initial_state (Optional[torch.Tensor]):
343
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
344
+ For equal-length input sequences, `N` equals the batch size `B`.
345
+ Default: `None`.
346
+ output_final_state (Optional[bool]):
347
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
348
+ cu_seqlens (torch.LongTensor):
349
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
350
+ consistent with the FlashAttention API.
351
+ head_first (Optional[bool]):
352
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
353
+ Default: `False`.
354
+
355
+ Returns:
356
+ o (torch.Tensor):
357
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
358
+ final_state (torch.Tensor):
359
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
360
+ """
361
+ assert q.dtype == k.dtype == v.dtype
362
+ # assert q.dtype != torch.float32, "ChunkDeltaRuleFunction does not support float32. Please use bfloat16."
363
+ # gk = gk.float()
364
+
365
+ if cu_seqlens is not None:
366
+ if q.shape[0] != 1:
367
+ raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
368
+ f"Please flatten variable-length inputs before processing.")
369
+ if head_first:
370
+ raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
371
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
372
+ raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, "
373
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.")
374
+ scale = k.shape[-1] ** -0.5 if scale is None else scale
375
+ o, final_state = ChunkDPLRDeltaRuleFunction.apply(
376
+ q,
377
+ k,
378
+ v,
379
+ a,
380
+ b,
381
+ gk,
382
+ scale,
383
+ initial_state,
384
+ output_final_state,
385
+ cu_seqlens,
386
+ head_first
387
+ )
388
+ return o, final_state
fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import exp, gather
11
+ from fla.utils import check_shared_mem, is_gather_supported, use_cuda_graph
12
+
13
+
14
+ @triton.heuristics({
15
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
16
+ })
17
+ @triton.autotune(
18
+ configs=[
19
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
20
+ for num_warps in [2, 4, 8, 16, 32]
21
+ for num_stages in [2, 3, 4]
22
+ ],
23
+ key=['BK', 'NC', 'BT', 'K'],
24
+ use_cuda_graph=use_cuda_graph,
25
+ )
26
+ @triton.jit(do_not_specialize=['T'])
27
+ def chunk_dplr_bwd_kernel_intra(
28
+ q,
29
+ k,
30
+ a,
31
+ b,
32
+ gi,
33
+ ge,
34
+ dAqk,
35
+ dAqb,
36
+ dAak,
37
+ dAab,
38
+ dq,
39
+ dk,
40
+ da,
41
+ db,
42
+ dqg,
43
+ dkg,
44
+ dag,
45
+ dbg,
46
+ dgk,
47
+ dgk_offset,
48
+ offsets,
49
+ indices,
50
+ scale: tl.constexpr,
51
+ T,
52
+ H: tl.constexpr,
53
+ K: tl.constexpr,
54
+ BT: tl.constexpr,
55
+ BC: tl.constexpr,
56
+ BK: tl.constexpr,
57
+ NC: tl.constexpr,
58
+ USE_OFFSETS: tl.constexpr,
59
+ HEAD_FIRST: tl.constexpr,
60
+ GATHER_SUPPORTED: tl.constexpr
61
+ ):
62
+ i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
63
+ i_b, i_h = i_bh // H, i_bh % H
64
+ i_t, i_i = i_c // NC, i_c % NC
65
+ if USE_OFFSETS:
66
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
67
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
68
+ else:
69
+ bos, eos = i_b * T, i_b * T + T
70
+ T = eos - bos
71
+ if i_t * BT + i_i * BC >= T:
72
+ return
73
+
74
+ # offset calculation
75
+ ge += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
76
+ gi += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
77
+ q += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
78
+ a += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
79
+ b += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
80
+ k += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
81
+ dq += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
82
+ dk += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
83
+ da += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
84
+ db += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
85
+ dqg += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
86
+ dag += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
87
+ dkg += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
88
+ dbg += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
89
+ dgk += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
90
+ dgk_offset += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
91
+ dAqk += i_bh * T * BT if HEAD_FIRST else (bos*H + i_h) * BT
92
+ dAqb += i_bh * T * BT if HEAD_FIRST else (bos*H + i_h) * BT
93
+ dAak += i_bh * T * BT if HEAD_FIRST else (bos*H + i_h) * BT
94
+ dAab += i_bh * T * BT if HEAD_FIRST else (bos*H + i_h) * BT
95
+
96
+ stride_qk = K if HEAD_FIRST else H*K
97
+ stride_A = BT if HEAD_FIRST else H*BT
98
+
99
+ p_ge = tl.make_block_ptr(ge, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
100
+ p_gi = tl.make_block_ptr(gi, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
101
+ # [BC, BK]
102
+ b_ge = tl.load(p_ge, boundary_check=(0, 1))
103
+ b_gi = tl.load(p_gi, boundary_check=(0, 1))
104
+ b_dq = tl.zeros([BC, BK], dtype=tl.float32)
105
+ b_da = tl.zeros([BC, BK], dtype=tl.float32)
106
+ b_dk = tl.zeros([BC, BK], dtype=tl.float32)
107
+ b_db = tl.zeros([BC, BK], dtype=tl.float32)
108
+ # intra chunk gradient calculation
109
+ p_dAqk = tl.make_block_ptr(dAqk, (T, BT), (stride_A, 1), (i_t*BT + i_i*BC, i_i*BC), (BC, BC), (1, 0))
110
+ p_dAab = tl.make_block_ptr(dAab, (T, BT), (stride_A, 1), (i_t*BT + i_i*BC, i_i*BC), (BC, BC), (1, 0))
111
+ p_dAqb = tl.make_block_ptr(dAqb, (T, BT), (stride_A, 1), (i_t*BT + i_i*BC, i_i*BC), (BC, BC), (1, 0))
112
+ p_dAak = tl.make_block_ptr(dAak, (T, BT), (stride_A, 1), (i_t*BT + i_i*BC, i_i*BC), (BC, BC), (1, 0))
113
+ o_i = tl.arange(0, BC)
114
+ p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t*BT + i_i*BC, i_k*BK), (BC, BK), (1, 0))
115
+ p_b = tl.make_block_ptr(b, (T, K), (stride_qk, 1), (i_t*BT + i_i*BC, i_k*BK), (BC, BK), (1, 0))
116
+ p_a = tl.make_block_ptr(a, (T, K), (stride_qk, 1), (i_t*BT + i_i*BC, i_k*BK), (BC, BK), (1, 0))
117
+ p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t*BT + i_i*BC, i_k*BK), (BC, BK), (1, 0))
118
+ b_k = tl.load(p_k, boundary_check=(0, 1)).to(tl.float32)
119
+ b_b = tl.load(p_b, boundary_check=(0, 1)).to(tl.float32)
120
+ b_q = tl.load(p_q, boundary_check=(0, 1)).to(tl.float32)
121
+ b_a = tl.load(p_a, boundary_check=(0, 1)).to(tl.float32)
122
+ b_dAqk = tl.load(p_dAqk, boundary_check=(0, 1)).to(tl.float32)
123
+ b_dAab = tl.load(p_dAab, boundary_check=(0, 1)).to(tl.float32)
124
+ b_dAqb = tl.load(p_dAqb, boundary_check=(0, 1)).to(tl.float32)
125
+ b_dAak = tl.load(p_dAak, boundary_check=(0, 1)).to(tl.float32)
126
+
127
+ # inter chunk gradient calculation
128
+ o_k = i_k * BK + tl.arange(0, BK)
129
+ m_k = o_k < K
130
+ if i_i > 0:
131
+ p_gn = gi + (i_t * BT + i_i * BC - 1) * stride_qk + o_k
132
+ p_gn = tl.max_contiguous(tl.multiple_of(p_gn, BK), BK)
133
+ # [BK,]
134
+ b_gn = tl.load(p_gn, mask=m_k, other=0)
135
+ # [BK,]
136
+ for i_j in range(0, i_i):
137
+ p_kj = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
138
+ p_bj = tl.make_block_ptr(b, (T, K), (stride_qk, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
139
+ p_gkj = tl.make_block_ptr(gi, (T, K), (stride_qk, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
140
+ p_dAqikj = tl.make_block_ptr(dAqk, (T, BT), (stride_A, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
141
+ p_dAaibj = tl.make_block_ptr(dAab, (T, BT), (stride_A, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
142
+ p_dAqibj = tl.make_block_ptr(dAqb, (T, BT), (stride_A, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
143
+ p_dAaikj = tl.make_block_ptr(dAak, (T, BT), (stride_A, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
144
+ # [BC, BK]
145
+ b_kj = tl.load(p_kj, boundary_check=(0, 1))
146
+ b_bj = tl.load(p_bj, boundary_check=(0, 1))
147
+ b_gkj = tl.load(p_gkj, boundary_check=(0, 1))
148
+ tmp = exp(b_gn[None, :] - b_gkj)
149
+ b_kjg = b_kj * tmp
150
+ b_bjg = b_bj * tmp
151
+ # [BC, BC]
152
+ b_dAqikj = tl.load(p_dAqikj, boundary_check=(0, 1))
153
+ b_dAaibj = tl.load(p_dAaibj, boundary_check=(0, 1))
154
+ b_dAqibj = tl.load(p_dAqibj, boundary_check=(0, 1))
155
+ b_dAaikj = tl.load(p_dAaikj, boundary_check=(0, 1))
156
+ # [BC, BK]
157
+ b_dq += tl.dot(b_dAqikj, b_kjg)
158
+ b_dq += tl.dot(b_dAqibj, b_bjg)
159
+ # [BC, BC]
160
+ b_da += tl.dot(b_dAaibj, b_bjg)
161
+ b_da += tl.dot(b_dAaikj, b_kjg)
162
+ b_dq *= exp(b_gi - b_gn[None, :])
163
+ b_da *= exp(b_ge - b_gn[None, :])
164
+
165
+ NC = min(NC, tl.cdiv(T - i_t * BT, BC))
166
+ if i_i < NC - 1:
167
+ p_gn = gi + (min(i_t * BT + i_i * BC + BC, T) - 1)*stride_qk + o_k
168
+ p_gn = tl.max_contiguous(tl.multiple_of(p_gn, BK), BK)
169
+ # [BK,]
170
+ b_gn = tl.load(p_gn, mask=m_k, other=0)
171
+ for i_j in range(i_i + 1, NC):
172
+ m_j = (i_t * BT + i_j * BC + tl.arange(0, BC)) < T
173
+ p_qj = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
174
+ p_aj = tl.make_block_ptr(a, (T, K), (stride_qk, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
175
+ p_gij = tl.make_block_ptr(gi, (T, K), (stride_qk, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
176
+ p_gej = tl.make_block_ptr(ge, (T, K), (stride_qk, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
177
+ p_dAqjki = tl.make_block_ptr(dAqk, (BT, T), (1, stride_A), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1))
178
+ p_dAajbi = tl.make_block_ptr(dAab, (BT, T), (1, stride_A), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1))
179
+ p_dAqjbi = tl.make_block_ptr(dAqb, (BT, T), (1, stride_A), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1))
180
+ p_dAajki = tl.make_block_ptr(dAak, (BT, T), (1, stride_A), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1))
181
+ b_qj = tl.load(p_qj, boundary_check=(0, 1))
182
+ b_aj = tl.load(p_aj, boundary_check=(0, 1))
183
+ b_gij = tl.load(p_gij, boundary_check=(0, 1))
184
+ b_gej = tl.load(p_gej, boundary_check=(0, 1))
185
+ b_gij = tl.where(m_j[:, None] & m_k, b_gij, float('-inf'))
186
+ b_gej = tl.where(m_j[:, None] & m_k, b_gej, float('-inf'))
187
+ b_qjg = b_qj * exp(b_gij - b_gn[None, :])
188
+ b_ajg = b_aj * exp(b_gej - b_gn[None, :])
189
+ # [BC, BC]
190
+ b_dAqjki = tl.load(p_dAqjki, boundary_check=(0, 1))
191
+ b_dAajbi = tl.load(p_dAajbi, boundary_check=(0, 1))
192
+ b_dAqjbi = tl.load(p_dAqjbi, boundary_check=(0, 1))
193
+ b_dAajki = tl.load(p_dAajki, boundary_check=(0, 1))
194
+ b_dk += tl.dot(b_dAqjki, b_qjg)
195
+ b_dk += tl.dot(b_dAajki, b_ajg)
196
+ b_db += tl.dot(b_dAqjbi, b_qjg)
197
+ b_db += tl.dot(b_dAajbi, b_ajg)
198
+ tmp = exp(b_gn[None, :] - b_gi)
199
+ b_dk *= tmp
200
+ b_db *= tmp
201
+
202
+ # intra chunk gradient calculation
203
+ for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
204
+ # trick to index the block
205
+ if GATHER_SUPPORTED:
206
+ row_idx = tl.full([1, BK], j, dtype=tl.int16)
207
+ col_idx = tl.full([BC, 1], j, dtype=tl.int16)
208
+ row_idx_bc = tl.full([1, BC], j, dtype=tl.int16)
209
+ # [1, BK]
210
+ b_kj = gather(b_k, row_idx, axis=0)
211
+ b_bj = gather(b_b, row_idx, axis=0)
212
+ b_gij = gather(b_gi, row_idx, axis=0)
213
+ b_gej = gather(b_ge, row_idx, axis=0)
214
+ b_qj = gather(b_q, row_idx, axis=0)
215
+ b_aj = gather(b_a, row_idx, axis=0)
216
+ # [BC, 1]
217
+ b_dAqk_j = gather(b_dAqk, col_idx, axis=1)
218
+ b_dAab_j = gather(b_dAab, col_idx, axis=1)
219
+ b_dAqb_j = gather(b_dAqb, col_idx, axis=1)
220
+ b_dAak_j = gather(b_dAak, col_idx, axis=1)
221
+ # [1, BC] -> [BC, 1]
222
+ b_dA_qk_j = tl.sum(gather(b_dAqk, row_idx_bc, axis=0), 0)[:, None]
223
+ b_dA_qk_j = tl.sum(gather(b_dAqk, row_idx_bc, axis=0), 0)[:, None]
224
+ b_dA_ab_j = tl.sum(gather(b_dAab, row_idx_bc, axis=0), 0)[:, None]
225
+ b_dA_qb_j = tl.sum(gather(b_dAqb, row_idx_bc, axis=0), 0)[:, None]
226
+ b_dA_ak_j = tl.sum(gather(b_dAak, row_idx_bc, axis=0), 0)[:, None]
227
+ else:
228
+ mask_idx = tl.arange(0, BC) == j
229
+ b_kj = tl.sum(tl.where(mask_idx[:, None], b_k, 0), 0)[None, :]
230
+ b_bj = tl.sum(tl.where(mask_idx[:, None], b_b, 0), 0)[None, :]
231
+ b_gij = tl.sum(tl.where(mask_idx[:, None], b_gi, 0), 0)[None, :]
232
+ b_gej = tl.sum(tl.where(mask_idx[:, None], b_ge, 0), 0)[None, :]
233
+ b_dAqk_j = tl.sum(tl.where(mask_idx[None, :], b_dAqk, 0), 1)[:, None]
234
+ b_dAab_j = tl.sum(tl.where(mask_idx[None, :], b_dAab, 0), 1)[:, None]
235
+ b_dAqb_j = tl.sum(tl.where(mask_idx[None, :], b_dAqb, 0), 1)[:, None]
236
+ b_dAak_j = tl.sum(tl.where(mask_idx[None, :], b_dAak, 0), 1)[:, None]
237
+ b_dA_qk_j = tl.sum(tl.where(mask_idx[:, None], b_dAqk, 0), 0)[:, None]
238
+ b_dA_ab_j = tl.sum(tl.where(mask_idx[:, None], b_dAab, 0), 0)[:, None]
239
+ b_dA_qb_j = tl.sum(tl.where(mask_idx[:, None], b_dAqb, 0), 0)[:, None]
240
+ b_dA_ak_j = tl.sum(tl.where(mask_idx[:, None], b_dAak, 0), 0)[:, None]
241
+ # [1, BK] b_qj, b_aj
242
+ b_qj = tl.sum(tl.where(mask_idx[:, None], b_q, 0), 0)[None, :]
243
+ b_aj = tl.sum(tl.where(mask_idx[:, None], b_a, 0), 0)[None, :]
244
+ # tl.static_print(b_kj)
245
+ m_e = o_i[:, None] > j
246
+ m_i = o_i[:, None] >= j
247
+ tmp1 = exp(b_gi - b_gij)
248
+ tmp2 = exp(b_ge - b_gij)
249
+ b_dq += tl.where(m_i, b_dAqk_j * b_kj * tmp1, 0.)
250
+ b_dq += tl.where(m_i, b_dAqb_j * b_bj * tmp1, 0.)
251
+ b_da += tl.where(m_e, b_dAab_j * b_bj * tmp2, 0.)
252
+ b_da += tl.where(m_e, b_dAak_j * b_kj * tmp2, 0.)
253
+
254
+ m_i = o_i[:, None] <= j
255
+ m_e = o_i[:, None] < j
256
+ tmp1 = exp(b_gij - b_gi)
257
+ tmp2 = exp(b_gej - b_gi)
258
+ b_dk += tl.where(m_i, b_dA_qk_j * b_qj * tmp1, 0.)
259
+ b_dk += tl.where(m_e, b_dA_ak_j * b_aj * tmp2, 0.)
260
+ b_db += tl.where(m_i, b_dA_qb_j * b_qj * tmp1, 0.)
261
+ b_db += tl.where(m_e, b_dA_ab_j * b_aj * tmp2, 0.)
262
+ # post processing
263
+ p_dq = tl.make_block_ptr(dq, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
264
+ p_dk = tl.make_block_ptr(dk, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
265
+ p_da = tl.make_block_ptr(da, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
266
+ p_db = tl.make_block_ptr(db, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
267
+ p_dgk = tl.make_block_ptr(dgk, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
268
+ p_dgk_offset = tl.make_block_ptr(dgk_offset, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
269
+ p_dqg = tl.make_block_ptr(dqg, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
270
+ p_dkg = tl.make_block_ptr(dkg, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
271
+ p_dag = tl.make_block_ptr(dag, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
272
+ p_dbg = tl.make_block_ptr(dbg, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
273
+ p_gn = gi + (min(i_t * BT + BT, T) - 1)*stride_qk + o_k
274
+ p_gn = tl.max_contiguous(tl.multiple_of(p_gn, BK), BK)
275
+ b_gn = tl.load(p_gn, mask=m_k, other=0)
276
+ b_da += tl.load(p_dag, boundary_check=(0, 1)) * exp(b_ge)
277
+ b_dq += tl.load(p_dqg, boundary_check=(0, 1)) * exp(b_gi) * scale
278
+ tmp = exp(b_gn[None, :] - b_gi)
279
+ b_dk += tl.load(p_dkg, boundary_check=(0, 1)) * tmp
280
+ b_db += tl.load(p_dbg, boundary_check=(0, 1)) * tmp
281
+ tl.store(p_dq, (b_dq).to(p_dq.dtype.element_ty), boundary_check=(0, 1))
282
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
283
+ tl.store(p_da, b_da.to(p_da.dtype.element_ty), boundary_check=(0, 1))
284
+ tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0, 1))
285
+ b_dgk = b_dq * b_q + b_da * b_a - b_dk * b_k - b_db * b_b
286
+ b_dgk_offset = b_da * b_a
287
+ tl.store(p_dgk, b_dgk.to(p_dgk.dtype.element_ty), boundary_check=(0, 1))
288
+ tl.store(p_dgk_offset, b_dgk_offset.to(p_dgk_offset.dtype.element_ty), boundary_check=(0, 1))
289
+
290
+
291
+ @triton.heuristics({
292
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
293
+ })
294
+ @triton.autotune(
295
+ configs=[
296
+ triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages)
297
+ for num_warps in [2, 4, 8, 16, 32]
298
+ for num_stages in [2, 3, 4]
299
+ for BK in [32, 64]
300
+ ],
301
+ key=['BK', 'BT', 'K'],
302
+ use_cuda_graph=use_cuda_graph,
303
+ )
304
+ @triton.jit(do_not_specialize=['T'])
305
+ def chunk_dplr_bwd_dgk_kernel(
306
+ dgk,
307
+ dgk_offset,
308
+ dgk_last,
309
+ dgk_output,
310
+ offsets,
311
+ indices,
312
+ T,
313
+ H: tl.constexpr,
314
+ K: tl.constexpr,
315
+ BT: tl.constexpr,
316
+ BK: tl.constexpr,
317
+ USE_OFFSETS: tl.constexpr,
318
+ HEAD_FIRST: tl.constexpr,
319
+ ):
320
+ i_t, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
321
+ i_b, i_h = i_bh // H, i_bh % H
322
+ if USE_OFFSETS:
323
+ i_tg = i_t
324
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
325
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
326
+ T = eos - bos
327
+ NT = tl.cdiv(T, BT)
328
+ else:
329
+ NT = tl.cdiv(T, BT)
330
+ i_tg = i_b * NT + i_t
331
+ bos, eos = i_b * T, i_b * T + T
332
+ T = eos - bos
333
+ stride_qk = K if HEAD_FIRST else H * K
334
+ dgk += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
335
+ dgk_offset += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
336
+ dgk_last += ((i_bh * NT + i_t) * K) if HEAD_FIRST else (i_tg * H + i_h) * K
337
+ dgk_output += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
338
+ p_dgk_last = dgk_last + tl.arange(0, BK) + i_k * BK
339
+ m_k = tl.arange(0, BK) + i_k * BK < K
340
+ b_dgk_last = tl.load(p_dgk_last, mask=m_k, other=0)
341
+ p_dgk_offset = tl.make_block_ptr(dgk_offset, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
342
+ p_dgk = tl.make_block_ptr(dgk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
343
+ b_dgk = tl.load(p_dgk, boundary_check=(0, 1))
344
+ b_dgk_offset = tl.load(p_dgk_offset, boundary_check=(0, 1))
345
+ # m_inv_cumsum = (tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :]).to(tl.float32)
346
+ # b_dgk_cumsum = tl.dot(m_inv_cumsum, b_dgk, allow_tf32=False)
347
+ b_dgk_cumsum = tl.cumsum(b_dgk, 0, reverse=True)
348
+ b_dgk_cumsum += b_dgk_last[None, :]
349
+ b_dgk_cumsum -= b_dgk_offset
350
+ p_dgk_output = tl.make_block_ptr(dgk_output, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
351
+ tl.store(p_dgk_output, b_dgk_cumsum.to(p_dgk_output.dtype.element_ty), boundary_check=(0, 1))
352
+
353
+
354
+ def chunk_dplr_bwd_dqk_intra(
355
+ q: torch.Tensor,
356
+ k: torch.Tensor,
357
+ a: torch.Tensor,
358
+ b: torch.Tensor,
359
+ gi: torch.Tensor,
360
+ ge: torch.Tensor,
361
+ dAqk: torch.Tensor,
362
+ dAqb: torch.Tensor,
363
+ dAak: torch.Tensor,
364
+ dAab: torch.Tensor,
365
+ dqg: torch.Tensor,
366
+ dkg: torch.Tensor,
367
+ dag: torch.Tensor,
368
+ dbg: torch.Tensor,
369
+ dgk_last: torch.Tensor,
370
+ offsets: Optional[torch.LongTensor] = None,
371
+ indices: Optional[torch.LongTensor] = None,
372
+ head_first: bool = True,
373
+ scale: float = 1.0,
374
+ chunk_size: int = 64,
375
+ ):
376
+ if head_first:
377
+ B, H, T, K = q.shape
378
+ else:
379
+ B, T, H, K = q.shape
380
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
381
+ BC = min(16, BT)
382
+ BK = min(64, triton.next_power_of_2(K)) if check_shared_mem() else min(32, triton.next_power_of_2(K))
383
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
384
+ NC = triton.cdiv(BT, BC)
385
+ NK = triton.cdiv(K, BK)
386
+
387
+ dq = torch.empty_like(q)
388
+ dk = torch.empty_like(k)
389
+ da = torch.empty_like(a)
390
+ db = torch.empty_like(b)
391
+ dgk = torch.empty_like(gi, dtype=torch.float)
392
+ dgk_offset = torch.empty_like(gi, dtype=torch.float)
393
+
394
+ grid = (NK, NT * NC, B * H)
395
+ chunk_dplr_bwd_kernel_intra[grid](
396
+ q=q,
397
+ k=k,
398
+ a=a,
399
+ b=b,
400
+ gi=gi,
401
+ ge=ge,
402
+ dAqk=dAqk,
403
+ dAqb=dAqb,
404
+ dAak=dAak,
405
+ dAab=dAab,
406
+ dq=dq,
407
+ dk=dk,
408
+ dgk=dgk,
409
+ dgk_offset=dgk_offset,
410
+ dqg=dqg,
411
+ dkg=dkg,
412
+ dag=dag,
413
+ dbg=dbg,
414
+ da=da,
415
+ db=db,
416
+ offsets=offsets,
417
+ indices=indices,
418
+ scale=scale,
419
+ T=T,
420
+ H=H,
421
+ K=K,
422
+ BT=BT,
423
+ BC=BC,
424
+ BK=BK,
425
+ NC=NC,
426
+ HEAD_FIRST=head_first,
427
+ GATHER_SUPPORTED=is_gather_supported
428
+ )
429
+
430
+ def grid2(meta): return (NT, triton.cdiv(K, meta['BK']), B * H)
431
+ dgk_output = torch.empty_like(dgk)
432
+
433
+ chunk_dplr_bwd_dgk_kernel[grid2](
434
+ dgk=dgk,
435
+ dgk_offset=dgk_offset,
436
+ dgk_last=dgk_last,
437
+ dgk_output=dgk_output,
438
+ offsets=offsets,
439
+ indices=indices,
440
+ T=T,
441
+ H=H,
442
+ K=K,
443
+ BT=BT,
444
+ HEAD_FIRST=head_first
445
+ )
446
+ return dq, dk, da, db, dgk_output
fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import exp, gather
11
+ from fla.utils import is_gather_supported, use_cuda_graph
12
+
13
+
14
+ @triton.heuristics({
15
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
16
+ })
17
+ @triton.autotune(
18
+ configs=[
19
+ triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages)
20
+ for BK in [32, 64]
21
+ for num_warps in [2, 4, 8, 16]
22
+ for num_stages in [2, 3, 4]
23
+ ],
24
+ key=['BC', 'K'],
25
+ use_cuda_graph=use_cuda_graph,
26
+ )
27
+ @triton.jit(do_not_specialize=['T'])
28
+ def chunk_dplr_fwd_A_kernel_intra_sub_inter(
29
+ q,
30
+ k,
31
+ a,
32
+ b,
33
+ gi, # cumsum
34
+ ge, # before cumsum
35
+ Aqk,
36
+ Aqb,
37
+ Aab,
38
+ Aak,
39
+ offsets,
40
+ indices,
41
+ scale: tl.constexpr,
42
+ T,
43
+ H: tl.constexpr,
44
+ K: tl.constexpr,
45
+ BT: tl.constexpr,
46
+ BC: tl.constexpr,
47
+ BK: tl.constexpr,
48
+ NC: tl.constexpr,
49
+ USE_OFFSETS: tl.constexpr,
50
+ HEAD_FIRST: tl.constexpr,
51
+ ):
52
+ i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
53
+ i_b, i_h = i_bh // H, i_bh % H
54
+ i_i, i_j = i_c // NC, i_c % NC
55
+ if USE_OFFSETS:
56
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
57
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
58
+ T = eos - bos
59
+ else:
60
+ bos, eos = i_b * T, i_b * T + T
61
+
62
+ if i_t * BT + i_i * BC >= T:
63
+ return
64
+ if i_i <= i_j:
65
+ return
66
+
67
+ b_Aqk = tl.zeros([BC, BC], dtype=tl.float32)
68
+ b_Aqb = tl.zeros([BC, BC], dtype=tl.float32)
69
+ b_Aab = tl.zeros([BC, BC], dtype=tl.float32)
70
+ b_Aak = tl.zeros([BC, BC], dtype=tl.float32)
71
+ for i_k in range(tl.cdiv(K, BK)):
72
+ o_k = i_k * BK + tl.arange(0, BK)
73
+ m_k = o_k < K
74
+
75
+ if HEAD_FIRST:
76
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
77
+ p_a = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
78
+ p_gq_i = tl.make_block_ptr(gi + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
79
+ p_gq_e = tl.make_block_ptr(ge + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
80
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
81
+ p_b = tl.make_block_ptr(b + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
82
+ p_gk = tl.make_block_ptr(gi + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
83
+ p_gn = tl.max_contiguous(tl.multiple_of(gi + (i_bh * T + i_t * BT + i_i * BC - 1) * K + o_k, BK), BK)
84
+ else:
85
+ p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
86
+ p_a = tl.make_block_ptr(a + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
87
+ p_gq_i = tl.make_block_ptr(gi + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
88
+ p_gq_e = tl.make_block_ptr(ge + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
89
+ p_k = tl.make_block_ptr(k + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
90
+ p_b = tl.make_block_ptr(b + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
91
+ p_gk = tl.make_block_ptr(gi + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
92
+ p_gn = gi + (bos + i_t * BT + i_i * BC - 1) * H*K + i_h * K + o_k
93
+ # [BK,]
94
+ b_gn = tl.load(p_gn, mask=m_k, other=0).to(tl.float32)
95
+ # [BC, BK]
96
+ b_q = tl.load(p_q, boundary_check=(0, 1))
97
+ b_a = tl.load(p_a, boundary_check=(0, 1))
98
+ b_gq_i = tl.load(p_gq_i, boundary_check=(0, 1))
99
+ b_gq_e = tl.load(p_gq_e, boundary_check=(0, 1))
100
+ b_ag = b_a * exp(b_gq_e - b_gn[None, :])
101
+ b_qg = b_q * exp(b_gq_i - b_gn[None, :]) * scale
102
+ # [BK, BC]
103
+ b_k = tl.load(p_k, boundary_check=(0, 1))
104
+ b_b = tl.load(p_b, boundary_check=(0, 1))
105
+ b_gk = tl.load(p_gk, boundary_check=(0, 1)).to(tl.float32)
106
+ tmp = exp(b_gn[:, None] - b_gk)
107
+ b_kg = b_k * tmp
108
+ b_bg = b_b * tmp
109
+ # [BC, BC] using tf32 to improve precision here.
110
+ b_Aab += tl.dot(b_ag, b_bg)
111
+ b_Aak += tl.dot(b_ag, b_kg)
112
+ b_Aqk += tl.dot(b_qg, b_kg)
113
+ b_Aqb += tl.dot(b_qg, b_bg)
114
+
115
+ if HEAD_FIRST:
116
+ p_Aqk = tl.make_block_ptr(Aqk + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
117
+ p_Aqb = tl.make_block_ptr(Aqb + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
118
+ p_Aab = tl.make_block_ptr(Aab + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
119
+ p_Aak = tl.make_block_ptr(Aak + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
120
+ else:
121
+ p_Aqk = tl.make_block_ptr(Aqk + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
122
+ p_Aqb = tl.make_block_ptr(Aqb + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
123
+ p_Aab = tl.make_block_ptr(Aab + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
124
+ p_Aak = tl.make_block_ptr(Aak + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
125
+ tl.store(p_Aqk, b_Aqk.to(Aqk.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
126
+ tl.store(p_Aqb, b_Aqb.to(Aqb.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
127
+ tl.store(p_Aab, b_Aab.to(Aab.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
128
+ tl.store(p_Aak, b_Aak.to(Aak.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
129
+
130
+
131
+ @triton.heuristics({
132
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
133
+ })
134
+ @triton.autotune(
135
+ configs=[
136
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
137
+ for num_warps in [2, 4, 8, 16, 32]
138
+ for num_stages in [2, 3, 4]
139
+ ],
140
+ key=['BK', 'BT'],
141
+ use_cuda_graph=use_cuda_graph,
142
+ )
143
+ @triton.jit(do_not_specialize=['T'])
144
+ def chunk_dplr_fwd_A_kernel_intra_sub_intra(
145
+ q,
146
+ k,
147
+ a,
148
+ b,
149
+ gi,
150
+ ge,
151
+ qg,
152
+ kg,
153
+ ag,
154
+ bg,
155
+ Aqk,
156
+ Aqb,
157
+ Aab,
158
+ Aak,
159
+ offsets,
160
+ indices,
161
+ scale: tl.constexpr,
162
+ T,
163
+ H: tl.constexpr,
164
+ K: tl.constexpr,
165
+ BT: tl.constexpr,
166
+ BC: tl.constexpr,
167
+ BK: tl.constexpr,
168
+ NC: tl.constexpr,
169
+ USE_OFFSETS: tl.constexpr,
170
+ HEAD_FIRST: tl.constexpr,
171
+ GATHER_SUPPORTED: tl.constexpr
172
+ ):
173
+ i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
174
+ i_b, i_h = i_bh // H, i_bh % H
175
+ i_j = i_i
176
+ if USE_OFFSETS:
177
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
178
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
179
+ T = eos - bos
180
+ else:
181
+ bos, eos = i_b * T, i_b * T + T
182
+
183
+ if i_t * BT + i_i * BC >= T:
184
+ return
185
+
186
+ o_i = tl.arange(0, BC)
187
+ o_k = tl.arange(0, BK)
188
+ m_k = o_k < K
189
+ m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
190
+ last_idx = min((i_t+1) * BT, T) - 1
191
+ if HEAD_FIRST:
192
+ o_A = i_bh * T*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC
193
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
194
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
195
+ p_a = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
196
+ p_b = tl.make_block_ptr(b + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
197
+ p_gi = tl.make_block_ptr(gi + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
198
+ p_ge = tl.make_block_ptr(ge + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
199
+ p_g_last = gi + i_bh * T*K + last_idx * K + tl.arange(0, BK)
200
+ b_g_last = tl.load(p_g_last, mask=m_k, other=0)
201
+
202
+ p_qg = tl.make_block_ptr(qg + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
203
+ p_kg = tl.make_block_ptr(kg + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
204
+ p_ag = tl.make_block_ptr(ag + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
205
+ p_bg = tl.make_block_ptr(bg + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
206
+ else:
207
+ o_A = (bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BT + i_h * BT + i_j * BC
208
+ p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
209
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
210
+ p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
211
+ p_b = tl.make_block_ptr(b + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
212
+ p_gi = tl.make_block_ptr(gi + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
213
+ p_ge = tl.make_block_ptr(ge + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
214
+ p_g_last = gi + (bos * H + i_h) * K + last_idx * H * K + tl.arange(0, BK)
215
+ b_g_last = tl.load(p_g_last, mask=m_k, other=0)
216
+ p_qg = tl.make_block_ptr(qg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
217
+ p_kg = tl.make_block_ptr(kg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
218
+ p_ag = tl.make_block_ptr(ag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
219
+ p_bg = tl.make_block_ptr(bg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
220
+
221
+ b_q = tl.load(p_q, boundary_check=(0, 1))
222
+ b_q = b_q * scale
223
+ b_k = tl.load(p_k, boundary_check=(0, 1))
224
+ b_a = tl.load(p_a, boundary_check=(0, 1))
225
+ b_b = tl.load(p_b, boundary_check=(0, 1))
226
+ b_gi = tl.load(p_gi, boundary_check=(0, 1)).to(tl.float32)
227
+ b_ge = tl.load(p_ge, boundary_check=(0, 1)).to(tl.float32)
228
+
229
+ # deal with decay term.
230
+ g_exp = exp(b_gi)
231
+ g_exp_inv = exp(-b_gi + b_g_last[None, :])
232
+ b_qg = b_q * g_exp
233
+ b_kg = b_k * g_exp_inv
234
+ b_bg = b_b * g_exp_inv
235
+ b_ag = b_a * exp(b_ge)
236
+ tl.store(p_qg, b_qg.to(p_qg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
237
+ tl.store(p_bg, b_bg.to(p_bg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
238
+ tl.store(p_ag, b_ag.to(p_ag.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
239
+ tl.store(p_kg, b_kg.to(p_kg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
240
+ # tl.debug_barrier()
241
+
242
+ b_q = b_q.to(b_k.dtype)
243
+ # inner attn
244
+ for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
245
+ # a trick to index the j-th row of b_k, b_g, b_b
246
+ if GATHER_SUPPORTED:
247
+ row_idx = tl.full([1, BK], j, dtype=tl.int16)
248
+ # [1, BK]
249
+ b_k_j = gather(b_k, row_idx, axis=0)
250
+ b_gk_j = gather(b_gi, row_idx, axis=0)
251
+ b_b_j = gather(b_b, row_idx, axis=0)
252
+ else:
253
+ mask = tl.arange(0, BC) == j
254
+ b_k_j = tl.sum(tl.where(mask[:, None], b_k, 0), 0)[None, :]
255
+ b_gk_j = tl.sum(tl.where(mask[:, None], b_gi, 0), 0)[None, :]
256
+ b_b_j = tl.sum(tl.where(mask[:, None], b_b, 0), 0)[None, :]
257
+ mask = tl.arange(0, BC) == j
258
+ tmp = exp(b_gi - b_gk_j)
259
+ b_A_qk = tl.sum(b_q * b_k_j * tmp, 1)
260
+ b_A_qk = tl.where(o_i >= j, b_A_qk, 0.)
261
+ b_A_qb = tl.sum(b_q * b_b_j * tmp, 1)
262
+ b_A_qb = tl.where(o_i >= j, b_A_qb, 0.)
263
+ tmp2 = exp(b_ge - b_gk_j)
264
+ b_A_ak = tl.sum(b_a * b_k_j * tmp2, 1)
265
+ b_A_ak = tl.where(o_i > j, b_A_ak, 0.)
266
+ b_A_ab = tl.sum(b_a * b_b_j * tmp2, 1)
267
+ b_A_ab = tl.where(o_i > j, b_A_ab, 0.)
268
+ tl.store(Aqk + o_A + j, b_A_qk.to(dtype=Aqk.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A)
269
+ tl.store(Aqb + o_A + j, b_A_qb.to(dtype=Aqb.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A)
270
+ tl.store(Aab + o_A + j, b_A_ab.to(dtype=Aqb.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A)
271
+ tl.store(Aak + o_A + j, b_A_ak.to(dtype=Aqk.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A)
272
+
273
+
274
+ def chunk_fwd_intra_dplr_fn(
275
+ q: torch.Tensor,
276
+ k: torch.Tensor,
277
+ a: torch.Tensor,
278
+ b: torch.Tensor,
279
+ gi: torch.Tensor,
280
+ ge: torch.Tensor,
281
+ scale: float,
282
+ chunk_size: int,
283
+ offsets: Optional[torch.LongTensor] = None,
284
+ indices: Optional[torch.LongTensor] = None,
285
+ head_first: bool = True,
286
+ ):
287
+ if head_first:
288
+ B, H, T, K = k.shape
289
+ else:
290
+ B, T, H, K = k.shape
291
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
292
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
293
+ BC = min(16, BT)
294
+ NC = triton.cdiv(BT, BC)
295
+
296
+ Aqk = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=q.dtype)
297
+ Aqb = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=q.dtype)
298
+ # involving matrix inverse and it'd be better to use float here.
299
+ Aab = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=torch.float)
300
+ Aak = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=torch.float)
301
+ grid = (NT, NC * NC, B * H)
302
+
303
+ chunk_dplr_fwd_A_kernel_intra_sub_inter[grid](
304
+ q=q, k=k, a=a, b=b, gi=gi, ge=ge, Aqk=Aqk, Aqb=Aqb, Aab=Aab, Aak=Aak,
305
+ offsets=offsets, indices=indices,
306
+ scale=scale,
307
+ T=T, H=H, K=K, BT=BT, BC=BC, NC=NC,
308
+ HEAD_FIRST=head_first
309
+ )
310
+ grid = (NT, NC, B * H)
311
+ BK = triton.next_power_of_2(K)
312
+ qg = torch.empty_like(q)
313
+ kg = torch.empty_like(k, dtype=q.dtype)
314
+ ag = torch.empty_like(a, dtype=q.dtype)
315
+ bg = torch.empty_like(b, dtype=q.dtype)
316
+ chunk_dplr_fwd_A_kernel_intra_sub_intra[grid](
317
+ q=q, k=k, a=a, b=b, gi=gi, ge=ge, Aqk=Aqk, Aqb=Aqb, Aab=Aab, Aak=Aak,
318
+ qg=qg, kg=kg, ag=ag, bg=bg,
319
+ offsets=offsets, indices=indices,
320
+ scale=scale,
321
+ T=T, H=H, K=K, BT=BT, BC=BC, BK=BK, HEAD_FIRST=head_first, NC=NC,
322
+ GATHER_SUPPORTED=is_gather_supported
323
+ )
324
+ return Aab, Aqk, Aak, Aqb, qg, kg, ag, bg
fla/ops/generalized_delta_rule/dplr/chunk_h_bwd.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.utils import prepare_chunk_offsets
11
+ from fla.ops.utils.op import exp
12
+ from fla.utils import check_shared_mem, use_cuda_graph
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
17
+ 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None,
18
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
23
+ for num_warps in [2, 4, 8, 16, 32]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['BT', 'BK', 'BV', "V"],
27
+ use_cuda_graph=use_cuda_graph,
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def chunk_dplr_bwd_kernel_dhu(
31
+ qg,
32
+ bg,
33
+ w,
34
+ gk,
35
+ dht,
36
+ dh0,
37
+ do,
38
+ dh,
39
+ dv,
40
+ dv2,
41
+ offsets,
42
+ chunk_offsets,
43
+ T,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BT: tl.constexpr,
48
+ BC: tl.constexpr,
49
+ BK: tl.constexpr,
50
+ BV: tl.constexpr,
51
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
52
+ USE_INITIAL_STATE: tl.constexpr,
53
+ USE_OFFSETS: tl.constexpr,
54
+ HEAD_FIRST: tl.constexpr
55
+ ):
56
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
57
+ i_n, i_h = i_nh // H, i_nh % H
58
+ if USE_OFFSETS:
59
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
60
+ T = eos - bos
61
+ NT = tl.cdiv(T, BT)
62
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
63
+ else:
64
+ bos, eos = i_n * T, i_n * T + T
65
+ NT = tl.cdiv(T, BT)
66
+ boh = i_n * NT
67
+
68
+ # [BK, BV]
69
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
70
+ if USE_FINAL_STATE_GRADIENT:
71
+ p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
72
+ b_dh += tl.load(p_dht, boundary_check=(0, 1))
73
+
74
+ mask_k = tl.arange(0, BK) < K
75
+ for i_t in range(NT - 1, -1, -1):
76
+ if HEAD_FIRST:
77
+ p_dh = tl.make_block_ptr(dh + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
78
+ else:
79
+ p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
80
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
81
+ b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32)
82
+ for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1):
83
+ if HEAD_FIRST:
84
+ p_qg = tl.make_block_ptr(qg + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
85
+ p_bg = tl.make_block_ptr(bg + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
86
+ p_w = tl.make_block_ptr(w + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
87
+ p_dv = tl.make_block_ptr(dv + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
88
+ p_do = tl.make_block_ptr(do + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
89
+ p_dv2 = tl.make_block_ptr(dv2 + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
90
+ else:
91
+ p_qg = tl.make_block_ptr(qg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
92
+ p_bg = tl.make_block_ptr(bg+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
93
+ p_w = tl.make_block_ptr(w+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
94
+ p_dv = tl.make_block_ptr(dv+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
95
+ p_do = tl.make_block_ptr(do+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
96
+ p_dv2 = tl.make_block_ptr(dv2+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
97
+ # [BK, BT]
98
+ b_qg = tl.load(p_qg, boundary_check=(0, 1))
99
+ # [BT, BK]
100
+ b_bg = tl.load(p_bg, boundary_check=(0, 1))
101
+ b_w = tl.load(p_w, boundary_check=(0, 1))
102
+ # [BT, V]
103
+ b_do = tl.load(p_do, boundary_check=(0, 1))
104
+ b_dv = tl.load(p_dv, boundary_check=(0, 1))
105
+ b_dv2 = b_dv + tl.dot(b_bg, b_dh.to(b_bg.dtype))
106
+ tl.store(p_dv2, b_dv2.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
107
+ # [BK, BV]
108
+ b_dh_tmp += tl.dot(b_qg, b_do.to(b_qg.dtype))
109
+ b_dh_tmp += tl.dot(b_w, b_dv2.to(b_qg.dtype))
110
+ last_idx = min((i_t + 1) * BT, T) - 1
111
+ if HEAD_FIRST:
112
+ bg_last = tl.load(gk + (i_nh * T + last_idx) * K + tl.arange(0, BK), mask=mask_k)
113
+ else:
114
+ bg_last = tl.load(gk + ((bos + last_idx) * H + i_h) * K + tl.arange(0, BK), mask=mask_k)
115
+ b_dh *= exp(bg_last)[:, None]
116
+ b_dh += b_dh_tmp
117
+
118
+ if USE_INITIAL_STATE:
119
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
120
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
121
+
122
+
123
+ def chunk_dplr_bwd_dhu(
124
+ qg: torch.Tensor,
125
+ bg: torch.Tensor,
126
+ w: torch.Tensor,
127
+ gk: torch.Tensor,
128
+ h0: torch.Tensor,
129
+ dht: Optional[torch.Tensor],
130
+ do: torch.Tensor,
131
+ dv: torch.Tensor,
132
+ offsets: Optional[torch.LongTensor] = None,
133
+ indices: Optional[torch.LongTensor] = None,
134
+ head_first: bool = True,
135
+ chunk_size: int = 64
136
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
137
+ if head_first:
138
+ B, H, T, K, V = *qg.shape, do.shape[-1]
139
+ else:
140
+ B, T, H, K, V = *qg.shape, do.shape[-1]
141
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
142
+ BK = triton.next_power_of_2(K)
143
+ assert BK <= 256, "current kernel does not support head dimension being larger than 256."
144
+ # H100
145
+ if check_shared_mem('hopper', qg.device.index):
146
+ BV = 64
147
+ BC = 64 if K <= 128 else 32
148
+ elif check_shared_mem('ampere', qg.device.index): # A100
149
+ BV = 32
150
+ BC = 32
151
+ else: # Etc: 4090
152
+ BV = 16
153
+ BC = 16
154
+
155
+ # N: the actual number of sequences in the batch with either equal or variable lengths
156
+ if offsets is None:
157
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
158
+ else:
159
+ N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT)
160
+
161
+ BC = min(BT, BC)
162
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
163
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
164
+
165
+ if head_first:
166
+ dh = qg.new_empty(B, H, NT, K, V)
167
+ else:
168
+ dh = qg.new_empty(B, NT, H, K, V)
169
+ dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None
170
+ dv2 = torch.zeros_like(dv)
171
+
172
+ grid = (NK, NV, N * H)
173
+ chunk_dplr_bwd_kernel_dhu[grid](
174
+ qg=qg,
175
+ bg=bg,
176
+ w=w,
177
+ gk=gk,
178
+ dht=dht,
179
+ dh0=dh0,
180
+ do=do,
181
+ dh=dh,
182
+ dv=dv,
183
+ dv2=dv2,
184
+ offsets=offsets,
185
+ chunk_offsets=chunk_offsets,
186
+ T=T,
187
+ H=H,
188
+ K=K,
189
+ V=V,
190
+ BT=BT,
191
+ BC=BC,
192
+ BK=BK,
193
+ BV=BV,
194
+ HEAD_FIRST=head_first
195
+ )
196
+ return dh, dh0, dv2
fla/ops/generalized_delta_rule/dplr/chunk_h_fwd.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.utils import prepare_chunk_offsets
11
+ from fla.ops.utils.op import exp
12
+ from fla.utils import check_shared_mem, use_cuda_graph
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
17
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
18
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
23
+ for num_warps in [2, 4, 8, 16, 32]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['BT', 'BK', 'BV'],
27
+ use_cuda_graph=use_cuda_graph,
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def chunk_dplr_fwd_kernel_h(
31
+ kg,
32
+ v,
33
+ w,
34
+ bg,
35
+ u,
36
+ v_new,
37
+ gk,
38
+ h,
39
+ h0,
40
+ ht,
41
+ offsets,
42
+ chunk_offsets,
43
+ T,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BT: tl.constexpr,
48
+ BC: tl.constexpr,
49
+ BK: tl.constexpr,
50
+ BV: tl.constexpr,
51
+ NT: tl.constexpr,
52
+ USE_INITIAL_STATE: tl.constexpr,
53
+ STORE_FINAL_STATE: tl.constexpr,
54
+ USE_OFFSETS: tl.constexpr,
55
+ HEAD_FIRST: tl.constexpr,
56
+ ):
57
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
58
+ i_n, i_h = i_nh // H, i_nh % H
59
+ if USE_OFFSETS:
60
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
61
+ T = eos - bos
62
+ NT = tl.cdiv(T, BT)
63
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
64
+ else:
65
+ bos, eos = i_n * T, i_n * T + T
66
+ NT = tl.cdiv(T, BT)
67
+ boh = i_n * NT
68
+
69
+ # [BK, BV]
70
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
71
+ if USE_INITIAL_STATE:
72
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
73
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
74
+
75
+ for i_t in range(NT):
76
+ if HEAD_FIRST:
77
+ p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
78
+ else:
79
+ p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
80
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
81
+
82
+ b_hc = tl.zeros([BK, BV], dtype=tl.float32)
83
+ # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden
84
+ for i_c in range(tl.cdiv(min(BT, T - i_t * BT), BC)):
85
+ if HEAD_FIRST:
86
+ p_kg = tl.make_block_ptr(kg + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
87
+ p_bg = tl.make_block_ptr(bg + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
88
+ p_w = tl.make_block_ptr(w + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
89
+ p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
90
+ p_u = tl.make_block_ptr(u + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
91
+ p_v_new = tl.make_block_ptr(v_new+i_nh*T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
92
+ else:
93
+ p_kg = tl.make_block_ptr(kg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
94
+ p_bg = tl.make_block_ptr(bg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
95
+ p_w = tl.make_block_ptr(w+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
96
+ p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
97
+ p_u = tl.make_block_ptr(u+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
98
+ p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT+i_c*BC, i_v * BV), (BC, BV), (1, 0))
99
+ # [BK, BC]
100
+ b_kg = tl.load(p_kg, boundary_check=(0, 1))
101
+ b_v = tl.load(p_v, boundary_check=(0, 1))
102
+ b_w = tl.load(p_w, boundary_check=(0, 1))
103
+ b_bg = tl.load(p_bg, boundary_check=(0, 1))
104
+ b_v2 = tl.dot(b_w, b_h.to(b_w.dtype)) + tl.load(p_u, boundary_check=(0, 1))
105
+ b_hc += tl.dot(b_kg, b_v)
106
+ b_hc += tl.dot(b_bg.to(b_hc.dtype), b_v2)
107
+ tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
108
+
109
+ last_idx = min((i_t + 1) * BT, T) - 1
110
+ if HEAD_FIRST:
111
+ b_g_last = tl.load(gk + i_nh * T * K + last_idx * K + tl.arange(0, BK), mask=tl.arange(0, BK) < K).to(tl.float32)
112
+ else:
113
+ b_g_last = tl.load(gk + (bos + last_idx) * H * K + i_h * K +
114
+ tl.arange(0, BK), mask=tl.arange(0, BK) < K).to(tl.float32)
115
+ b_h *= exp(b_g_last[:, None])
116
+ b_h += b_hc
117
+
118
+ if STORE_FINAL_STATE:
119
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
120
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
121
+
122
+
123
+ def chunk_dplr_fwd_h(
124
+ kg: torch.Tensor,
125
+ v: torch.Tensor,
126
+ w: torch.Tensor,
127
+ u: torch.Tensor,
128
+ bg: torch.Tensor,
129
+ gk: torch.Tensor,
130
+ initial_state: Optional[torch.Tensor] = None,
131
+ output_final_state: bool = False,
132
+ offsets: Optional[torch.LongTensor] = None,
133
+ indices: Optional[torch.LongTensor] = None,
134
+ head_first: bool = True,
135
+ chunk_size: int = 64
136
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
137
+ if head_first:
138
+ B, H, T, K, V = *kg.shape, u.shape[-1]
139
+ else:
140
+ B, T, H, K, V = *kg.shape, u.shape[-1]
141
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
142
+ # N: the actual number of sequences in the batch with either equal or variable lengths
143
+ if offsets is None:
144
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
145
+ else:
146
+ N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT)
147
+ BK = triton.next_power_of_2(K)
148
+ assert BK <= 256, "current kernel does not support head dimension larger than 256."
149
+ # H100 can have larger block size
150
+
151
+ if check_shared_mem('hopper', kg.device.index):
152
+ BV = 64
153
+ BC = 64 if K <= 128 else 32
154
+ elif check_shared_mem('ampere', kg.device.index): # A100
155
+ BV = 32
156
+ BC = 32
157
+ else:
158
+ BV = 16
159
+ BC = 16
160
+
161
+ BC = min(BT, BC)
162
+ NK = triton.cdiv(K, BK)
163
+ NV = triton.cdiv(V, BV)
164
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
165
+
166
+ if head_first:
167
+ h = kg.new_empty(B, H, NT, K, V)
168
+ else:
169
+ h = kg.new_empty(B, NT, H, K, V)
170
+ final_state = kg.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
171
+ v_new = torch.empty_like(u)
172
+ grid = (NK, NV, N * H)
173
+ chunk_dplr_fwd_kernel_h[grid](
174
+ kg=kg,
175
+ v=v,
176
+ w=w,
177
+ bg=bg,
178
+ u=u,
179
+ v_new=v_new,
180
+ h=h,
181
+ gk=gk,
182
+ h0=initial_state,
183
+ ht=final_state,
184
+ offsets=offsets,
185
+ chunk_offsets=chunk_offsets,
186
+ T=T,
187
+ H=H,
188
+ K=K,
189
+ V=V,
190
+ BT=BT,
191
+ BC=BC,
192
+ BK=BK,
193
+ BV=BV,
194
+ NT=NT,
195
+ HEAD_FIRST=head_first
196
+ )
197
+ return h, v_new, final_state
fla/ops/generalized_delta_rule/dplr/chunk_o_fwd.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.utils import check_shared_mem, use_cuda_graph
11
+
12
+ BK_LIST = [32, 64, 128] if check_shared_mem() else [16, 32]
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
17
+ })
18
+ @triton.autotune(
19
+ configs=[
20
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
21
+ for BK in BK_LIST
22
+ for BV in BK_LIST
23
+ for num_warps in [2, 4, 8, 16, 32]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['BT'],
27
+ use_cuda_graph=use_cuda_graph,
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def chunk_dplr_fwd_kernel_o(
31
+ qg,
32
+ v,
33
+ v_new,
34
+ A_qk,
35
+ A_qb,
36
+ h,
37
+ o,
38
+ offsets,
39
+ indices,
40
+ T,
41
+ H: tl.constexpr,
42
+ K: tl.constexpr,
43
+ V: tl.constexpr,
44
+ BT: tl.constexpr,
45
+ BK: tl.constexpr,
46
+ BV: tl.constexpr,
47
+ USE_OFFSETS: tl.constexpr,
48
+ HEAD_FIRST: tl.constexpr,
49
+ ):
50
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
51
+ i_b, i_h = i_bh // H, i_bh % H
52
+
53
+ if USE_OFFSETS:
54
+ i_tg = i_t
55
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
56
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
57
+ T = eos - bos
58
+ NT = tl.cdiv(T, BT)
59
+ else:
60
+ NT = tl.cdiv(T, BT)
61
+ i_tg = i_b * NT + i_t
62
+ bos, eos = i_b * T, i_b * T + T
63
+
64
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
65
+ for i_k in range(tl.cdiv(K, BK)):
66
+ if HEAD_FIRST:
67
+ p_qg = tl.make_block_ptr(qg + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
68
+ p_h = tl.make_block_ptr(h + (i_bh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
69
+ else:
70
+ p_qg = tl.make_block_ptr(qg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
71
+ p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
72
+ b_qg = tl.load(p_qg, boundary_check=(0, 1))
73
+ b_h = tl.load(p_h, boundary_check=(0, 1))
74
+ b_o += tl.dot(b_qg, b_h)
75
+
76
+ if HEAD_FIRST:
77
+ p_Aqk = tl.make_block_ptr(A_qk + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
78
+ p_Aqb = tl.make_block_ptr(A_qb + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
79
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
80
+ p_v_new = tl.make_block_ptr(v_new + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
81
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
82
+ else:
83
+ p_Aqk = tl.make_block_ptr(A_qk + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
84
+ p_Aqb = tl.make_block_ptr(A_qb + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
85
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
86
+ p_v_new = tl.make_block_ptr(v_new + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
87
+ p_o = tl.make_block_ptr(o + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
88
+
89
+ m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :]
90
+ b_Aqk = tl.load(p_Aqk, boundary_check=(0, 1))
91
+ b_Aqb = tl.load(p_Aqb, boundary_check=(0, 1))
92
+ b_Aqk = tl.where(m_s, b_Aqk, 0)
93
+ b_Aqb = tl.where(m_s, b_Aqb, 0)
94
+ b_v = tl.load(p_v, boundary_check=(0, 1))
95
+ b_v_new = tl.load(p_v_new, boundary_check=(0, 1))
96
+ b_o = b_o + tl.dot(b_Aqk.to(b_v.dtype), b_v) + tl.dot(b_Aqb.to(b_v_new.dtype), b_v_new)
97
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
98
+
99
+
100
+ def chunk_dplr_fwd_o(
101
+ qg: torch.Tensor,
102
+ v: torch.Tensor,
103
+ v_new: torch.Tensor,
104
+ A_qk: torch.Tensor,
105
+ A_qb: torch.Tensor,
106
+ h: torch.Tensor,
107
+ offsets: Optional[torch.LongTensor] = None,
108
+ indices: Optional[torch.LongTensor] = None,
109
+ head_first: bool = True,
110
+ chunk_size: int = 64
111
+ ) -> torch.Tensor:
112
+ if head_first:
113
+ B, H, T, K, V = *qg.shape, v.shape[-1]
114
+ else:
115
+ B, T, H, K, V = *qg.shape, v.shape[-1]
116
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
117
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
118
+
119
+ o = torch.empty_like(v)
120
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H)
121
+ chunk_dplr_fwd_kernel_o[grid](
122
+ qg=qg,
123
+ v=v,
124
+ v_new=v_new,
125
+ A_qk=A_qk,
126
+ A_qb=A_qb,
127
+ h=h,
128
+ o=o,
129
+ offsets=offsets,
130
+ indices=indices,
131
+ T=T,
132
+ H=H,
133
+ K=K,
134
+ V=V,
135
+ BT=BT,
136
+ HEAD_FIRST=head_first
137
+ )
138
+ return o
fla/ops/generalized_delta_rule/dplr/fused_recurrent.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import exp
11
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard, use_cuda_graph
12
+
13
+
14
+ @triton.heuristics({
15
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
16
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
17
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
18
+ })
19
+ @triton.autotune(
20
+ configs=[
21
+ triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages)
22
+ for BV in [16, 32, 64]
23
+ for num_warps in [2, 4, 8, 16]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['BK'],
27
+ use_cuda_graph=use_cuda_graph,
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def fused_recurrent_dplr_delta_rule_fwd_kernel(
31
+ q,
32
+ k,
33
+ v,
34
+ a,
35
+ b,
36
+ gk,
37
+ o,
38
+ h0,
39
+ ht,
40
+ offsets,
41
+ scale,
42
+ T,
43
+ B: tl.constexpr,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BK: tl.constexpr,
48
+ BV: tl.constexpr,
49
+ REVERSE: tl.constexpr,
50
+ USE_INITIAL_STATE: tl.constexpr,
51
+ STORE_FINAL_STATE: tl.constexpr,
52
+ USE_OFFSETS: tl.constexpr,
53
+ HEAD_FIRST: tl.constexpr
54
+ ):
55
+ i_v, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64)
56
+ i_n, i_h = i_nh // H, i_nh % H
57
+
58
+ if USE_OFFSETS:
59
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
60
+ T = eos - bos
61
+ else:
62
+ bos, eos = i_n * T, i_n * T + T
63
+
64
+ o_k = tl.arange(0, BK)
65
+ o_v = i_v * BV + tl.arange(0, BV)
66
+ if HEAD_FIRST:
67
+ p_q = q + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k
68
+ p_k = k + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k
69
+ p_a = a + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k
70
+ p_b = b + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k
71
+ p_gk = gk + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k
72
+ p_v = v + i_nh * T*V + ((T-1) * V if REVERSE else 0) + o_v
73
+ p_o = o + i_nh * T*V + ((T-1) * V if REVERSE else 0) + o_v
74
+
75
+ else:
76
+ p_q = q + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
77
+ p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
78
+ p_a = a + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
79
+ p_b = b + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
80
+ p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
81
+ p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v
82
+ p_o = o + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v
83
+
84
+ mask_k = o_k < K
85
+ mask_v = o_v < V
86
+ mask_h = mask_k[None, :] & mask_v[:, None]
87
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
88
+
89
+ if USE_INITIAL_STATE:
90
+ p_h0 = h0 + i_nh * K*V + o_k[None, :] * V + o_v[:, None]
91
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
92
+
93
+ for _ in range(0, T):
94
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
95
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
96
+ b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32)
97
+ b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32)
98
+ b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)
99
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
100
+
101
+ tmp = tl.sum(b_h * b_a[None, :], axis=1)
102
+ b_h = exp(b_gk)[None, :] * b_h + (tmp[:, None] * b_b[None, :] + b_k[None, :] * b_v[:, None])
103
+ b_o = tl.sum(b_h * b_q[None, :], axis=1)
104
+
105
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
106
+ p_q += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
107
+ p_k += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
108
+ p_a += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
109
+ p_b += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
110
+ p_gk += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
111
+ p_v += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
112
+ p_o += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
113
+
114
+ if STORE_FINAL_STATE:
115
+ p_ht = ht + i_nh * K*V + o_k[None, :] * V + o_v[:, None]
116
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
117
+
118
+
119
+ def fused_recurrent_dplr_delta_rule_fwd(
120
+ q: torch.Tensor,
121
+ k: torch.Tensor,
122
+ v: torch.Tensor,
123
+ a: torch.Tensor,
124
+ b: torch.Tensor,
125
+ gk: torch.Tensor,
126
+ scale: Optional[float] = 1.0,
127
+ initial_state: Optional[torch.Tensor] = None,
128
+ output_final_state: bool = False,
129
+ reverse: bool = False,
130
+ offsets: Optional[torch.LongTensor] = None,
131
+ head_first: bool = True
132
+ ):
133
+ if head_first:
134
+ B, H, T, K, V = *k.shape, v.shape[-1]
135
+ else:
136
+ B, T, H, K, V = *k.shape, v.shape[-1]
137
+ N = B if offsets is None else len(offsets) - 1
138
+ BK = triton.next_power_of_2(K)
139
+
140
+ h0 = initial_state
141
+ if output_final_state:
142
+ ht = q.new_empty(N, H, K, V, dtype=torch.float32)
143
+ else:
144
+ ht = None
145
+ o = torch.empty_like(v)
146
+
147
+ def grid(meta): return (triton.cdiv(V, meta['BV']), N * H)
148
+ fused_recurrent_dplr_delta_rule_fwd_kernel[grid](
149
+ q,
150
+ k,
151
+ v,
152
+ a,
153
+ b,
154
+ gk,
155
+ o,
156
+ h0,
157
+ ht,
158
+ offsets,
159
+ scale,
160
+ T=T,
161
+ B=B,
162
+ H=H,
163
+ K=K,
164
+ V=V,
165
+ BK=BK,
166
+ REVERSE=reverse,
167
+ HEAD_FIRST=head_first
168
+ )
169
+ return o, ht
170
+
171
+
172
+ class FusedRecurrentDPLRDeltaRuleFunction(torch.autograd.Function):
173
+
174
+ @staticmethod
175
+ @input_guard
176
+ @autocast_custom_fwd
177
+ def forward(
178
+ ctx,
179
+ q: torch.Tensor,
180
+ k: torch.Tensor,
181
+ v: torch.Tensor,
182
+ a: torch.Tensor,
183
+ b: torch.Tensor,
184
+ gk: torch.Tensor,
185
+ scale: Optional[float] = 1.0,
186
+ initial_state: Optional[torch.Tensor] = None,
187
+ output_final_state: bool = False,
188
+ reverse: bool = False,
189
+ offsets: Optional[torch.LongTensor] = None,
190
+ head_first: bool = False
191
+ ):
192
+ o, ht = fused_recurrent_dplr_delta_rule_fwd(
193
+ q=q,
194
+ k=k,
195
+ v=v,
196
+ a=a,
197
+ b=b,
198
+ gk=gk,
199
+ scale=scale,
200
+ initial_state=initial_state,
201
+ output_final_state=output_final_state,
202
+ reverse=reverse,
203
+ offsets=offsets,
204
+ head_first=head_first
205
+ )
206
+ return o, ht
207
+
208
+ @staticmethod
209
+ @input_guard
210
+ @autocast_custom_bwd
211
+ def backward(ctx, do, dht):
212
+ raise NotImplementedError(
213
+ "Backward pass for fused_recurrent_dplr_delta_rule is not implemented and will not be supported. "
214
+ "This kernel is only for inference. "
215
+ "For training, please use `chunk_dplr_delta_rule`."
216
+ )
217
+
218
+
219
+ def fused_recurrent_dplr_delta_rule(
220
+ q: torch.Tensor,
221
+ k: torch.Tensor,
222
+ v: torch.Tensor,
223
+ a: torch.Tensor,
224
+ b: torch.Tensor,
225
+ gk: torch.Tensor,
226
+ scale: Optional[float] = 1.0,
227
+ initial_state: Optional[torch.Tensor] = None,
228
+ output_final_state: bool = False,
229
+ reverse: bool = False,
230
+ cu_seqlens: Optional[torch.Tensor] = None,
231
+ head_first: bool = False
232
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
233
+ r"""
234
+ This function computes the recurrence S_t = S_t @ (I + a_t b_t^T) + v_t k_t^T in a recurrent manner.
235
+
236
+ Args:
237
+ q (torch.Tensor):
238
+ queries of shape `[B, H, T, K]`
239
+ k (torch.Tensor):
240
+ keys of shape `[B, H, T, K]`
241
+ v (torch.Tensor):
242
+ values of shape `[B, H, T, V]`
243
+ a (torch.Tensor):
244
+ as of shape `[B, H, T, K]`
245
+ b (torch.Tensor):
246
+ bs of shape `[B, H, T, K]`
247
+ gk (torch.Tensor):
248
+ gk of shape `[B, H, T, K]`
249
+ scale (Optional[int]):
250
+ Scale factor for the RetNet attention scores.
251
+ If None, it will default to `1 / sqrt(K)`. Default: `1.0`.
252
+ initial_state (Optional[torch.Tensor]):
253
+ Initial state of shape `[B, H, K, V]`. Default: `None`.
254
+ output_final_state (Optional[bool]):
255
+ Whether to output the final state of shape `[B, H, K, V]`. Default: `False`.
256
+ reverse (Optional[bool]):
257
+ If `True`, process the state passing in reverse order. Default: `False`.
258
+ cu_seqlens (Optional[torch.Tensor]):
259
+ Cumulative sequence lengths of shape `[N + 1]` used for variable-length training,
260
+ consistent with the FlashAttention API.
261
+ head_first (Optional[bool]):
262
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
263
+ Default: `False`.
264
+ """
265
+ if cu_seqlens is not None:
266
+ if q.shape[0] != 1:
267
+ raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
268
+ f"Please flatten variable-length inputs before processing.")
269
+ if head_first:
270
+ raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
271
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
272
+ raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, "
273
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.")
274
+ if scale is None:
275
+ scale = q.shape[-1] ** -0.5
276
+ else:
277
+ assert scale > 0, "scale must be positive"
278
+ o, final_state = FusedRecurrentDPLRDeltaRuleFunction.apply(
279
+ q,
280
+ k,
281
+ v,
282
+ a,
283
+ b,
284
+ gk,
285
+ scale,
286
+ initial_state,
287
+ output_final_state,
288
+ reverse,
289
+ cu_seqlens,
290
+ head_first
291
+ )
292
+ return o, final_state
fla/ops/generalized_delta_rule/dplr/naive.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ from einops import rearrange
5
+
6
+ # S_t = S_t @ (I + alpha_t beta_t^T) + v_t k_t^T
7
+ # q, k, alpha, beta [B, H, L, D_K]
8
+ # v [B, H, L, D_V]
9
+
10
+
11
+ def dplr_recurrence(q, k, v, alpha, beta, gk, initial_state=None, output_final_state=True):
12
+ orig_dtype = q.dtype
13
+ b, h, l, d_k = q.shape
14
+ q, k, v, beta, gk = map(lambda x: x.float(), [q, k, v, beta, gk])
15
+ d_v = v.shape[-1]
16
+ o = torch.zeros_like(v)
17
+ S = torch.zeros(b, h, d_k, d_v).to(v)
18
+ q = q * (d_k ** -0.5)
19
+
20
+ if initial_state is not None:
21
+ S += initial_state
22
+
23
+ for i in range(l):
24
+ _k = k[:, :, i]
25
+ _q = q[:, :, i]
26
+ _v = v[:, :, i]
27
+ _alpha = alpha[:, :, i].clone()
28
+ _beta = beta[:, :, i].clone()
29
+ _kv = _k[..., None] * _v[..., None, :] + (S.clone() * _alpha[..., None]).sum(-2, keepdim=True) * _beta[..., None]
30
+ S = S.clone() * gk[:, :, i].exp()[..., None] + _kv
31
+ o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S)
32
+ S = None if output_final_state is False else S
33
+ return o.to(orig_dtype), S
34
+
35
+
36
+ def dplr_chunkwise(q, k, v, alpha, beta, gk, initial_state=None, output_final_state=True, chunk_size=32):
37
+ b, h, l, d_k = q.shape
38
+ d_v = v.shape[-1]
39
+ q = q * (d_k ** -0.5)
40
+ v = v
41
+ assert l % chunk_size == 0
42
+
43
+ S = k.new_zeros(b, h, d_k, d_v).to(q)
44
+ if initial_state is not None:
45
+ S += initial_state
46
+
47
+ # note that diagonal is masked.
48
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0)
49
+ q, k, v, alpha, beta, gk = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d',
50
+ c=chunk_size).float(), [q, k, v, alpha, beta, gk])
51
+
52
+ gk_cumsum = gk.cumsum(-2)
53
+
54
+ # v2 = (alpha @ k.transpose(-1, -2)).masked_fill_(mask, 0) @ v
55
+ A_ab = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
56
+ A_qk = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
57
+ A_ak = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
58
+ A_qb = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
59
+
60
+ for i in range(chunk_size):
61
+ alpha_i = alpha[:, :, :, i, None]
62
+ q_i = q[:, :, :, i, None]
63
+ gk_i = gk_cumsum[:, :, :, i, None]
64
+ mask = (torch.arange(chunk_size) <= i).to(q.device)
65
+ attn_i = (gk_i - gk_cumsum).masked_fill(~mask.unsqueeze(-1), float('-inf')).exp()
66
+ A_qk[:, :, :, i, :] = (q_i * k * attn_i).sum(-1).clone()
67
+ A_qb[:, :, :, i, :] = (q_i * beta * attn_i).sum(-1).clone()
68
+ mask = (torch.arange(chunk_size) < i).to(q.device)
69
+ # shift by one.
70
+ attn_i = (gk_i - gk[:, :, :, i, None] - gk_cumsum).masked_fill(~mask.unsqueeze(-1), float('-inf')).exp()
71
+ A_ab[:, :, :, i, :] = (alpha_i * beta * attn_i).sum(-1).clone()
72
+ A_ak[:, :, :, i, :] = (alpha_i * k * attn_i).sum(-1).clone()
73
+
74
+ A_ab = A_ab
75
+ for i in range(1, chunk_size):
76
+ A_ab[..., i, :i] = A_ab[..., i, :i].clone() + (A_ab[..., i, :, None].clone() * A_ab[..., :, :i].clone()).sum(-2)
77
+
78
+ A_ab = A_ab + torch.eye(chunk_size, dtype=torch.float, device=q.device)
79
+ u = A_ab @ (A_ak @ v)
80
+ w = A_ab @ ((gk_cumsum-gk).exp() * alpha)
81
+
82
+ o = torch.zeros_like(v)
83
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1)
84
+ for i in range(0, l // chunk_size):
85
+ q_i, k_i, v_i, u_i, w_i, beta_i = q[:, :, i], k[:, :, i], v[:, :, i], u[:, :, i], w[:, :, i], beta[:, :, i]
86
+ v2_i = u_i + w_i @ S
87
+
88
+ o_1 = A_qk[:, :, i] @ v_i
89
+ o_2 = A_qb[:, :, i] @ v2_i
90
+ o_3 = (q_i * gk_cumsum[:, :, i].exp()) @ S
91
+ o[:, :, i] = o_1 + o_2 + o_3
92
+ decay = (gk_cumsum[:, :, i, -1, None] - gk_cumsum[:, :, i]).exp()
93
+ S = S*gk_cumsum[:, :, i, -1, :, None].exp() + (k_i * decay).transpose(-1, -2) @ v_i + \
94
+ (beta_i * decay).transpose(-1, -2) @ v2_i
95
+ S = None if output_final_state is False else S
96
+ return rearrange(o, 'b h n c d -> b h (n c) d'), S
fla/ops/generalized_delta_rule/dplr/wy_fast_bwd.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.utils import check_shared_mem, is_intel_alchemist, use_cuda_graph
11
+
12
+ # https://github.com/intel/intel-xpu-backend-for-triton/issues/3449
13
+ triton_config = {'grf_mode': 'large'} if is_intel_alchemist else {}
14
+
15
+
16
+ @triton.heuristics({
17
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
18
+ })
19
+ @triton.autotune(
20
+ configs=[
21
+ triton.Config(triton_config, num_warps=num_warps, num_stages=num_stages)
22
+ for num_warps in [2, 4, 8, 16, 32]
23
+ for num_stages in [2, 3, 4]
24
+ ],
25
+ key=['BT', 'BK', 'BV'],
26
+ use_cuda_graph=use_cuda_graph,
27
+ )
28
+ @triton.jit(do_not_specialize=['T'])
29
+ def bwd_prepare_wy_repr_kernel(
30
+ A_ab_inv,
31
+ A_ak,
32
+ ag,
33
+ v,
34
+ dw,
35
+ du,
36
+ dv,
37
+ dv0,
38
+ dag,
39
+ dAak,
40
+ dAab,
41
+ offsets,
42
+ indices,
43
+ T,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BT: tl.constexpr,
48
+ BK: tl.constexpr,
49
+ BV: tl.constexpr,
50
+ USE_OFFSETS: tl.constexpr,
51
+ HEAD_FIRST: tl.constexpr
52
+ ):
53
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
54
+ i_b, i_h = i_bh // H, i_bh % H
55
+ if USE_OFFSETS:
56
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
57
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
58
+ T = eos - bos
59
+ else:
60
+ bos, eos = i_b * T, i_b * T + T
61
+
62
+ if HEAD_FIRST:
63
+ p_Aab_inv_t = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
64
+ p_Aak_t = tl.make_block_ptr(A_ak + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
65
+ p_dAak = tl.make_block_ptr(dAak + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
66
+ p_dAab = tl.make_block_ptr(dAab + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
67
+ else:
68
+ p_Aak_t = tl.make_block_ptr(A_ak + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
69
+ p_Aab_inv_t = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
70
+ p_dAak = tl.make_block_ptr(dAak + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
71
+ p_dAab = tl.make_block_ptr(dAab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
72
+
73
+ b_A_ab_inv_t = tl.load(p_Aab_inv_t, boundary_check=(0, 1))
74
+ b_A_ak_t = tl.load(p_Aak_t, boundary_check=(0, 1))
75
+ b_A_ak_t = tl.where(tl.arange(0, BT)[:, None] < tl.arange(0, BT)[None, :], b_A_ak_t, 0)
76
+ b_A_ab_inv_t = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A_ab_inv_t, 0)
77
+ b_A_tmp_t = tl.dot(b_A_ak_t, b_A_ab_inv_t).to(v.dtype.element_ty)
78
+ b_dA_tmp = tl.zeros([BT, BT], dtype=tl.float32)
79
+
80
+ for i_v in range(tl.cdiv(V, BV)):
81
+ if HEAD_FIRST:
82
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
83
+ p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
84
+ p_dv0 = tl.make_block_ptr(dv0 + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
85
+ p_du = tl.make_block_ptr(du + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
86
+ else:
87
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
88
+ p_dv = tl.make_block_ptr(dv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
89
+ p_dv0 = tl.make_block_ptr(dv0 + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
90
+ p_du = tl.make_block_ptr(du + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
91
+ b_v = tl.load(p_v, boundary_check=(0, 1))
92
+ b_du = tl.load(p_du, boundary_check=(0, 1))
93
+ b_dA_tmp += tl.dot(b_du.to(b_v.dtype), tl.trans(b_v))
94
+ b_dv0 = tl.load(p_dv0, boundary_check=(0, 1))
95
+ b_dv = b_dv0 + tl.dot(b_A_tmp_t, b_du)
96
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
97
+
98
+ b_dA_tmp = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA_tmp, 0)
99
+ b_dA_ak = tl.dot(b_A_ab_inv_t, b_dA_tmp)
100
+ b_dA_ak = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA_ak, 0)
101
+ tl.store(p_dAak, b_dA_ak, boundary_check=(0, 1))
102
+ b_dA_ab_inv = tl.dot(b_dA_tmp, b_A_ak_t)
103
+
104
+ for i_k in range(tl.cdiv(K, BK)):
105
+ if HEAD_FIRST:
106
+ p_ag = tl.make_block_ptr(ag + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
107
+ p_dag = tl.make_block_ptr(dag + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
108
+ p_dw = tl.make_block_ptr(dw + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
109
+ else:
110
+ p_ag = tl.make_block_ptr(ag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
111
+ p_dag = tl.make_block_ptr(dag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
112
+ p_dw = tl.make_block_ptr(dw + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
113
+ b_ag = tl.load(p_ag, boundary_check=(0, 1))
114
+ b_dw = tl.load(p_dw, boundary_check=(0, 1))
115
+ b_dA_ab_inv += tl.dot(b_dw, tl.trans(b_ag))
116
+ b_dag = tl.dot(b_A_ab_inv_t.to(b_dw.dtype), b_dw)
117
+ tl.store(p_dag, b_dag.to(p_dag.dtype.element_ty), boundary_check=(0, 1))
118
+
119
+ # if we know dL/dA^(-1), for dL/dA, we can use the following formula:
120
+ # dL/dA = -(A^(-1))^T @ (dL/dA^(-1)) @ (A^(-1))^T
121
+ # in the fwd pass we use fwd substitution to calculate (I-lower(A_ab))^-1.
122
+ # denote A = I - lower(A_ab), B = A^-1
123
+ # in the backward pass.
124
+ # dL/dA = -(B)^T @ (dL/dB) @ B^T
125
+ # dL/dA_ab = lower(B^T @ dL/dB @ B^T)
126
+ b_dA_ab_inv = tl.where(tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :], b_dA_ab_inv, 0)
127
+ b_dA_ab_inv = tl.dot(b_A_ab_inv_t, b_dA_ab_inv)
128
+ b_dA_ab_inv = tl.dot(b_dA_ab_inv, b_A_ab_inv_t)
129
+ b_dA_ab_inv = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA_ab_inv, 0)
130
+ tl.store(p_dAab, b_dA_ab_inv, boundary_check=(0, 1))
131
+
132
+
133
+ def chunk_dplr_bwd_wy(
134
+ A_ab_inv: torch.Tensor,
135
+ A_ak: torch.Tensor,
136
+ v: torch.Tensor,
137
+ ag: torch.Tensor,
138
+ dw: torch.Tensor,
139
+ du: torch.Tensor,
140
+ dv0: torch.Tensor,
141
+ offsets: Optional[torch.LongTensor],
142
+ indices: Optional[torch.LongTensor],
143
+ head_first: bool,
144
+ chunk_size: int,
145
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
146
+ A_ab_inv, A_ak, v, ag, dw, du = map(lambda x: x.contiguous(), [A_ab_inv, A_ak, v, ag, dw, du])
147
+ if head_first:
148
+ B, H, T, K, V = *dw.shape, du.shape[-1]
149
+ else:
150
+ B, T, H, K, V = *dw.shape, du.shape[-1]
151
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
152
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
153
+ BK = min(triton.next_power_of_2(K), 64)
154
+ BV = min(triton.next_power_of_2(V), 64) if check_shared_mem() else min(triton.next_power_of_2(V), 32)
155
+
156
+ dA_ab = torch.empty_like(A_ab_inv, dtype=torch.float)
157
+ dA_ak = torch.empty_like(A_ak, dtype=torch.float)
158
+ dv = torch.empty_like(v)
159
+ dag = torch.empty_like(ag)
160
+
161
+ bwd_prepare_wy_repr_kernel[(NT, B * H)](
162
+ A_ab_inv=A_ab_inv,
163
+ A_ak=A_ak,
164
+ ag=ag,
165
+ v=v,
166
+ dw=dw,
167
+ du=du,
168
+ dv=dv,
169
+ dv0=dv0,
170
+ dag=dag,
171
+ dAak=dA_ak,
172
+ dAab=dA_ab,
173
+ offsets=offsets,
174
+ indices=indices,
175
+ T=T,
176
+ H=H,
177
+ K=K,
178
+ V=V,
179
+ BT=BT,
180
+ BK=BK,
181
+ BV=BV,
182
+ HEAD_FIRST=head_first
183
+ )
184
+ return dA_ab, dA_ak, dv, dag
fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2024, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import gather
11
+ from fla.utils import is_gather_supported, use_cuda_graph
12
+
13
+
14
+ @triton.heuristics({
15
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
16
+ })
17
+ @triton.autotune(
18
+ configs=[
19
+ triton.Config({}, num_warps=num_warps)
20
+ for num_warps in [1, 2, 4, 8, 16]
21
+ ],
22
+ key=['BT'],
23
+ use_cuda_graph=use_cuda_graph,
24
+ )
25
+ @triton.jit(do_not_specialize=['T'])
26
+ def fwd_prepare_wy_repr_kernel_chunk32(
27
+ A_ab,
28
+ A_ab_inv,
29
+ offsets,
30
+ indices,
31
+ T,
32
+ H: tl.constexpr,
33
+ BT: tl.constexpr,
34
+ BC: tl.constexpr, # placeholder, do not delete
35
+ USE_OFFSETS: tl.constexpr,
36
+ HEAD_FIRST: tl.constexpr
37
+ ):
38
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
39
+ i_b, i_h = i_bh // H, i_bh % H
40
+ if USE_OFFSETS:
41
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
42
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
43
+ T = eos - bos
44
+ else:
45
+ bos, eos = i_b * T, i_b * T + T
46
+ if HEAD_FIRST:
47
+ p_Aab = tl.make_block_ptr(A_ab + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
48
+ p_Aab_inv = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
49
+ else:
50
+ p_Aab = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
51
+ p_Aab_inv = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
52
+ b_A_ab = tl.load(p_Aab, boundary_check=(0, 1))
53
+ b_A_ab = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A_ab, 0)
54
+ for i in range(1, BT):
55
+ mask = tl.arange(0, BT) == i
56
+ b_a = tl.sum(tl.where(mask[:, None], b_A_ab, 0), 0)
57
+ b_a = b_a + tl.sum(b_a[:, None] * b_A_ab, 0) * (tl.arange(0, BT) < i)
58
+ b_A_ab = tl.where(mask[:, None], b_a, b_A_ab)
59
+ b_A_ab += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]
60
+ tl.store(p_Aab_inv, b_A_ab.to(p_Aab_inv.dtype.element_ty), boundary_check=(0, 1))
61
+
62
+
63
+ @triton.heuristics({
64
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
65
+ })
66
+ @triton.autotune(
67
+ configs=[
68
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
69
+ for num_warps in [2, 4, 8]
70
+ for num_stages in [2, 3, 4]
71
+ ],
72
+ key=['BC'],
73
+ use_cuda_graph=use_cuda_graph,
74
+ )
75
+ @triton.jit(do_not_specialize=['T'])
76
+ def fwd_prepare_wy_repr_kernel_chunk64(
77
+ A_ab,
78
+ A_ab_inv,
79
+ offsets,
80
+ indices,
81
+ T,
82
+ H: tl.constexpr,
83
+ BT: tl.constexpr,
84
+ BC: tl.constexpr,
85
+ USE_OFFSETS: tl.constexpr,
86
+ HEAD_FIRST: tl.constexpr,
87
+ GATHER_SUPPORTED: tl.constexpr = is_gather_supported
88
+ ):
89
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
90
+ i_b, i_h = i_bh // H, i_bh % H
91
+ if USE_OFFSETS:
92
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
93
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
94
+ T = eos - bos
95
+ else:
96
+ bos, eos = i_b * T, i_b * T + T
97
+
98
+ if HEAD_FIRST:
99
+
100
+ p_A1 = tl.make_block_ptr(A_ab + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
101
+ p_A2 = tl.make_block_ptr(A_ab + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
102
+ p_A3 = tl.make_block_ptr(A_ab + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
103
+ p_A_inv1 = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
104
+ p_A_inv2 = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
105
+ p_A_inv3 = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
106
+ p_A_inv4 = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
107
+ else:
108
+ p_A1 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
109
+ p_A2 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
110
+ p_A3 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
111
+ p_A_inv1 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
112
+ p_A_inv2 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
113
+ p_A_inv3 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
114
+ p_A_inv4 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
115
+
116
+ b_A = tl.load(p_A1, boundary_check=(0, 1))
117
+ b_A2 = tl.load(p_A2, boundary_check=(0, 1))
118
+ b_A3 = tl.load(p_A3, boundary_check=(0, 1))
119
+ b_A = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A, 0)
120
+ b_A2 = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A2, 0)
121
+
122
+ for i in range(1, BC):
123
+ if GATHER_SUPPORTED:
124
+ row_idx = tl.full([1, BC], i, dtype=tl.int16)
125
+ # [1, BK] -> [BK]
126
+ b_a = tl.sum(gather(b_A, row_idx, axis=0), 0)
127
+ b_a2 = tl.sum(gather(b_A2, row_idx, axis=0), 0)
128
+ else:
129
+ mask = tl.arange(0, BC) == i
130
+ b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
131
+ b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0)
132
+ mask = tl.arange(0, BC) == i
133
+ # b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
134
+ # b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0)
135
+ b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BC) < i)
136
+ b_a2 = b_a2 + tl.sum(b_a2[:, None] * b_A2, 0) * (tl.arange(0, BC) < i)
137
+ b_A = tl.where(mask[:, None], b_a, b_A)
138
+ b_A2 = tl.where(mask[:, None], b_a2, b_A2)
139
+
140
+ # blockwise computation of lower triangular matrix's inverse
141
+ # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1]
142
+ b_A += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
143
+ b_A2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
144
+ b_A3 = tl.dot(tl.dot(b_A2, b_A3), b_A)
145
+ # tl.debug_barrier()
146
+ tl.store(p_A_inv1, b_A.to(p_A_inv1.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
147
+ tl.store(p_A_inv2, b_A2.to(p_A_inv2.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
148
+ tl.store(p_A_inv3, b_A3.to(p_A_inv3.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
149
+ # causal mask
150
+ tl.store(p_A_inv4, tl.zeros([BC, BC], dtype=tl.float32).to(p_A_inv4.dtype.element_ty), boundary_check=(0, 1))
151
+
152
+
153
+ @triton.heuristics({
154
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
155
+ })
156
+ @triton.autotune(
157
+ configs=[
158
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
159
+ for num_warps in [2, 4, 8, 16, 32]
160
+ for num_stages in [2, 3, 4]
161
+ ],
162
+ key=['BT', 'BK', 'BV'],
163
+ use_cuda_graph=use_cuda_graph,
164
+ )
165
+ @triton.jit(do_not_specialize=['T'])
166
+ def fwd_wu_kernel(
167
+ u,
168
+ w,
169
+ ag,
170
+ v,
171
+ A_ab_inv,
172
+ A_ak,
173
+ offsets,
174
+ indices,
175
+ T,
176
+ H: tl.constexpr,
177
+ K: tl.constexpr,
178
+ V: tl.constexpr,
179
+ BT: tl.constexpr,
180
+ BK: tl.constexpr,
181
+ BV: tl.constexpr,
182
+ USE_OFFSETS: tl.constexpr,
183
+ HEAD_FIRST: tl.constexpr,
184
+ ):
185
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
186
+ i_b, i_h = i_bh // H, i_bh % H
187
+ if USE_OFFSETS:
188
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
189
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
190
+ T = eos - bos
191
+ else:
192
+ bos, eos = i_b * T, i_b * T + T
193
+
194
+ if HEAD_FIRST:
195
+ p_A_ab_inv = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
196
+ p_A_ak = tl.make_block_ptr(A_ak + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
197
+ else:
198
+ p_A_ab_inv = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
199
+ p_A_ak = tl.make_block_ptr(A_ak + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
200
+ b_Aab_inv = tl.load(p_A_ab_inv, boundary_check=(0, 1))
201
+ b_Aak = tl.load(p_A_ak, boundary_check=(0, 1))
202
+ o_s = tl.arange(0, BT)
203
+ b_Aab_inv = tl.where(o_s[:, None] >= o_s[None, :], b_Aab_inv, 0)
204
+ b_Aak = tl.where(o_s[:, None] > o_s[None, :], b_Aak, 0)
205
+ # let's use tf32 here
206
+ b_Aak = tl.dot(b_Aab_inv, b_Aak)
207
+ # (SY 01/04) should be bf16 or tf32? To verify.
208
+ b_Aak = b_Aak.to(v.dtype.element_ty, fp_downcast_rounding="rtne")
209
+ b_Aab_inv = b_Aab_inv.to(ag.dtype.element_ty, fp_downcast_rounding="rtne")
210
+
211
+ for i_k in range(tl.cdiv(K, BK)):
212
+ if HEAD_FIRST:
213
+ p_ag = tl.make_block_ptr(ag + i_bh * T * K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
214
+ p_w = tl.make_block_ptr(w + i_bh * T * K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
215
+ else:
216
+ p_ag = tl.make_block_ptr(ag + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
217
+ p_w = tl.make_block_ptr(w + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
218
+ b_ag = tl.load(p_ag, boundary_check=(0, 1))
219
+ b_w = tl.dot(b_Aab_inv, b_ag) # both bf16 or fp16
220
+ tl.store(p_w, b_w.to(p_w.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
221
+
222
+ for i_v in range(tl.cdiv(V, BV)):
223
+ if HEAD_FIRST:
224
+ p_v = tl.make_block_ptr(v + i_bh * T * V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
225
+ p_u = tl.make_block_ptr(u + i_bh * T * V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
226
+ else:
227
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
228
+ p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
229
+ b_v = tl.load(p_v, boundary_check=(0, 1))
230
+ b_u = tl.dot(b_Aak, b_v) # both bf16 or fp16
231
+ tl.store(p_u, b_u.to(p_u.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
232
+
233
+
234
+ def fwd_prepare_wy_repr(
235
+ ag: torch.Tensor,
236
+ v: torch.Tensor,
237
+ A_ak: torch.Tensor,
238
+ A_ab: torch.Tensor,
239
+ offsets: Optional[torch.LongTensor],
240
+ indices: Optional[torch.LongTensor],
241
+ head_first: bool = True,
242
+ chunk_size: int = 64
243
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
244
+ if head_first:
245
+ B, H, T, K = ag.shape
246
+ else:
247
+ B, T, H, K = ag.shape
248
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
249
+
250
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
251
+ BC = min(BT, 32)
252
+ fwd_fn = fwd_prepare_wy_repr_kernel_chunk64 if BT == 64 else fwd_prepare_wy_repr_kernel_chunk32
253
+ A_ab_inv = torch.empty_like(A_ab)
254
+ fwd_fn[(NT, B * H)](
255
+ A_ab=A_ab,
256
+ A_ab_inv=A_ab_inv,
257
+ offsets=offsets,
258
+ indices=indices,
259
+ T=T,
260
+ H=H,
261
+ BT=BT,
262
+ BC=BC,
263
+ HEAD_FIRST=head_first
264
+ )
265
+ w, u = fwd_wu(
266
+ ag=ag,
267
+ v=v,
268
+ A_ak=A_ak,
269
+ A_ab_inv=A_ab_inv,
270
+ offsets=offsets,
271
+ indices=indices,
272
+ head_first=head_first,
273
+ chunk_size=BT
274
+ )
275
+ return w, u, A_ab_inv
276
+
277
+
278
+ def fwd_wu(
279
+ ag: torch.Tensor,
280
+ v: torch.Tensor,
281
+ A_ak: torch.Tensor,
282
+ A_ab_inv: torch.Tensor,
283
+ offsets: Optional[torch.LongTensor],
284
+ indices: Optional[torch.LongTensor],
285
+ head_first: bool,
286
+ chunk_size: int
287
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
288
+ if head_first:
289
+ B, H, T, K, V = *ag.shape, v.shape[-1]
290
+ else:
291
+ B, T, H, K, V = *ag.shape, v.shape[-1]
292
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
293
+
294
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
295
+ BK = min(triton.next_power_of_2(K), 64)
296
+ BV = min(triton.next_power_of_2(V), 64)
297
+
298
+ u = torch.empty_like(v)
299
+ w = torch.empty_like(ag)
300
+ fwd_wu_kernel[(NT, B*H)](
301
+ ag=ag,
302
+ v=v,
303
+ A_ak=A_ak,
304
+ A_ab_inv=A_ab_inv,
305
+ w=w,
306
+ u=u,
307
+ offsets=offsets,
308
+ indices=indices,
309
+ T=T,
310
+ H=H,
311
+ K=K,
312
+ V=V,
313
+ BT=BT,
314
+ BK=BK,
315
+ BV=BV,
316
+ HEAD_FIRST=head_first
317
+ )
318
+ return w, u
fla/ops/generalized_delta_rule/iplr/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (328 Bytes). View file
 
fla/ops/generalized_delta_rule/iplr/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (27.4 kB). View file
 
fla/ops/generalized_delta_rule/iplr/fused_recurrent.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2024-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.utils import input_guard
11
+
12
+
13
+ @triton.heuristics({
14
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
15
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
16
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
17
+ })
18
+ @triton.autotune(
19
+ configs=[
20
+ triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages)
21
+ for BV in [32, 64]
22
+ for num_warps in [2, 4, 8, 16]
23
+ for num_stages in [2, 3, 4]
24
+ ],
25
+ key=["BK"],
26
+ )
27
+ @triton.jit
28
+ def fused_recurrent_fwd_kernel(
29
+ q, # query [B, H, L, K]
30
+ k, # key [B, H, L, V]
31
+ v, # value [B, H, L, V].
32
+ a, # a [B, H, L, K]
33
+ b, # b [B, H, L, K]
34
+ o, # output [B, H, L, V]
35
+ ha, # tmp variable [B, H, L, V] for storing intermediate results of (h * a[None, :]).sum(0)
36
+ h0, # initial hidden state [B, H, K, V]
37
+ ht, # final hidden state [B, H, K, V]
38
+ offsets, # varlen offsets
39
+ scale, # K ** -0.5
40
+ H, # n_heads
41
+ T, # seq_len
42
+ K: tl.constexpr, # K
43
+ V: tl.constexpr, # V
44
+ BK: tl.constexpr, # BLOCK SIZE along the K dimension
45
+ BV: tl.constexpr, # BLOCK SIZE along the V dimension
46
+ USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
47
+ STORE_FINAL_STATE: tl.constexpr, # whether to store final state
48
+ USE_OFFSETS: tl.constexpr,
49
+ HEAD_FIRST: tl.constexpr
50
+ ):
51
+ # indices
52
+ i_v, i_nh = tl.program_id(0), tl.program_id(1)
53
+ i_n, i_h = i_nh // H, i_nh % H
54
+
55
+ if USE_OFFSETS:
56
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
57
+ T = eos - bos
58
+ else:
59
+ bos, eos = i_n * T, i_n * T + T
60
+
61
+ if HEAD_FIRST:
62
+ p_q = q + i_nh * T*K + tl.arange(0, BK)
63
+ p_k = k + i_nh * T*K + tl.arange(0, BK)
64
+ p_a = a + i_nh * T*K + tl.arange(0, BK)
65
+ p_b = b + i_nh * T*K + tl.arange(0, BK)
66
+ p_o = o + i_nh * T*V + i_v * BV + tl.arange(0, BV)
67
+ p_v = v + i_nh * T*V + i_v * BV + tl.arange(0, BV)
68
+ p_ha = ha + i_nh * T*V + i_v * BV + tl.arange(0, BV)
69
+ else:
70
+ p_q = q + (bos * H + i_h) * K + tl.arange(0, BK)
71
+ p_k = k + (bos * H + i_h) * K + tl.arange(0, BK)
72
+ p_a = a + (bos * H + i_h) * K + tl.arange(0, BK)
73
+ p_b = b + (bos * H + i_h) * K + tl.arange(0, BK)
74
+ p_ha = ha + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
75
+ p_v = v + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
76
+ p_o = o + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
77
+
78
+ mask_k = tl.arange(0, BK) < K
79
+ mask_v = (i_v * BV + tl.arange(0, BV)) < V
80
+ mask_h = mask_k[None, :] & mask_v[:, None]
81
+
82
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
83
+
84
+ if USE_INITIAL_STATE:
85
+ p_h0 = h0 + i_nh * K * V + (tl.arange(0, BK)[None, :]) * V + ((i_v * BV + tl.arange(0, BV))[:, None])
86
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
87
+
88
+ for _ in range(0, T):
89
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
90
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
91
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
92
+ b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32)
93
+ b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32)
94
+ # to store
95
+ tmp = tl.sum(b_h * b_a[None, :], axis=1)
96
+ b_h += (tmp[:, None] * b_b[None, :] + b_k[None, :] * b_v[:, None])
97
+ _o = b_h * b_q[None, :]
98
+ _o = tl.sum(_o, axis=1)
99
+ tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_v)
100
+ tl.store(p_ha, tmp.to(p_ha.dtype.element_ty), mask=mask_v)
101
+ p_q += K if HEAD_FIRST else K*H
102
+ p_k += K if HEAD_FIRST else K*H
103
+ p_o += V if HEAD_FIRST else V*H
104
+ p_v += V if HEAD_FIRST else V*H
105
+ p_ha += V if HEAD_FIRST else V*H
106
+ p_a += K if HEAD_FIRST else K*H
107
+ p_b += K if HEAD_FIRST else K*H
108
+
109
+ if STORE_FINAL_STATE:
110
+ p_ht = ht + i_nh * K * V + (tl.arange(0, BK)[None, :]) * V + ((i_v * BV + tl.arange(0, BV))[:, None])
111
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
112
+
113
+
114
+ @triton.heuristics({
115
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
116
+ 'USE_DHT': lambda args: args['dht'] is not None,
117
+ 'USE_DH0': lambda args: args['dh0'] is not None,
118
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
119
+ })
120
+ @triton.autotune(
121
+ configs=[
122
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
123
+ for num_warps in [2, 4, 8, 16]
124
+ for num_stages in [2, 3]
125
+ ],
126
+ key=["BK", "BV"],
127
+ )
128
+ @triton.jit
129
+ def fused_recurrent_bwd_kernel(
130
+ # B: batch_size, H: n_heads, T: seq_len, D: b_dhead
131
+ # NV: number of split in the V dimension. NK: number of split in the K dimension
132
+ q, # query [B, H, L, K]
133
+ k, # key [B, H, L, V]
134
+ v, # value [B, H, L, V]
135
+ a, # a [B, H, L, K]
136
+ b, # b [B, H, L, K]
137
+ ha, # ha [B, H, L, V]
138
+ dht, # gradient of final state [B, H, K, V]
139
+ dh0, # gradient of initial state [B, H, K, V]
140
+ do, # gradient of output [B, H, L, V]
141
+ dq, # gradient of query [NV, B, H, L, K]
142
+ dk, # gradient of key [NV, B, H, L, K]
143
+ dv, # gradient of value [NK, B, H, L, V]
144
+ da, # gradient of a [NV, B, H, L, K]
145
+ db, # gradient of b [NV, B, H, L, K]
146
+ dha, # gradient of ha [NK, B, H, L, V]
147
+ h0, # initial state [B, H, K, V]
148
+ scale, # K ** -0.5
149
+ offsets, # offsets
150
+ B, # batch_size
151
+ H, # n_heads
152
+ T, # seq_len
153
+ K: tl.constexpr, # K
154
+ V: tl.constexpr, # V
155
+ BK: tl.constexpr, # BLOCK SIZE along the K dimension
156
+ BV: tl.constexpr, # BLOCK SIZE along the V dimension
157
+ USE_INITIAL_STATE: tl.constexpr, # whether to use initial state h0
158
+ USE_DH0: tl.constexpr, # whether to use dh0
159
+ USE_DHT: tl.constexpr, # whether to use dht
160
+ USE_OFFSETS: tl.constexpr,
161
+ HEAD_FIRST: tl.constexpr
162
+ ):
163
+ i_v, i_nh = tl.program_id(0), tl.program_id(1)
164
+ i_n, i_h = i_nh // H, i_nh % H
165
+ dk += i_v * B * H * K * T
166
+ db += i_v * B * H * K * T
167
+ dq += i_v * B * H * K * T
168
+ da += i_v * B * H * K * T
169
+ if USE_OFFSETS:
170
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
171
+ T = eos - bos
172
+ else:
173
+ bos, eos = i_n * T, i_n * T + T
174
+ mask_k = tl.arange(0, BK) < K
175
+ mask_v = (tl.arange(0, BV) + i_v * BV) < V
176
+
177
+ q += (i_nh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K)
178
+ k += (i_nh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K)
179
+ v += (i_nh * T*V + i_v * BV) if HEAD_FIRST else ((bos * H + i_h) * V + i_v * BV)
180
+ ha += (i_nh * T*V + i_v * BV) if HEAD_FIRST else ((bos * H + i_h) * V + i_v * BV)
181
+ a += (i_nh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K)
182
+ b += (i_nh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K)
183
+ do += (i_nh * T*V + i_v * BV) if HEAD_FIRST else ((bos * H + i_h) * V + i_v * BV)
184
+ dq += (i_nh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K)
185
+ dk += (i_nh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K)
186
+ dv += (i_nh * T*V + i_v * BV) if HEAD_FIRST else ((bos * H + i_h) * V + i_v * BV)
187
+ da += (i_nh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K)
188
+ db += (i_nh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K)
189
+ dha += (i_nh * T*V + i_v * BV) if HEAD_FIRST else ((bos * H + i_h) * V + i_v * BV)
190
+
191
+ p_q = q + tl.arange(0, BK) + (T - 1) * K * (1 if HEAD_FIRST else H)
192
+ p_k = k + tl.arange(0, BK) + (T - 1) * K * (1 if HEAD_FIRST else H)
193
+ p_v = v + tl.arange(0, BV) + (T - 1) * V * (1 if HEAD_FIRST else H)
194
+ p_ha = ha + tl.arange(0, BV) + (T - 1) * V * (1 if HEAD_FIRST else H)
195
+ p_a = a + tl.arange(0, BK) + (T - 1) * K * (1 if HEAD_FIRST else H)
196
+ p_b = b + tl.arange(0, BK) + (T - 1) * K * (1 if HEAD_FIRST else H)
197
+ p_do = do + tl.arange(0, BV) + (T - 1) * V * (1 if HEAD_FIRST else H)
198
+ p_dk = dk + tl.arange(0, BK) + (T - 1) * K * (1 if HEAD_FIRST else H)
199
+ p_dv = dv + tl.arange(0, BV) + (T - 1) * V * (1 if HEAD_FIRST else H)
200
+ p_dha = dha + tl.arange(0, BV) + (T - 1) * V * (1 if HEAD_FIRST else H)
201
+ p_db = db + tl.arange(0, BK) + (T - 1) * K * (1 if HEAD_FIRST else H)
202
+ p_da = da + tl.arange(0, BK) + (T - 1) * K * (1 if HEAD_FIRST else H)
203
+ p_dq = dq + tl.arange(0, BK) + (T - 1) * K * (1 if HEAD_FIRST else H)
204
+
205
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
206
+ if USE_DHT:
207
+ p_ht = dht + i_nh * K * V + (tl.arange(0, BK)[:, None]) * V + ((i_v * BV + tl.arange(0, BV))[None, :])
208
+ b_dh += tl.load(p_ht, mask=mask_k[:, None] & mask_v[None, :], other=0).to(tl.float32)
209
+
210
+ for _ in range(T):
211
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
212
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
213
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
214
+ b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
215
+ b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32)
216
+ b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32)
217
+ b_ha = tl.load(p_ha, mask=mask_v, other=0).to(tl.float32)
218
+
219
+ b_dh += b_q[:, None] * b_do[None, :]
220
+ d_k = tl.sum(b_dh * b_v[None, :], axis=1)
221
+ d_v = tl.sum(b_dh * b_k[:, None], axis=0)
222
+ tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_k)
223
+ tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_v)
224
+
225
+ b_dha = tl.sum(b_dh * b_b[:, None], axis=0)
226
+ tl.store(p_dha, b_dha.to(p_dha.dtype.element_ty), mask=mask_v)
227
+ b_db = tl.sum(b_dh * b_ha[None, :], axis=1)
228
+ tl.store(p_db, b_db.to(p_db.dtype.element_ty), mask=mask_k)
229
+
230
+ b_dh += b_dha[None, :] * b_a[:, None]
231
+ p_do -= V if HEAD_FIRST else V*H
232
+ p_q -= K if HEAD_FIRST else K*H
233
+ p_k -= K if HEAD_FIRST else K*H
234
+ p_v -= V if HEAD_FIRST else V*H
235
+ p_dk -= K if HEAD_FIRST else K*H
236
+ p_dv -= V if HEAD_FIRST else V*H
237
+ p_b -= K if HEAD_FIRST else K*H
238
+ p_db -= K if HEAD_FIRST else K*H
239
+ p_a -= K if HEAD_FIRST else K*H
240
+ p_dha -= V if HEAD_FIRST else V*H
241
+ p_ha -= V if HEAD_FIRST else V*H
242
+
243
+ if USE_DH0:
244
+ p_dh0 = dh0 + i_nh * K * V + (tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
245
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_k[:, None] & mask_v[None, :])
246
+
247
+ tl.debug_barrier()
248
+
249
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
250
+
251
+ if USE_INITIAL_STATE:
252
+ mask_kv = mask_k[:, None] & mask_v[None, :]
253
+ p_h0 = h0 + i_nh * K * V + (tl.arange(0, BK)[:, None]) * V + ((i_v * BV + tl.arange(0, BV))[None, :])
254
+ b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)
255
+
256
+ p_k = k + tl.arange(0, BK)
257
+ p_v = v + tl.arange(0, BV)
258
+ p_ha = ha + tl.arange(0, BV)
259
+ p_do = do + tl.arange(0, BV)
260
+ p_dha = dha + tl.arange(0, BV)
261
+ p_da = da + tl.arange(0, BK)
262
+ p_dq = dq + tl.arange(0, BK)
263
+ p_b = b + tl.arange(0, BK)
264
+
265
+ for i in range(0, T):
266
+ b_dha = tl.load(p_dha, mask=mask_v, other=0).to(tl.float32)
267
+ d_a = tl.sum(b_dha[None, :] * b_h, axis=1)
268
+ tl.store(p_da, d_a.to(p_da.dtype.element_ty), mask=mask_k)
269
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
270
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
271
+ b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
272
+ b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32)
273
+ b_ha = tl.load(p_ha, mask=mask_v, other=0).to(tl.float32)
274
+ b_h += b_k[:, None] * b_v[None, :] + b_b[:, None] * b_ha[None, :]
275
+ _d_q = b_h * b_do[None, :]
276
+ d_q = tl.sum(_d_q, axis=1) * scale
277
+ tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_k)
278
+
279
+ p_k += K if HEAD_FIRST else K*H
280
+ p_do += V if HEAD_FIRST else V*H
281
+ p_v += V if HEAD_FIRST else V*H
282
+ p_da += K if HEAD_FIRST else K*H
283
+ p_dha += V if HEAD_FIRST else V*H
284
+ p_ha += V if HEAD_FIRST else V*H
285
+ p_dq += K if HEAD_FIRST else K*H
286
+ p_b += K if HEAD_FIRST else K*H
287
+
288
+
289
+ class FusedRecurrentIPLRDeltaRuleFunction(torch.autograd.Function):
290
+
291
+ @staticmethod
292
+ @input_guard
293
+ def forward(ctx, q, k, v, a, b, scale=None, initial_state=None, output_final_state=False, offsets=None, head_first=False):
294
+ if head_first:
295
+ B, H, T, K, V = *k.shape, v.shape[-1]
296
+ else:
297
+ B, T, H, K, V = *k.shape, v.shape[-1]
298
+ N = B if offsets is None else len(offsets) - 1
299
+
300
+ BK = triton.next_power_of_2(K)
301
+ if output_final_state:
302
+ final_state = q.new_empty(B, H, K, V, dtype=torch.float32)
303
+ else:
304
+ final_state = None
305
+
306
+ ha = torch.empty_like(v, dtype=torch.float32)
307
+
308
+ def grid(meta): return (
309
+ triton.cdiv(V, meta['BV']),
310
+ N * H
311
+ )
312
+ o = torch.empty_like(v)
313
+ fused_recurrent_fwd_kernel[grid](
314
+ q=q,
315
+ k=k,
316
+ v=v,
317
+ a=a,
318
+ b=b,
319
+ o=o,
320
+ ha=ha,
321
+ h0=initial_state,
322
+ ht=final_state,
323
+ scale=scale,
324
+ offsets=offsets,
325
+ H=H,
326
+ T=T,
327
+ K=K,
328
+ V=V,
329
+ BK=BK,
330
+ HEAD_FIRST=head_first
331
+ )
332
+ ctx.save_for_backward(q, k, v, a, b, ha, initial_state)
333
+ ctx.scale = scale
334
+ ctx.head_first = head_first
335
+ ctx.offsets = offsets
336
+ return o, final_state
337
+
338
+ @staticmethod
339
+ @input_guard
340
+ def backward(ctx, do, dht):
341
+ q, k, v, a, b, ha, initial_state = ctx.saved_tensors
342
+ if ctx.head_first:
343
+ B, H, T, K, V = *q.shape, v.shape[-1]
344
+ else:
345
+ B, T, H, K, V = *q.shape, v.shape[-1]
346
+
347
+ N = B if ctx.offsets is None else len(ctx.offsets) - 1
348
+ scale = ctx.scale
349
+ BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 64)
350
+ NV = triton.cdiv(V, BV)
351
+
352
+ dq = q.new_empty(NV, *q.shape)
353
+ dk = k.new_empty(NV, *k.shape)
354
+ da = a.new_empty(NV, *a.shape)
355
+ db = b.new_empty(NV, *b.shape)
356
+ dv = torch.empty_like(v)
357
+ dha = torch.empty_like(ha)
358
+ grid = (NV, N * H)
359
+
360
+ if initial_state is not None and initial_state.requires_grad:
361
+ dh0 = torch.empty_like(initial_state, dtype=torch.float32)
362
+ else:
363
+ dh0 = None
364
+
365
+ fused_recurrent_bwd_kernel[grid](
366
+ q=q,
367
+ k=k,
368
+ v=v,
369
+ a=a,
370
+ b=b,
371
+ ha=ha,
372
+ dht=dht,
373
+ dh0=dh0,
374
+ do=do,
375
+ dq=dq,
376
+ dk=dk,
377
+ dv=dv,
378
+ da=da,
379
+ db=db,
380
+ dha=dha,
381
+ h0=initial_state,
382
+ scale=scale,
383
+ offsets=ctx.offsets,
384
+ B=B,
385
+ H=H,
386
+ T=T,
387
+ K=K,
388
+ V=V,
389
+ BK=BK,
390
+ BV=BV,
391
+ HEAD_FIRST=ctx.head_first
392
+ )
393
+ dq = dq.sum(0)
394
+ dk = dk.sum(0)
395
+ da = da.sum(0)
396
+ db = db.sum(0)
397
+ return dq.to(q), dk.to(k), dv.to(v), da.to(a), db.to(b), None, dh0, None, None, None
398
+
399
+
400
+ def fused_recurrent_iplr_delta_rule(
401
+ q: torch.Tensor,
402
+ k: torch.Tensor,
403
+ v: torch.Tensor,
404
+ a: torch.Tensor,
405
+ b: torch.Tensor,
406
+ scale: float = None,
407
+ initial_state: torch.Tensor = None,
408
+ output_final_state: bool = False,
409
+ offsets: torch.Tensor = None,
410
+ head_first: bool = False
411
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
412
+ r"""
413
+ This function computes the recurrence S_t = S_t @ (I + a_t b_t^T) + v_t k_t^T in a recurrent manner.
414
+
415
+ Args:
416
+ q (torch.Tensor):
417
+ queries of shape `[B, H, T, K]`
418
+ k (torch.Tensor):
419
+ keys of shape `[B, H, T, K]`
420
+ v (torch.Tensor):
421
+ values of shape `[B, H, T, V]`
422
+ a (torch.Tensor):
423
+ as of shape `[B, H, T, K]`
424
+ b (torch.Tensor):
425
+ bs of shape `[B, H, T, K]`
426
+ scale (Optional[int]):
427
+ Scale factor for the RetNet attention scores.
428
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
429
+ initial_state (Optional[torch.Tensor]):
430
+ Initial state of shape `[B, H, K, V]`. Default: `None`.
431
+ output_final_state (Optional[bool]):
432
+ Whether to output the final state of shape `[B, H, K, V]`. Default: `False`.
433
+ offsets (Optional[torch.Tensor]):
434
+
435
+ """
436
+ if offsets is not None:
437
+ if q.shape[0] != 1:
438
+ raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `offsets`."
439
+ f"Please flatten variable-length inputs before processing.")
440
+ if head_first:
441
+ raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
442
+ if initial_state is not None and initial_state.shape[0] != len(offsets) - 1:
443
+ raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, "
444
+ f"i.e., {len(offsets) - 1} rather than {initial_state.shape[0]}.")
445
+ if scale is None:
446
+ scale = q.shape[-1] ** -0.5
447
+ else:
448
+ assert scale > 0, "scale must be positive"
449
+ o, final_state = FusedRecurrentIPLRDeltaRuleFunction.apply(
450
+ q, k, v, a, b, scale, initial_state, output_final_state, offsets, head_first)
451
+ return o, final_state
fla/ops/generalized_delta_rule/iplr/wy_fast.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
4
+
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+
11
+ from fla.utils import check_shared_mem, is_nvidia_hopper
12
+
13
+ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
14
+
15
+
16
+ @triton.heuristics({
17
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
18
+ })
19
+ @triton.autotune(
20
+ configs=[
21
+ triton.Config({}, num_warps=num_warps)
22
+ for num_warps in [1, 2, 4, 8, 16]
23
+ ],
24
+ key=['BK']
25
+ )
26
+ @triton.jit(do_not_specialize=['T'])
27
+ def fwd_prepare_wy_repr_kernel_chunk32(
28
+ a,
29
+ b,
30
+ A,
31
+ offsets,
32
+ indices,
33
+ T,
34
+ H: tl.constexpr,
35
+ K: tl.constexpr,
36
+ BT: tl.constexpr,
37
+ BK: tl.constexpr,
38
+ BC: tl.constexpr, # dummy placeholder
39
+ USE_OFFSETS: tl.constexpr,
40
+ HEAD_FIRST: tl.constexpr,
41
+ ):
42
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
43
+ i_b, i_h = i_bh // H, i_bh % H
44
+ if USE_OFFSETS:
45
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
46
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
47
+ T = eos - bos
48
+ else:
49
+ bos, eos = i_b * T, i_b * T + T
50
+
51
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
52
+ for i_k in range(tl.cdiv(K, BK)):
53
+ if HEAD_FIRST:
54
+ p_a = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
55
+ p_b = tl.make_block_ptr(b + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
56
+ else:
57
+ p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
58
+ p_b = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
59
+ b_a = tl.load(p_a, boundary_check=(0, 1))
60
+ b_b = tl.load(p_b, boundary_check=(0, 1))
61
+ b_A += tl.dot(b_a, b_b)
62
+
63
+ b_A = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0)
64
+ for i in range(1, BT):
65
+ mask = tl.arange(0, BT) == i
66
+ b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
67
+ b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i)
68
+ b_A = tl.where(mask[:, None], b_a, b_A)
69
+ b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]
70
+
71
+ if HEAD_FIRST:
72
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
73
+ else:
74
+ p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
75
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
76
+
77
+
78
+ @triton.heuristics({
79
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
80
+ })
81
+ @triton.autotune(
82
+ configs=[
83
+ triton.Config({}, num_warps=num_warps)
84
+ for num_warps in [1, 2, 4, 8, 16]
85
+ ],
86
+ key=['BK']
87
+ )
88
+ @triton.jit(do_not_specialize=['T'])
89
+ def fwd_prepare_wy_repr_kernel_chunk64(
90
+ a,
91
+ b,
92
+ A,
93
+ offsets,
94
+ indices,
95
+ T,
96
+ H: tl.constexpr,
97
+ K: tl.constexpr,
98
+ BT: tl.constexpr,
99
+ BK: tl.constexpr,
100
+ BC: tl.constexpr,
101
+ USE_OFFSETS: tl.constexpr,
102
+ HEAD_FIRST: tl.constexpr
103
+ ):
104
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
105
+ i_b, i_h = i_bh // H, i_bh % H
106
+ if USE_OFFSETS:
107
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
108
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
109
+ T = eos - bos
110
+ else:
111
+ bos, eos = i_b * T, i_b * T + T
112
+
113
+ b_A = tl.zeros([BC, BC], dtype=tl.float32)
114
+ b_A2 = tl.zeros([BC, BC], dtype=tl.float32)
115
+ b_A3 = tl.zeros([BC, BC], dtype=tl.float32)
116
+
117
+ for i_k in range(tl.cdiv(K, BK)):
118
+ if HEAD_FIRST:
119
+ p_a1 = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
120
+ p_a2 = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT + BC, i_k * BK), (BC, BK), (1, 0))
121
+ p_b1 = tl.make_block_ptr(b + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BC), (0, 1))
122
+ p_b2 = tl.make_block_ptr(b + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + BC), (BK, BC), (0, 1))
123
+ else:
124
+ p_a1 = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
125
+ p_a2 = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + BC, i_k * BK), (BC, BK), (1, 0))
126
+ p_b1 = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT), (BK, BC), (0, 1))
127
+ p_b2 = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT + BC), (BK, BC), (0, 1))
128
+ b_a1 = tl.load(p_a1, boundary_check=(0, 1))
129
+ b_a2 = tl.load(p_a2, boundary_check=(0, 1))
130
+ b_b1 = tl.load(p_b1, boundary_check=(0, 1))
131
+ b_b2 = tl.load(p_b2, boundary_check=(0, 1))
132
+ b_A += tl.dot(b_a1, b_b1, allow_tf32=False)
133
+ b_A2 += tl.dot(b_a2, b_b2, allow_tf32=False)
134
+ b_A3 += tl.dot(b_a2, b_b1, allow_tf32=False)
135
+
136
+ b_A = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A, 0)
137
+ b_A2 = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A2, 0)
138
+
139
+ for i in range(1, BC):
140
+ mask = tl.arange(0, BC) == i
141
+ b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
142
+ b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0)
143
+ b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BC) < i)
144
+ b_a2 = b_a2 + tl.sum(b_a2[:, None] * b_A2, 0) * (tl.arange(0, BC) < i)
145
+ b_A = tl.where(mask[:, None], b_a, b_A)
146
+ b_A2 = tl.where(mask[:, None], b_a2, b_A2)
147
+
148
+ # blockwise computation of lower triangular matrix's inverse
149
+ # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1]
150
+ b_A += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
151
+ b_A2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
152
+ b_A3 = tl.dot(tl.dot(b_A2, b_A3, allow_tf32=False), b_A, allow_tf32=False)
153
+
154
+ if HEAD_FIRST:
155
+ p_A1 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
156
+ p_A2 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
157
+ p_A3 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
158
+ p_A4 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
159
+ else:
160
+ p_A1 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
161
+ p_A2 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
162
+ p_A3 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
163
+ p_A4 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
164
+ tl.store(p_A1, b_A.to(p_A1.dtype.element_ty), boundary_check=(0, 1))
165
+ tl.store(p_A2, b_A2.to(p_A2.dtype.element_ty), boundary_check=(0, 1))
166
+ tl.store(p_A3, b_A3.to(p_A3.dtype.element_ty), boundary_check=(0, 1))
167
+ # causal mask
168
+ tl.store(p_A4, tl.zeros([BC, BC], dtype=tl.float32).to(p_A4.dtype.element_ty), boundary_check=(0, 1))
169
+
170
+
171
+ @triton.heuristics({
172
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
173
+ })
174
+ @triton.autotune(
175
+ configs=[
176
+ triton.Config({}, num_warps=num_warps)
177
+ for num_warps in NUM_WARPS
178
+ ],
179
+ key=['BT', 'BK', 'BV']
180
+ )
181
+ @triton.jit(do_not_specialize=['T'])
182
+ def fwd_wu_kernel(
183
+ w,
184
+ u,
185
+ a,
186
+ k,
187
+ v,
188
+ A,
189
+ offsets,
190
+ indices,
191
+ T,
192
+ H: tl.constexpr,
193
+ K: tl.constexpr,
194
+ V: tl.constexpr,
195
+ BT: tl.constexpr,
196
+ BK: tl.constexpr,
197
+ BV: tl.constexpr,
198
+ USE_OFFSETS: tl.constexpr,
199
+ HEAD_FIRST: tl.constexpr
200
+ ):
201
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
202
+ i_b, i_h = i_bh // H, i_bh % H
203
+ if USE_OFFSETS:
204
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
205
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
206
+ T = eos - bos
207
+ else:
208
+ bos, eos = i_b * T, i_b * T + T
209
+
210
+ if HEAD_FIRST:
211
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
212
+ else:
213
+ p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
214
+
215
+ b_A = tl.load(p_A, boundary_check=(0, 1))
216
+ b_Aak = tl.zeros([BT, BT], dtype=tl.float32)
217
+
218
+ for i_k in range(tl.cdiv(K, BK)):
219
+ if HEAD_FIRST:
220
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
221
+ p_a = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
222
+ p_w = tl.make_block_ptr(w + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
223
+ else:
224
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
225
+ p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
226
+ p_w = tl.make_block_ptr(w + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
227
+ b_k = tl.load(p_k, boundary_check=(0, 1))
228
+ b_a = tl.load(p_a, boundary_check=(0, 1))
229
+ b_w = tl.dot(b_A, b_a)
230
+ b_Aak += tl.dot(b_a, tl.trans(b_k))
231
+ tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
232
+
233
+ b_Aak = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_Aak, 0)
234
+ b_Aak = b_Aak.to(k.dtype.element_ty)
235
+
236
+ for i_v in range(tl.cdiv(V, BV)):
237
+ if HEAD_FIRST:
238
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
239
+ p_u = tl.make_block_ptr(u + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
240
+ else:
241
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
242
+ p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
243
+ b_v = tl.load(p_v, boundary_check=(0, 1))
244
+ b_v = tl.dot(b_Aak, b_v).to(v.dtype.element_ty)
245
+ b_u = tl.dot(b_A, b_v)
246
+ tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
247
+
248
+
249
+ def fwd_prepare_wy_repr(
250
+ a: torch.Tensor,
251
+ b: torch.Tensor,
252
+ v: torch.Tensor,
253
+ k: torch.Tensor,
254
+ offsets: Optional[torch.LongTensor],
255
+ indices: Optional[torch.LongTensor],
256
+ head_first: bool = True,
257
+ chunk_size: int = 64
258
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
259
+ if head_first:
260
+ B, H, T, K = a.shape
261
+ else:
262
+ B, T, H, K = a.shape
263
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
264
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
265
+ BC = min(BT, 32)
266
+ BK = min(triton.next_power_of_2(K), 64)
267
+
268
+ A = torch.empty(B, *((H, T) if head_first else (T, H)), BT, device=a.device, dtype=a.dtype)
269
+ fwd_fn = fwd_prepare_wy_repr_kernel_chunk64 if BT == 64 else fwd_prepare_wy_repr_kernel_chunk32
270
+
271
+ fwd_fn[(NT, B * H)](
272
+ a=a,
273
+ b=b,
274
+ A=A,
275
+ offsets=offsets,
276
+ indices=indices,
277
+ T=T,
278
+ H=H,
279
+ K=K,
280
+ BT=BT,
281
+ BK=BK,
282
+ BC=BC,
283
+ HEAD_FIRST=head_first
284
+ )
285
+ w, u = fwd_wu(
286
+ a=a,
287
+ v=v,
288
+ k=k,
289
+ A=A,
290
+ offsets=offsets,
291
+ indices=indices,
292
+ head_first=head_first,
293
+ chunk_size=chunk_size
294
+ )
295
+ return w, u, A
296
+
297
+
298
+ def fwd_wu(
299
+ a: torch.Tensor,
300
+ v: torch.Tensor,
301
+ k: torch.Tensor,
302
+ A: torch.Tensor,
303
+ offsets: Optional[torch.LongTensor],
304
+ indices: Optional[torch.LongTensor],
305
+ head_first: bool,
306
+ chunk_size: int
307
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
308
+ if head_first:
309
+ B, H, T, K, V = *a.shape, v.shape[-1]
310
+ else:
311
+ B, T, H, K, V = *a.shape, v.shape[-1]
312
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
313
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
314
+ CONST_TILING = 64 if check_shared_mem() else 32
315
+ BK = min(triton.next_power_of_2(K), CONST_TILING)
316
+ BV = min(triton.next_power_of_2(V), CONST_TILING)
317
+
318
+ u = torch.empty_like(v)
319
+ w = torch.empty_like(a)
320
+ fwd_wu_kernel[(NT, B*H)](
321
+ a=a,
322
+ v=v,
323
+ w=w,
324
+ u=u,
325
+ A=A,
326
+ k=k,
327
+ offsets=offsets,
328
+ indices=indices,
329
+ T=T,
330
+ H=H,
331
+ K=K,
332
+ V=V,
333
+ BT=BT,
334
+ BK=BK,
335
+ BV=BV,
336
+ HEAD_FIRST=head_first
337
+ )
338
+ return w, u
fla/ops/gla/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (333 Bytes). View file
 
fla/ops/gla/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (81.8 kB). View file
 
fla/ops/gla/__pycache__/fused_chunk.cpython-312.pyc ADDED
Binary file (35.3 kB). View file
 
fla/ops/gsa/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk import chunk_gsa
4
+ from .fused_recurrent import fused_recurrent_gsa
5
+
6
+ __all__ = [
7
+ 'chunk_gsa',
8
+ 'fused_recurrent_gsa'
9
+ ]
fla/ops/gsa/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (282 Bytes). View file
 
fla/ops/gsa/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (69.4 kB). View file
 
fla/ops/gsa/chunk.py ADDED
@@ -0,0 +1,1264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from einops import reduce
10
+
11
+ from fla.ops.common.chunk_h import chunk_bwd_dh, chunk_fwd_h
12
+ from fla.ops.gla.chunk import chunk_gla_bwd, chunk_gla_fwd
13
+ from fla.ops.utils import chunk_local_cumsum, softmax_bwd, softmax_fwd
14
+ from fla.ops.utils.op import exp, safe_exp
15
+ from fla.utils import input_guard
16
+
17
+
18
+ @triton.heuristics({
19
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
20
+ })
21
+ @triton.autotune(
22
+ configs=[
23
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
24
+ for BK in [32, 64]
25
+ for BV in [32, 64]
26
+ for num_warps in [2, 4, 8]
27
+ for num_stages in [2, 3, 4]
28
+ ],
29
+ key=['BT']
30
+ )
31
+ @triton.jit(do_not_specialize=['T'])
32
+ def chunk_gsa_fwd_k_kernel_inter(
33
+ q,
34
+ k,
35
+ h,
36
+ g,
37
+ o,
38
+ A,
39
+ offsets,
40
+ indices,
41
+ scale,
42
+ T,
43
+ HQ: tl.constexpr,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BT: tl.constexpr,
48
+ BK: tl.constexpr,
49
+ BV: tl.constexpr,
50
+ NG: tl.constexpr,
51
+ USE_OFFSETS: tl.constexpr,
52
+ HEAD_FIRST: tl.constexpr
53
+ ):
54
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
55
+ i_bg = i_bh // NG
56
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
57
+ i_h = i_hq // NG
58
+ if USE_OFFSETS:
59
+ i_tg = i_t
60
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
61
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
62
+ T = eos - bos
63
+ NT = tl.cdiv(T, BT)
64
+ else:
65
+ NT = tl.cdiv(T, BT)
66
+ i_tg = i_b * NT + i_t
67
+ bos, eos = i_b * T, i_b * T + T
68
+
69
+ o_i = tl.arange(0, BT)
70
+ m_s = o_i[:, None] >= o_i[None, :]
71
+
72
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
73
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
74
+ for i_k in range(tl.cdiv(K, BK)):
75
+ if HEAD_FIRST:
76
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
77
+ p_k = tl.make_block_ptr(k + i_bg * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
78
+ p_h = tl.make_block_ptr(h + (i_bg * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
79
+ else:
80
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
81
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
82
+ p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
83
+
84
+ # [BT, BK]
85
+ b_q = tl.load(p_q, boundary_check=(0, 1))
86
+ b_q = (b_q * scale).to(b_q.dtype)
87
+ # [BK, BT]
88
+ b_k = tl.load(p_k, boundary_check=(0, 1))
89
+ # [BK, BV]
90
+ b_h = tl.load(p_h, boundary_check=(0, 1))
91
+ # [BT, BV]
92
+ b_o += tl.dot(b_q, b_h)
93
+ # [BT, BT]
94
+ b_A += tl.dot(b_q, b_k)
95
+ if HEAD_FIRST:
96
+ p_g = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
97
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
98
+ p_A = tl.make_block_ptr(A + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
99
+ else:
100
+ p_g = tl.make_block_ptr(g + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
101
+ p_o = tl.make_block_ptr(o + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
102
+ p_A = tl.make_block_ptr(A + (bos * HQ + i_hq) * BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
103
+ # [BT, BV]
104
+ b_g = tl.load(p_g, boundary_check=(0, 1))
105
+ b_o = b_o * exp(b_g)
106
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
107
+
108
+ # [BT, BT]
109
+ b_A = tl.where(m_s, b_A, 0.)
110
+ if i_v == 0:
111
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
112
+
113
+
114
+ @triton.heuristics({
115
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
116
+ })
117
+ @triton.jit(do_not_specialize=['T'])
118
+ def chunk_gsa_fwd_k_kernel_intra(
119
+ v,
120
+ g,
121
+ o,
122
+ A,
123
+ offsets,
124
+ indices,
125
+ T,
126
+ HQ: tl.constexpr,
127
+ H: tl.constexpr,
128
+ V: tl.constexpr,
129
+ BT: tl.constexpr,
130
+ BC: tl.constexpr,
131
+ BV: tl.constexpr,
132
+ NC: tl.constexpr,
133
+ NG: tl.constexpr,
134
+ USE_OFFSETS: tl.constexpr,
135
+ HEAD_FIRST: tl.constexpr
136
+ ):
137
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
138
+ i_bg = i_bh // NG
139
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
140
+ i_h = i_hq // NG
141
+ i_t, i_i = i_c // NC, i_c % NC
142
+ if USE_OFFSETS:
143
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
144
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
145
+ T = eos - bos
146
+ else:
147
+ bos, eos = i_b * T, i_b * T + T
148
+
149
+ o_v = i_v * BV + tl.arange(0, BV)
150
+ m_v = o_v < V
151
+
152
+ if i_t * BT + i_i * BC > T:
153
+ return
154
+
155
+ if HEAD_FIRST:
156
+ p_g = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
157
+ p_gn = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + min(i_t * BT + i_i * BC, T) * V + o_v, BV), BV)
158
+ else:
159
+ p_g = tl.make_block_ptr(g + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
160
+ p_gn = g + (bos + min(i_t * BT + i_i * BC, T)) * H*V + i_h * V + o_v
161
+ # [BV,]
162
+ b_gn = tl.load(p_gn, mask=m_v, other=0)
163
+ # [BC, BV]
164
+ b_o = tl.zeros([BC, BV], dtype=tl.float32)
165
+ for i_j in range(0, i_i):
166
+ if HEAD_FIRST:
167
+ p_A = tl.make_block_ptr(A + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
168
+ p_v = tl.make_block_ptr(v + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
169
+ p_gv = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
170
+ else:
171
+ p_A = tl.make_block_ptr(A + (bos*HQ+i_hq) * BT, (T, BT), (HQ*BT, 1), (i_t*BT+i_i*BC, i_j * BC), (BC, BC), (1, 0))
172
+ p_v = tl.make_block_ptr(v + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
173
+ p_gv = tl.make_block_ptr(g + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
174
+ # [BC, BV]
175
+ b_v = tl.load(p_v, boundary_check=(0, 1))
176
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
177
+ b_vg = (b_v * exp(b_gn[None, :] - b_gv)).to(b_v.dtype)
178
+ # [BC, BC]
179
+ b_A = tl.load(p_A, boundary_check=(0, 1))
180
+ b_o += tl.dot(b_A, b_vg)
181
+ # [BC, BV]
182
+ b_g = tl.load(p_g, boundary_check=(0, 1))
183
+ b_o *= exp(b_g - b_gn[None, :])
184
+
185
+ o_i = tl.arange(0, BC)
186
+ if HEAD_FIRST:
187
+ o_A = i_bh * T*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC
188
+ else:
189
+ o_A = (bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * HQ*BT + i_hq * BT + i_i * BC
190
+ m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
191
+ for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
192
+ if HEAD_FIRST:
193
+ p_v = tl.max_contiguous(tl.multiple_of(v + i_bg * T*V + (i_t * BT + i_i * BC + j) * V + o_v, BV), BV)
194
+ p_gv = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + (i_t * BT + i_i * BC + j) * V + o_v, BV), BV)
195
+ else:
196
+ p_v = v + (bos + i_t * BT + i_i * BC + j) * H*V + i_h * V + o_v
197
+ p_gv = g + (bos + i_t * BT + i_i * BC + j) * H*V + i_h * V + o_v
198
+ # [BC,]
199
+ b_A = tl.load(A + o_A + j, mask=m_A, other=0)
200
+ # [BV,]
201
+ b_v = tl.load(p_v, mask=m_v, other=0).to(tl.float32)
202
+ b_gv = tl.load(p_gv, mask=m_v, other=0).to(tl.float32)
203
+ # [BC, BV]
204
+ b_vg = b_v[None, :] * exp(b_g - b_gv[None, :])
205
+ # avoid 0 * inf = inf
206
+ b_o += tl.where(o_i[:, None] >= j, b_A[:, None] * b_vg, 0.)
207
+ if HEAD_FIRST:
208
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
209
+ else:
210
+ p_o = tl.make_block_ptr(o + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
211
+ b_o += tl.load(p_o, boundary_check=(0, 1))
212
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
213
+
214
+
215
+ @triton.heuristics({
216
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
217
+ })
218
+ @triton.autotune(
219
+ configs=[
220
+ triton.Config({}, num_warps=num_warps)
221
+ for num_warps in [2, 4, 8]
222
+ ],
223
+ key=["BT"]
224
+ )
225
+ @triton.jit(do_not_specialize=['T'])
226
+ def chunk_gsa_bwd_k_kernel_dA(
227
+ v,
228
+ g,
229
+ do,
230
+ dA,
231
+ indices,
232
+ offsets,
233
+ scale,
234
+ T,
235
+ B: tl.constexpr,
236
+ HQ: tl.constexpr,
237
+ H: tl.constexpr,
238
+ V: tl.constexpr,
239
+ BT: tl.constexpr,
240
+ BC: tl.constexpr,
241
+ BV: tl.constexpr,
242
+ NC: tl.constexpr,
243
+ NG: tl.constexpr,
244
+ USE_OFFSETS: tl.constexpr,
245
+ HEAD_FIRST: tl.constexpr
246
+ ):
247
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
248
+ i_bg = i_bh // NG
249
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
250
+ i_h = i_hq // NG
251
+ i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC
252
+ if USE_OFFSETS:
253
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
254
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
255
+ all = T
256
+ T = eos - bos
257
+ else:
258
+ bos, eos = i_b * T, i_b * T + T
259
+ all = B * T
260
+
261
+ o_v = i_v * BV + tl.arange(0, BV)
262
+ m_v = o_v < V
263
+
264
+ if i_t * BT + i_i * BC > T:
265
+ return
266
+
267
+ if HEAD_FIRST:
268
+ p_dA = tl.make_block_ptr(dA+(i_v*B*H+i_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
269
+ else:
270
+ p_dA = tl.make_block_ptr(dA+((i_v*all+bos)*HQ+i_hq)*BT, (T, BT), (HQ*BT, 1), (i_t*BT+i_i*BC, i_j*BC), (BC, BC), (1, 0))
271
+
272
+ # [BC, BC]
273
+ b_dA = tl.zeros([BC, BC], dtype=tl.float32)
274
+ if i_i > i_j:
275
+ if HEAD_FIRST:
276
+ p_v = tl.make_block_ptr(v + i_bg * T*V, (V, T), (1, V), (i_v * BV, i_t * BT + i_j * BC), (BV, BC), (0, 1))
277
+ p_gv = tl.make_block_ptr(g + i_bg * T*V, (V, T), (1, V), (i_v * BV, i_t * BT + i_j * BC), (BV, BC), (0, 1))
278
+ p_gn = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + (i_t * BT + i_i * BC) * V + o_v, BV), BV)
279
+ p_g = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
280
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
281
+ else:
282
+ p_v = tl.make_block_ptr(v + (bos*H+i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t*BT + i_j*BC), (BV, BC), (0, 1))
283
+ p_gv = tl.make_block_ptr(g + (bos*H+i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t*BT + i_j*BC), (BV, BC), (0, 1))
284
+ p_gn = g + (bos + i_t*BT + i_i*BC) * H*V + i_h * V + o_v
285
+ p_g = tl.make_block_ptr(g + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
286
+ p_do = tl.make_block_ptr(do + (bos*HQ+i_hq) * V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
287
+ # [BV,]
288
+ b_gn = tl.load(p_gn, mask=m_v, other=0.)
289
+ # [BC, BV]
290
+ b_g = tl.load(p_g, boundary_check=(0, 1))
291
+ b_do = tl.load(p_do, boundary_check=(0, 1))
292
+ b_do = (b_do * exp(b_g - b_gn[None, :]) * scale).to(b_do.dtype)
293
+ # [BV, BC]
294
+ b_v = tl.load(p_v, boundary_check=(0, 1))
295
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
296
+ b_vg = (b_v * exp(b_gn[:, None] - b_gv)).to(b_v.dtype)
297
+ # [BC, BC]
298
+ b_dA = tl.dot(b_do, b_vg)
299
+ elif i_i == i_j:
300
+ if HEAD_FIRST:
301
+ p_g = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
302
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
303
+ p_v = tl.max_contiguous(tl.multiple_of(v + i_bg * T*V + (i_t * BT + i_j * BC) * V + o_v, BV), BV)
304
+ p_gv = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + (i_t * BT + i_j * BC) * V + o_v, BV), BV)
305
+ else:
306
+ p_g = tl.make_block_ptr(g + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
307
+ p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
308
+ p_v = v + (bos + i_t*BT + i_j*BC) * H*V + i_h * V + o_v
309
+ p_gv = g + (bos + i_t*BT + i_j*BC) * H*V + i_h * V + o_v
310
+ # [BC, BV]
311
+ b_g = tl.load(p_g, boundary_check=(0, 1))
312
+ b_do = tl.load(p_do, boundary_check=(0, 1)) * scale
313
+ m_v = o_v < V
314
+
315
+ o_i = tl.arange(0, BC)
316
+ # [BC, BC]
317
+ m_dA = o_i[:, None] >= o_i[None, :]
318
+ for j in range(0, min(BC, T - i_t * BT - i_j * BC)):
319
+ # [BV,]
320
+ b_v = tl.load(p_v, mask=m_v, other=0).to(tl.float32)
321
+ b_gv = tl.load(p_gv, mask=m_v, other=0).to(tl.float32)
322
+ # [BC,]
323
+ b_dAj = tl.sum(b_do * b_v[None, :] * exp(b_g - b_gv[None, :]), 1)
324
+ b_dA = tl.where((o_i == j)[None, :], b_dAj[:, None], b_dA)
325
+
326
+ p_v += (1 if HEAD_FIRST else H) * V
327
+ p_gv += (1 if HEAD_FIRST else H) * V
328
+ b_dA = tl.where(m_dA, b_dA, 0.)
329
+ tl.store(p_dA, b_dA.to(dA.dtype.element_ty), boundary_check=(0, 1))
330
+
331
+
332
+ @triton.heuristics({
333
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
334
+ })
335
+ @triton.autotune(
336
+ configs=[
337
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
338
+ for num_warps in [2, 4]
339
+ for num_stages in [2, 3, 4]
340
+ ],
341
+ key=['BT']
342
+ )
343
+ @triton.jit(do_not_specialize=['T'])
344
+ def chunk_gsa_bwd_k_kernel_dqkvg(
345
+ q,
346
+ k,
347
+ v,
348
+ h,
349
+ g,
350
+ A,
351
+ do,
352
+ dh,
353
+ dq,
354
+ dk,
355
+ dv,
356
+ dg,
357
+ dgv,
358
+ dA,
359
+ offsets,
360
+ indices,
361
+ scale,
362
+ T,
363
+ B: tl.constexpr,
364
+ HQ: tl.constexpr,
365
+ H: tl.constexpr,
366
+ K: tl.constexpr,
367
+ V: tl.constexpr,
368
+ BT: tl.constexpr,
369
+ BK: tl.constexpr,
370
+ BV: tl.constexpr,
371
+ NG: tl.constexpr,
372
+ USE_OFFSETS: tl.constexpr,
373
+ HEAD_FIRST: tl.constexpr
374
+ ):
375
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
376
+ i_bg = i_bh // NG
377
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
378
+ i_h = i_hq // NG
379
+ if USE_OFFSETS:
380
+ i_tg = i_t
381
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
382
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
383
+ all = T
384
+ T = eos - bos
385
+ NT = tl.cdiv(T, BT)
386
+ else:
387
+ NT = tl.cdiv(T, BT)
388
+ i_tg = i_b * NT + i_t
389
+ bos, eos = i_b * T, i_b * T + T
390
+ all = B * T
391
+
392
+ o_i = tl.arange(0, BT)
393
+ o_t = min(i_t * BT + BT, T)
394
+ m_s = o_i[:, None] >= o_i[None, :]
395
+
396
+ if HEAD_FIRST:
397
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
398
+ p_k = tl.make_block_ptr(k + i_bg * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
399
+ p_A = tl.make_block_ptr(A + (i_k*B*H+i_bh) * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
400
+ else:
401
+ p_q = tl.make_block_ptr(q + (bos*HQ+i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
402
+ p_k = tl.make_block_ptr(k + (bos*H+i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
403
+ p_A = tl.make_block_ptr(A + ((i_k*all+bos)*HQ+i_hq)*BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
404
+
405
+ # [BT, BK]
406
+ b_q = tl.load(p_q, boundary_check=(0, 1))
407
+ b_k = tl.load(p_k, boundary_check=(0, 1))
408
+ # [BT, BT]
409
+ b_A = tl.dot((b_q * scale).to(b_q.dtype), tl.trans(b_k))
410
+ b_A = tl.where(m_s, b_A, 0.)
411
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
412
+
413
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
414
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
415
+ for i_v in range(tl.cdiv(V, BV)):
416
+ o_v = i_v * BV + tl.arange(0, BV)
417
+ if HEAD_FIRST:
418
+ p_v = tl.make_block_ptr(v + i_bg * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
419
+ p_g = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
420
+ p_gn = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + (o_t - 1) * V + o_v, BV), BV)
421
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
422
+ p_dv = tl.make_block_ptr(dv + (i_k*B*H+i_bh) * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
423
+ p_dg = tl.make_block_ptr(dg + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
424
+ p_dgv = tl.make_block_ptr(dgv + (i_k*B*H+i_bh) * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
425
+ p_h = tl.make_block_ptr(h + i_bg * NT*K*V + i_t * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
426
+ p_dh = tl.make_block_ptr(dh + i_bh * NT*K*V + i_t * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
427
+ else:
428
+ p_v = tl.make_block_ptr(v + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
429
+ p_g = tl.make_block_ptr(g + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
430
+ p_gn = g + (bos + o_t - 1) * H*V + i_h * V + o_v
431
+ p_do = tl.make_block_ptr(do + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
432
+ p_dv = tl.make_block_ptr(dv + ((i_k*all+bos)*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
433
+ p_dg = tl.make_block_ptr(dg + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
434
+ p_dgv = tl.make_block_ptr(dgv+((i_k*all+bos)*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
435
+ p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
436
+ p_dh = tl.make_block_ptr(dh + (i_tg * HQ + i_hq) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
437
+ m_v = o_v < V
438
+
439
+ # [BV,]
440
+ b_gn = tl.load(p_gn, mask=m_v, other=0)
441
+ # [BT, BV]
442
+ b_v = tl.load(p_v, boundary_check=(0, 1))
443
+ b_g = tl.load(p_g, boundary_check=(0, 1))
444
+ b_gv = exp(b_gn[None, :] - b_g)
445
+ # [BV, BK]
446
+ b_h = tl.load(p_h, boundary_check=(0, 1))
447
+ # [BT, BV]
448
+ b_do = tl.load(p_do, boundary_check=(0, 1))
449
+ b_do = (b_do * exp(b_g) * scale).to(b_do.dtype)
450
+ # [BK, BV]
451
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
452
+ # [BV]
453
+ b_dg = tl.sum(tl.trans(b_h) * b_dh, 0) * exp(b_gn)
454
+
455
+ b_dh = b_dh.to(b_k.dtype)
456
+ # [BT, BK]
457
+ b_dq += tl.dot(b_do, b_h.to(b_k.dtype))
458
+ b_dk += tl.dot((b_v * b_gv).to(b_v.dtype), tl.trans(b_dh))
459
+ # [BT, BV]
460
+ b_dv = tl.dot(b_k, b_dh) * b_gv
461
+ # [BV]
462
+ b_dg += tl.sum(b_dv * b_v, 0)
463
+
464
+ if i_k == 0:
465
+ b_dgv = tl.load(p_dg, boundary_check=(0, 1)) + b_dg[None, :]
466
+ else:
467
+ b_dgv = tl.zeros([BT, BV], dtype=tl.float32) + b_dg[None, :]
468
+
469
+ tl.store(p_dgv, b_dgv.to(p_dgv.dtype.element_ty), boundary_check=(0, 1))
470
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
471
+ if HEAD_FIRST:
472
+ p_dA = tl.make_block_ptr(dA + i_bh * T*BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
473
+ p_dq = tl.make_block_ptr(dq + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
474
+ p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
475
+ else:
476
+ p_dA = tl.make_block_ptr(dA + (bos*HQ + i_hq) * BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
477
+ p_dq = tl.make_block_ptr(dq + (bos*HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
478
+ p_dk = tl.make_block_ptr(dk + (bos*HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
479
+ # [BT, BT]
480
+ b_dA = tl.load(p_dA, boundary_check=(0, 1))
481
+ # [BT, BK]
482
+ b_dq += tl.dot(b_dA, b_k)
483
+ b_dk += tl.dot(tl.trans(b_dA).to(b_k.dtype), b_q)
484
+
485
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
486
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
487
+
488
+
489
+ @triton.heuristics({
490
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
491
+ })
492
+ @triton.jit(do_not_specialize=['T'])
493
+ def chunk_gsa_bwd_k_kernel_intra_dvg(
494
+ v,
495
+ g,
496
+ o,
497
+ A,
498
+ do,
499
+ dv,
500
+ dg,
501
+ offsets,
502
+ indices,
503
+ T,
504
+ HQ: tl.constexpr,
505
+ H: tl.constexpr,
506
+ V: tl.constexpr,
507
+ BT: tl.constexpr,
508
+ BC: tl.constexpr,
509
+ BV: tl.constexpr,
510
+ NC: tl.constexpr,
511
+ NG: tl.constexpr,
512
+ USE_OFFSETS: tl.constexpr,
513
+ HEAD_FIRST: tl.constexpr
514
+ ):
515
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
516
+ i_bg = i_bh // NG
517
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
518
+ i_h = i_hq // NG
519
+ i_t, i_i = i_c // NC, i_c % NC
520
+ if USE_OFFSETS:
521
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
522
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
523
+ T = eos - bos
524
+ else:
525
+ bos, eos = i_b * T, i_b * T + T
526
+
527
+ o_v = i_v * BV + tl.arange(0, BV)
528
+ m_v = o_v < V
529
+
530
+ if i_t * BT + i_i * BC > T:
531
+ return
532
+
533
+ if HEAD_FIRST:
534
+ p_gv = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
535
+ p_gn = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + (min(i_t * BT + i_i * BC + BC, T) - 1) * V + o_v, BV), BV)
536
+ else:
537
+ p_gv = tl.make_block_ptr(g + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
538
+ p_gn = g + (bos + min(i_t * BT + i_i * BC + BC, T)-1)*H*V + i_h*V + o_v
539
+ # [BV,]
540
+ b_gn = tl.load(p_gn, mask=m_v, other=0)
541
+ # [BC, BV]
542
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
543
+ b_dv = tl.zeros([BC, BV], dtype=tl.float32)
544
+ for i_j in range(i_i + 1, NC):
545
+ if HEAD_FIRST:
546
+ p_g = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
547
+ p_A = tl.make_block_ptr(A + i_bh * T*BT, (BT, T), (1, BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1))
548
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
549
+ else:
550
+ p_g = tl.make_block_ptr(g + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
551
+ p_A = tl.make_block_ptr(A + (bos*HQ+i_hq) * BT, (BT, T), (1, HQ*BT), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1))
552
+ p_do = tl.make_block_ptr(do + (bos*HQ+i_hq) * V, (T, V), (HQ*V, 1), (i_t*BT + i_j*BC, i_v*BV), (BC, BV), (1, 0))
553
+ # [BC, BV]
554
+ b_g = tl.load(p_g, boundary_check=(0, 1))
555
+ b_do = tl.load(p_do, boundary_check=(0, 1)) * safe_exp(b_g - b_gn[None, :])
556
+ # [BC, BC]
557
+ b_A = tl.load(p_A, boundary_check=(0, 1))
558
+ # [BC, BV]
559
+ b_dv += tl.dot(b_A, b_do.to(b_A.dtype))
560
+ b_dv *= exp(b_gn[None, :] - b_gv)
561
+
562
+ o_i = tl.arange(0, BC)
563
+ o_c = i_i * BC + tl.arange(0, BC)
564
+
565
+ if HEAD_FIRST:
566
+ p_g = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + (i_t * BT + i_i * BC) * V + o_v, BV), BV)
567
+ p_A = tl.max_contiguous(tl.multiple_of(A + i_bh * T*BT + (i_t * BT + i_i * BC) * BT + o_c, BC), BC)
568
+ p_do = tl.max_contiguous(tl.multiple_of(do + i_bh * T*V + (i_t * BT + i_i * BC) * V + o_v, BV), BV)
569
+ else:
570
+ p_g = g + (bos + i_t * BT + i_i * BC) * H*V + i_h * V + o_v
571
+ p_A = A + (bos + i_t*BT + i_i*BC) * HQ*BT + i_hq * BT + o_c
572
+ p_do = do + (bos + i_t*BT + i_i*BC) * HQ*V + i_hq * V + o_v
573
+
574
+ for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
575
+ # [BC,]
576
+ b_A = tl.load(p_A)
577
+ # [BV,]
578
+ b_g = tl.load(p_g, mask=m_v, other=0)
579
+ b_do = tl.load(p_do, mask=m_v, other=0)
580
+ # [BC, BV]
581
+ m_i = o_i[:, None] <= j
582
+ b_dv += tl.where(m_i, exp(b_g[None, :] - b_gv) * b_A[:, None] * b_do[None, :], 0.)
583
+
584
+ p_g += (1 if HEAD_FIRST else H) * V
585
+ p_A += (1 if HEAD_FIRST else HQ) * BT
586
+ p_do += (1 if HEAD_FIRST else HQ) * V
587
+ if HEAD_FIRST:
588
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
589
+ p_v = tl.make_block_ptr(v + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
590
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
591
+ p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
592
+ p_dg = tl.make_block_ptr(dg + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
593
+ else:
594
+ p_o = tl.make_block_ptr(o + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
595
+ p_v = tl.make_block_ptr(v + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
596
+ p_do = tl.make_block_ptr(do + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
597
+ p_dv = tl.make_block_ptr(dv + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
598
+ p_dg = tl.make_block_ptr(dg + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
599
+
600
+ b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32)
601
+ b_v = tl.load(p_v, boundary_check=(0, 1)).to(tl.float32)
602
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(tl.float32)
603
+ b_dv = b_dv + tl.load(p_dv, boundary_check=(0, 1)).to(tl.float32)
604
+ b_dg = b_o * b_do - b_v * b_dv
605
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
606
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1))
607
+
608
+
609
+ def chunk_gsa_fwd_v(
610
+ q: torch.Tensor,
611
+ k: torch.Tensor,
612
+ v: torch.Tensor,
613
+ g: torch.Tensor,
614
+ scale: float = 1.,
615
+ initial_state: Optional[torch.Tensor] = None,
616
+ output_final_state: bool = False,
617
+ offsets: Optional[torch.LongTensor] = None,
618
+ indices: Optional[torch.LongTensor] = None,
619
+ head_first: bool = True,
620
+ chunk_size: int = 64
621
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
622
+ _, A, h, ht, o = chunk_gla_fwd(
623
+ q=q,
624
+ k=k,
625
+ v=v,
626
+ g=None,
627
+ g_cumsum=g,
628
+ scale=scale,
629
+ initial_state=initial_state,
630
+ output_final_state=output_final_state,
631
+ offsets=offsets,
632
+ indices=indices,
633
+ head_first=head_first,
634
+ chunk_size=chunk_size
635
+ )
636
+ return A, h, ht, o
637
+
638
+
639
+ def chunk_gsa_fwd_k(
640
+ q: torch.Tensor,
641
+ k: torch.Tensor,
642
+ v: torch.Tensor,
643
+ g: torch.Tensor,
644
+ h0: Optional[torch.Tensor] = None,
645
+ output_final_state: bool = False,
646
+ scale: float = 1.,
647
+ offsets: Optional[torch.LongTensor] = None,
648
+ indices: Optional[torch.LongTensor] = None,
649
+ head_first: bool = True,
650
+ chunk_size: int = 64
651
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
652
+ if head_first:
653
+ B, H, T, K, V = *k.shape, v.shape[-1]
654
+ else:
655
+ B, T, H, K, V = *k.shape, v.shape[-1]
656
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
657
+ BC = min(16, BT)
658
+ BV = min(64, triton.next_power_of_2(V))
659
+ HQ = q.shape[1] if head_first else q.shape[2]
660
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
661
+ NC = triton.cdiv(BT, BC)
662
+ NG = HQ // H
663
+
664
+ h, ht = chunk_fwd_h(
665
+ k=k,
666
+ v=v,
667
+ g=None,
668
+ gk=None,
669
+ gv=g,
670
+ h0=h0,
671
+ output_final_state=output_final_state,
672
+ offsets=offsets,
673
+ head_first=head_first,
674
+ chunk_size=BT,
675
+ states_in_fp32=False
676
+ )
677
+ o = v.new_empty(B, *((HQ, T) if head_first else (T, HQ)), V)
678
+ A = q.new_empty(B, *((HQ, T) if head_first else (T, HQ)), BT)
679
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * HQ)
680
+ chunk_gsa_fwd_k_kernel_inter[grid](
681
+ q,
682
+ k,
683
+ h,
684
+ g,
685
+ o,
686
+ A,
687
+ offsets=offsets,
688
+ indices=indices,
689
+ scale=scale,
690
+ T=T,
691
+ HQ=HQ,
692
+ H=H,
693
+ K=K,
694
+ V=V,
695
+ BT=BT,
696
+ NG=NG,
697
+ HEAD_FIRST=head_first
698
+ )
699
+
700
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT * NC, B * HQ)
701
+ chunk_gsa_fwd_k_kernel_intra[grid](
702
+ v,
703
+ g,
704
+ o,
705
+ A,
706
+ offsets=offsets,
707
+ indices=indices,
708
+ T=T,
709
+ HQ=HQ,
710
+ H=H,
711
+ V=V,
712
+ BT=BT,
713
+ BC=BC,
714
+ BV=BV,
715
+ NC=NC,
716
+ NG=NG,
717
+ HEAD_FIRST=head_first,
718
+ num_warps=4,
719
+ num_stages=2
720
+ )
721
+ return A, h, ht, o
722
+
723
+
724
+ def chunk_gsa_bwd_v(
725
+ q: torch.Tensor,
726
+ k: torch.Tensor,
727
+ v: torch.Tensor,
728
+ g: torch.Tensor,
729
+ h0: torch.Tensor,
730
+ h: torch.Tensor,
731
+ A: torch.Tensor,
732
+ do: torch.Tensor,
733
+ dht: torch.Tensor,
734
+ dg: torch.Tensor,
735
+ scale: float = 1.,
736
+ offsets: Optional[torch.LongTensor] = None,
737
+ indices: Optional[torch.LongTensor] = None,
738
+ head_first: bool = True,
739
+ chunk_size: int = 64
740
+ ):
741
+ dq, dk, dv, dg, dh0 = chunk_gla_bwd(
742
+ q=q,
743
+ k=k,
744
+ v=v,
745
+ g=None,
746
+ g_cumsum=g,
747
+ scale=scale,
748
+ initial_state=h0,
749
+ h=h,
750
+ A=A,
751
+ do=do,
752
+ dht=dht,
753
+ offsets=offsets,
754
+ indices=indices,
755
+ head_first=head_first,
756
+ chunk_size=chunk_size
757
+ )
758
+ return dq, dk, dv, dg, dh0
759
+
760
+
761
+ def chunk_gsa_bwd_k(
762
+ q: torch.Tensor,
763
+ k: torch.Tensor,
764
+ v: torch.Tensor,
765
+ g: torch.Tensor,
766
+ h: torch.Tensor,
767
+ h0: torch.Tensor,
768
+ o: torch.Tensor,
769
+ do: torch.Tensor,
770
+ dht: torch.Tensor,
771
+ dg: torch.Tensor,
772
+ scale: float = 1.,
773
+ offsets: Optional[torch.LongTensor] = None,
774
+ indices: Optional[torch.LongTensor] = None,
775
+ head_first: bool = True,
776
+ chunk_size: int = 64
777
+ ):
778
+ if head_first:
779
+ B, H, T, K, V = *k.shape, v.shape[-1]
780
+ else:
781
+ B, T, H, K, V = *k.shape, v.shape[-1]
782
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
783
+ BC = min(16, BT)
784
+ BK = min(64, triton.next_power_of_2(K))
785
+ BV = min(64, triton.next_power_of_2(V))
786
+ HQ = q.shape[1] if head_first else q.shape[2]
787
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
788
+ NC = triton.cdiv(BT, BC)
789
+ NK = triton.cdiv(K, BK)
790
+ NV = triton.cdiv(V, BV)
791
+ NG = HQ // H
792
+
793
+ if h is None:
794
+ h, _ = chunk_fwd_h(
795
+ k=k,
796
+ v=v,
797
+ g=None,
798
+ gk=None,
799
+ gv=g,
800
+ h0=h0,
801
+ output_final_state=False,
802
+ offsets=offsets,
803
+ head_first=head_first,
804
+ chunk_size=BT,
805
+ states_in_fp32=False
806
+ )
807
+ dh, dh0 = chunk_bwd_dh(
808
+ q=q,
809
+ k=k,
810
+ v=v,
811
+ g=None,
812
+ gk=None,
813
+ gv=g,
814
+ do=do,
815
+ h0=h0,
816
+ dht=dht,
817
+ scale=scale,
818
+ offsets=offsets,
819
+ head_first=head_first,
820
+ chunk_size=BT,
821
+ states_in_fp32=True
822
+ )
823
+ dA = q.new_empty(NV, B, *((HQ, T) if head_first else (T, HQ)), BT)
824
+ grid = (NV, NT * NC * NC, B * HQ)
825
+ chunk_gsa_bwd_k_kernel_dA[grid](
826
+ v,
827
+ g,
828
+ do,
829
+ dA,
830
+ offsets=offsets,
831
+ indices=indices,
832
+ scale=scale,
833
+ T=T,
834
+ B=B,
835
+ HQ=HQ,
836
+ H=H,
837
+ V=V,
838
+ BT=BT,
839
+ BC=BC,
840
+ BV=BV,
841
+ NC=NC,
842
+ NG=NG,
843
+ HEAD_FIRST=head_first
844
+ )
845
+ dA = dA.sum(0, dtype=dA.dtype)
846
+
847
+ A = do.new_empty(NK, B, *((HQ, T) if head_first else (T, HQ)), BT)
848
+ dq = torch.empty_like(q)
849
+ dk = k.new_empty(B, *((HQ, T) if head_first else (T, HQ)), K)
850
+ dv = v.new_empty(NK, B, *((HQ, T) if head_first else (T, HQ)), V)
851
+ dgv = g.new_empty(NK, B, *((HQ, T) if head_first else (T, HQ)), V, dtype=torch.float)
852
+ grid = (NK, NT, B * HQ)
853
+ chunk_gsa_bwd_k_kernel_dqkvg[grid](
854
+ q,
855
+ k,
856
+ v,
857
+ h,
858
+ g,
859
+ A,
860
+ do,
861
+ dh,
862
+ dq,
863
+ dk,
864
+ dv,
865
+ dg,
866
+ dgv,
867
+ dA,
868
+ offsets=offsets,
869
+ indices=indices,
870
+ scale=scale,
871
+ T=T,
872
+ B=B,
873
+ HQ=HQ,
874
+ H=H,
875
+ K=K,
876
+ V=V,
877
+ BT=BT,
878
+ BK=BK,
879
+ BV=BV,
880
+ NG=NG,
881
+ HEAD_FIRST=head_first
882
+ )
883
+ A = A.sum(0, dtype=A.dtype)
884
+ dv = dv.sum(0, dtype=dv.dtype)
885
+ dgv = dgv.sum(0, dtype=dgv.dtype)
886
+
887
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT * NC, B * HQ)
888
+ chunk_gsa_bwd_k_kernel_intra_dvg[grid](
889
+ v,
890
+ g,
891
+ o,
892
+ A,
893
+ do,
894
+ dv,
895
+ dg,
896
+ offsets=offsets,
897
+ indices=indices,
898
+ T=T,
899
+ HQ=HQ,
900
+ H=H,
901
+ V=V,
902
+ BT=BT,
903
+ BC=BC,
904
+ BV=BV,
905
+ NC=NC,
906
+ NG=NG,
907
+ HEAD_FIRST=head_first,
908
+ num_warps=4,
909
+ num_stages=2
910
+ )
911
+ dg = dgv.add_(chunk_local_cumsum(dg, chunk_size=BT, reverse=True, offsets=offsets, indices=indices, head_first=head_first))
912
+
913
+ return dq, dk, dv, dg, dh0
914
+
915
+
916
+ def chunk_gsa_fwd(
917
+ q: torch.Tensor,
918
+ k: torch.Tensor,
919
+ v: torch.Tensor,
920
+ s: torch.Tensor,
921
+ g: torch.Tensor,
922
+ initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
923
+ output_final_state: bool = False,
924
+ scale: float = 1.,
925
+ offsets: Optional[torch.LongTensor] = None,
926
+ indices: Optional[torch.LongTensor] = None,
927
+ head_first: bool = True,
928
+ chunk_size: int = 64
929
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
930
+ hk0, hv0 = None, None
931
+ if initial_state is not None:
932
+ hk0, hv0 = initial_state
933
+ Ak, hk, hkt, ok = chunk_gsa_fwd_k(
934
+ q=q,
935
+ k=k,
936
+ v=s,
937
+ g=g,
938
+ h0=hk0,
939
+ output_final_state=output_final_state,
940
+ scale=scale,
941
+ offsets=offsets,
942
+ indices=indices,
943
+ head_first=head_first,
944
+ chunk_size=chunk_size
945
+ )
946
+
947
+ # p is kept in fp32 for safe softmax backward
948
+ p = softmax_fwd(ok, dtype=torch.float)
949
+
950
+ qv = p.to(q.dtype)
951
+ Av, hv, hvt, ov = chunk_gsa_fwd_v(
952
+ q=qv,
953
+ k=s,
954
+ v=v,
955
+ g=g,
956
+ scale=1.,
957
+ initial_state=hv0,
958
+ output_final_state=output_final_state,
959
+ offsets=offsets,
960
+ indices=indices,
961
+ head_first=head_first,
962
+ chunk_size=chunk_size
963
+ )
964
+ return Ak, hk, hkt, ok, p, Av, hv, hvt, ov
965
+
966
+
967
+ def chunk_gsa_bwd(
968
+ q: torch.Tensor,
969
+ k: torch.Tensor,
970
+ v: torch.Tensor,
971
+ s: torch.Tensor,
972
+ g: torch.Tensor,
973
+ ok: torch.Tensor,
974
+ p: torch.Tensor,
975
+ A: Tuple[torch.Tensor, torch.Tensor],
976
+ h: Tuple[torch.Tensor, torch.Tensor],
977
+ initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]],
978
+ scale: float,
979
+ do: torch.Tensor,
980
+ dht: Tuple[torch.Tensor, torch.Tensor],
981
+ offsets: Optional[torch.LongTensor] = None,
982
+ indices: Optional[torch.LongTensor] = None,
983
+ head_first: bool = True,
984
+ chunk_size: int = 64
985
+ ):
986
+ hk0, hv0 = None, None
987
+ if initial_state is not None:
988
+ hk0, hv0 = initial_state
989
+
990
+ _, Av = A
991
+ hk, hv = h
992
+ dhkt, dhvt = dht
993
+
994
+ qv = p.to(q.dtype)
995
+ dqv, dsv, dv, dg, dhv0 = chunk_gsa_bwd_v(
996
+ q=qv,
997
+ k=s,
998
+ v=v,
999
+ g=g,
1000
+ h0=hv0,
1001
+ h=hv,
1002
+ A=Av,
1003
+ do=do,
1004
+ dht=dhvt,
1005
+ dg=None,
1006
+ scale=1.,
1007
+ offsets=offsets,
1008
+ indices=indices,
1009
+ head_first=head_first,
1010
+ chunk_size=chunk_size
1011
+ )
1012
+
1013
+ # softmax gradient, equivalent to:
1014
+ # dok = qv * (dqv - (qv * dqv).sum(-1, True))
1015
+ dok = softmax_bwd(p, dqv, dtype=ok.dtype)
1016
+
1017
+ dq, dk, dsk, dg, dhk0 = chunk_gsa_bwd_k(
1018
+ q=q,
1019
+ k=k,
1020
+ v=s,
1021
+ g=g,
1022
+ h0=hk0,
1023
+ h=hk,
1024
+ o=ok,
1025
+ do=dok,
1026
+ dht=dhkt,
1027
+ dg=dg,
1028
+ scale=scale,
1029
+ offsets=offsets,
1030
+ indices=indices,
1031
+ head_first=head_first,
1032
+ chunk_size=chunk_size
1033
+ )
1034
+
1035
+ ds = dsv.add_(dsk)
1036
+ if q.shape[1] != k.shape[1]:
1037
+ dk, dv, ds, dg = map(lambda x: reduce(x, 'b (h g) ... -> b h ...', 'sum', h=k.shape[1]), (dk, dv, ds, dg))
1038
+ dg = dg.to(s.dtype)
1039
+ return dq, dk, dv, ds, dg, dhk0, dhv0
1040
+
1041
+
1042
+ class ChunkGSAFunction(torch.autograd.Function):
1043
+
1044
+ @staticmethod
1045
+ @input_guard
1046
+ def forward(
1047
+ ctx,
1048
+ q: torch.Tensor,
1049
+ k: torch.Tensor,
1050
+ v: torch.Tensor,
1051
+ s: torch.Tensor,
1052
+ g: torch.Tensor,
1053
+ scale: float,
1054
+ hk0: Optional[torch.Tensor],
1055
+ hv0: Optional[torch.Tensor],
1056
+ output_final_state: bool,
1057
+ checkpoint_level: int,
1058
+ offsets: Optional[torch.LongTensor],
1059
+ head_first: bool = True
1060
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1061
+ T = q.shape[2] if head_first else q.shape[1]
1062
+ chunk_size = min(64, max(16, triton.next_power_of_2(T)))
1063
+
1064
+ # 2-d indices denoting the offsets of chunks in each sequence
1065
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
1066
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
1067
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
1068
+ indices = None
1069
+ if offsets is not None:
1070
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], chunk_size).tolist()])
1071
+ indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets)
1072
+ g_org, g = g, chunk_local_cumsum(g, chunk_size, offsets=offsets, indices=indices, head_first=head_first)
1073
+ Ak, hk, hkt, ok, p, Av, hv, hvt, ov = chunk_gsa_fwd(
1074
+ q=q,
1075
+ k=k,
1076
+ v=v,
1077
+ s=s,
1078
+ g=g,
1079
+ initial_state=(hk0, hv0),
1080
+ output_final_state=output_final_state,
1081
+ scale=scale,
1082
+ offsets=offsets,
1083
+ indices=indices,
1084
+ head_first=head_first,
1085
+ chunk_size=chunk_size
1086
+ )
1087
+
1088
+ if checkpoint_level >= 1:
1089
+ del g
1090
+ g = g_org
1091
+ if checkpoint_level > 1:
1092
+ del hk
1093
+ del hv
1094
+ hk, hv = None, None
1095
+ else:
1096
+ hk0, hv0 = None, None
1097
+
1098
+ ctx.save_for_backward(q, k, v, s, g, ok, p, Av, hk0, hv0, hk, hv)
1099
+ ctx.checkpoint_level = checkpoint_level
1100
+ ctx.scale = scale
1101
+ ctx.offsets = offsets
1102
+ ctx.indices = indices
1103
+ ctx.head_first = head_first
1104
+ ctx.chunk_size = chunk_size
1105
+ return ov, hkt, hvt
1106
+
1107
+ @staticmethod
1108
+ @input_guard
1109
+ def backward(ctx, dov, dhkt=None, dhvt=None):
1110
+ q, k, v, s, g, ok, p, Av, hk0, hv0, hk, hv = ctx.saved_tensors
1111
+ scale = ctx.scale
1112
+ offsets = ctx.offsets
1113
+ indices = ctx.indices
1114
+ head_first = ctx.head_first
1115
+ chunk_size = ctx.chunk_size
1116
+
1117
+ if ctx.checkpoint_level >= 1:
1118
+ g = chunk_local_cumsum(g, chunk_size, offsets=offsets, indices=indices, head_first=head_first)
1119
+ dq, dk, dv, ds, dg, dhk0, dhv0 = chunk_gsa_bwd(
1120
+ q=q,
1121
+ k=k,
1122
+ v=v,
1123
+ s=s,
1124
+ g=g,
1125
+ ok=ok,
1126
+ p=p,
1127
+ A=(None, Av),
1128
+ h=(hk, hv),
1129
+ initial_state=(hk0, hv0),
1130
+ scale=scale,
1131
+ do=dov,
1132
+ dht=(dhkt, dhvt),
1133
+ offsets=offsets,
1134
+ indices=indices,
1135
+ head_first=head_first,
1136
+ chunk_size=chunk_size
1137
+ )
1138
+ return dq, dk, dv, ds, dg, None, dhk0, dhv0, None, None, None, None
1139
+
1140
+
1141
+ @torch.compiler.disable
1142
+ def chunk_gsa(
1143
+ q: torch.Tensor,
1144
+ k: torch.Tensor,
1145
+ v: torch.Tensor,
1146
+ s: torch.Tensor,
1147
+ g: Optional[torch.Tensor] = None,
1148
+ scale: Optional[int] = None,
1149
+ initial_state: Optional[Tuple[torch.Tensor]] = None,
1150
+ output_final_state: Optional[bool] = False,
1151
+ checkpoint_level: Optional[int] = 2,
1152
+ cu_seqlens: Optional[torch.LongTensor] = None,
1153
+ head_first: Optional[bool] = True
1154
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1155
+ r"""
1156
+ Args:
1157
+ q (torch.Tensor):
1158
+ queries of shape `[B, HQ, T, K]` if `head_first=True` else `[B, T, HQ, K]`.
1159
+ k (torch.Tensor):
1160
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
1161
+ GQA is performed if `H` is not equal to `HQ`.
1162
+ v (torch.Tensor):
1163
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
1164
+ s (torch.Tensor):
1165
+ slot representations of shape `[B, H, T, M]` if `head_first=True` else `[B, T, H, M]`.
1166
+ g (torch.Tensor):
1167
+ Forget gates of shape `[B, H, T, M]` applied to keys.
1168
+ If not provided, this function is equivalent to vanilla ABC.
1169
+ scale (Optional[int]):
1170
+ Scale factor for attention scores.
1171
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
1172
+ initial_state (Optional[Tuple[torch.Tensor]]):
1173
+ Initial state tuple having tensors of shape `[N, H, K, M]` and `[N, H, M, V]` for `N` input sequences.
1174
+ For equal-length input sequences, `N` equals the batch size `B`.
1175
+ Default: `None`.
1176
+ output_final_state (Optional[bool]):
1177
+ Whether to output the final state tuple, having tensors of shape `[N, H, K, M]` and `[N, H, M, V]`.
1178
+ Default: `False`.
1179
+ checkpoint_level (Optional[int]):
1180
+ Checkpointing level; higher values will save more memories and do more recomputations during backward.
1181
+ Default: `2`:
1182
+ - Level `0`: no memory saved, no recomputation.
1183
+ - Level `1`: recompute the fp32 cumulative values during backward.
1184
+ - Level `2`: recompute the fp32 cumulative values and forward hidden states during backward.
1185
+ cu_seqlens (torch.LongTensor):
1186
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
1187
+ consistent with the FlashAttention API.
1188
+ head_first (Optional[bool]):
1189
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
1190
+ Default: `True`.
1191
+
1192
+ Returns:
1193
+ o (torch.Tensor):
1194
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
1195
+ final_state (Tuple[torch.Tensor]):
1196
+ Final state tuple having tensors of shape `[N, H, K, M]` and `[N, H, M, V]` if `output_final_state=True`.
1197
+ `None` otherwise.
1198
+
1199
+ Examples::
1200
+ >>> import torch
1201
+ >>> import torch.nn.functional as F
1202
+ >>> from einops import rearrange
1203
+ >>> from fla.ops.gsa import fused_recurrent_gsa
1204
+ # inputs with equal lengths
1205
+ >>> B, T, H, K, V, M = 4, 2048, 4, 512, 512, 64
1206
+ >>> q = torch.randn(B, T, H, K, device='cuda')
1207
+ >>> k = torch.randn(B, T, H, K, device='cuda')
1208
+ >>> v = torch.randn(B, T, H, V, device='cuda')
1209
+ >>> s = torch.randn(B, T, H, M, device='cuda')
1210
+ >>> g = F.logsigmoid(torch.randn(B, T, H, M, device='cuda'))
1211
+ >>> h0 = (torch.randn(B, H, K, M, device='cuda'), torch.randn(B, H, M, V, device='cuda'))
1212
+ >>> o, (hk, hv) = chunk_gsa(q, k, v, s, g,
1213
+ initial_state=h0,
1214
+ output_final_state=True,
1215
+ head_first=False)
1216
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
1217
+ >>> q, k, v, s, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, s, g))
1218
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
1219
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
1220
+ >>> o_var, (hk_var, hv_var) = chunk_gsa(q, k, v, s, g,
1221
+ initial_state=h0,
1222
+ output_final_state=True,
1223
+ cu_seqlens=cu_seqlens,
1224
+ head_first=False)
1225
+ >>> assert o.allclose(o_var.view(o.shape))
1226
+ >>> assert hk.allclose(hk_var)
1227
+ >>> assert hv.allclose(hv_var)
1228
+ """
1229
+ if cu_seqlens is not None:
1230
+ if q.shape[0] != 1:
1231
+ raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
1232
+ f"Please flatten variable-length inputs before processing.")
1233
+ if head_first:
1234
+ raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
1235
+ if initial_state is not None and initial_state[0].shape[0] != len(cu_seqlens) - 1:
1236
+ raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, "
1237
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state[0].shape[0]}.")
1238
+ assert checkpoint_level in [0, 1, 2]
1239
+ if g is None:
1240
+ # TODO: this 3 steps took huge amount of time, ought to be optimized
1241
+ z = s.float().logcumsumexp(2)
1242
+ g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z
1243
+ s = torch.exp(s - z).to(k.dtype)
1244
+ if scale is None:
1245
+ scale = q.shape[-1] ** -0.5
1246
+
1247
+ hk0, hv0 = None, None
1248
+ if initial_state is not None:
1249
+ hk0, hv0 = initial_state
1250
+ o, *final_state = ChunkGSAFunction.apply(
1251
+ q,
1252
+ k,
1253
+ v,
1254
+ s,
1255
+ g,
1256
+ scale,
1257
+ hk0,
1258
+ hv0,
1259
+ output_final_state,
1260
+ checkpoint_level,
1261
+ cu_seqlens,
1262
+ head_first
1263
+ )
1264
+ return o, final_state
fla/ops/gsa/naive.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from einops import repeat
7
+
8
+
9
+ def naive_recurrent_gsa(
10
+ q: torch.Tensor,
11
+ k: torch.Tensor,
12
+ v: torch.Tensor,
13
+ s: torch.Tensor,
14
+ g: Optional[torch.Tensor] = None,
15
+ scale: Optional[int] = None,
16
+ initial_state: Optional[torch.Tensor] = None,
17
+ output_final_state: Optional[bool] = False
18
+ ) -> torch.Tensor:
19
+ dtype = q.dtype
20
+
21
+ NG = q.shape[1]//k.shape[1]
22
+ # [batch_size, n_heads, seq_len, n_slots]
23
+ if g is None:
24
+ z = s.float().logcumsumexp(2)
25
+ g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z
26
+ s = torch.exp(s - z)
27
+ q, k, v, s, g = map(lambda x: x.float(), (q, k, v, s, g))
28
+ k, v, s, g = map(lambda x: repeat(x, 'b h t d -> b (h g) t d', g=NG), (k, v, s, g))
29
+ if initial_state is not None:
30
+ initial_state = tuple(map(lambda x: repeat(x, 'b h k v -> b (h g) k v', g=NG), initial_state))
31
+
32
+ B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
33
+
34
+ hk = torch.zeros(B, H, K, M, dtype=torch.float, device=q.device)
35
+ ok = torch.zeros_like(s)
36
+
37
+ if scale is None:
38
+ scale = q.shape[-1] ** -0.5
39
+
40
+ final_state = None
41
+ if initial_state is not None:
42
+ hk += initial_state[0]
43
+
44
+ for i in range(T):
45
+ q_i = q[:, :, i] * scale
46
+ k_i = k[:, :, i]
47
+ v_i = s[:, :, i]
48
+ g_i = g[:, :, i].exp()
49
+ hk = hk * g_i[..., None, :] + k_i[..., None] * v_i[..., None, :]
50
+ ok[:, :, i] = (q_i[..., None] * hk).sum(-2)
51
+
52
+ qv = ok.softmax(-1)
53
+ hv = torch.zeros(B, H, M, V, dtype=torch.float, device=q.device)
54
+ ov = torch.zeros_like(v)
55
+ if initial_state is not None:
56
+ hv += initial_state[1]
57
+
58
+ for i in range(T):
59
+ q_i = qv[:, :, i]
60
+ k_i = s[:, :, i]
61
+ v_i = v[:, :, i]
62
+ g_i = g[:, :, i].exp()
63
+ hv = hv * g_i[..., :, None] + k_i[..., None] * v_i[..., None, :]
64
+ ov[:, :, i] = (q_i[..., None] * hv).sum(-2)
65
+
66
+ if output_final_state:
67
+ final_state = (hk.view(B, -1, NG, K, M)[:, :, 0], hv.view(B, -1, NG, M, V)[:, :, 0])
68
+ return ov.to(dtype), final_state
fla/ops/hgrn/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (285 Bytes). View file
 
fla/ops/hgrn/fused_recurrent.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import exp
11
+ from fla.utils import input_guard
12
+
13
+
14
+ @triton.heuristics({
15
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
16
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
17
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
18
+ })
19
+ @triton.autotune(
20
+ configs=[
21
+ triton.Config({'BD': BD}, num_warps=num_warps)
22
+ for BD in [32, 64, 128]
23
+ for num_warps in [1, 2, 4, 8]
24
+ ],
25
+ key=['D']
26
+ )
27
+ @triton.jit(do_not_specialize=['T'])
28
+ def fused_recurrent_hgrn_fwd_kernel(
29
+ x,
30
+ g,
31
+ o,
32
+ h0,
33
+ ht,
34
+ offsets,
35
+ T,
36
+ D: tl.constexpr,
37
+ BD: tl.constexpr,
38
+ USE_INITIAL_STATE: tl.constexpr,
39
+ STORE_FINAL_STATE: tl.constexpr,
40
+ USE_OFFSETS: tl.constexpr
41
+ ):
42
+ i_d, i_n = tl.program_id(0), tl.program_id(1)
43
+ if USE_OFFSETS:
44
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
45
+ T = eos - bos
46
+ else:
47
+ bos, eos = i_n * T, i_n * T + T
48
+
49
+ o_d = i_d * BD + tl.arange(0, BD)
50
+ mask = o_d < D
51
+
52
+ p_x = x + bos * D + o_d
53
+ p_g = g + bos * D + o_d
54
+ p_o = o + bos * D + o_d
55
+
56
+ b_h = tl.zeros([BD], dtype=tl.float32)
57
+ if USE_INITIAL_STATE:
58
+ p_h0 = h0 + i_n * D + o_d
59
+ b_h += tl.load(p_h0, mask=mask, other=0).to(tl.float32)
60
+ for _ in range(0, T):
61
+ b_x = tl.load(p_x, mask=mask, other=0).to(tl.float32)
62
+ b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
63
+ b_h = exp(b_g) * b_h + b_x
64
+ tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask)
65
+
66
+ p_x += D
67
+ p_g += D
68
+ p_o += D
69
+
70
+ if STORE_FINAL_STATE:
71
+ p_ht = ht + i_n * D + o_d
72
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask)
73
+
74
+
75
+ @triton.heuristics({
76
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
77
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
78
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
79
+ })
80
+ @triton.autotune(
81
+ configs=[
82
+ triton.Config({'BD': BD}, num_warps=num_warps)
83
+ for BD in [32, 64, 128]
84
+ for num_warps in [1, 2, 4, 8]
85
+ ],
86
+ key=['D']
87
+ )
88
+ @triton.jit(do_not_specialize=['T'])
89
+ def fused_recurrent_hgrn_bwd_kernel(
90
+ g,
91
+ o,
92
+ h0,
93
+ dx,
94
+ dg,
95
+ do,
96
+ dht,
97
+ dh0,
98
+ offsets,
99
+ T,
100
+ D: tl.constexpr,
101
+ BD: tl.constexpr,
102
+ USE_INITIAL_STATE: tl.constexpr,
103
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
104
+ USE_OFFSETS: tl.constexpr
105
+ ):
106
+ i_d, i_n = tl.program_id(0), tl.program_id(1)
107
+ if USE_OFFSETS:
108
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
109
+ T = eos - bos
110
+ else:
111
+ bos, eos = i_n * T, i_n * T + T
112
+
113
+ o_d = i_d * BD + tl.arange(0, BD)
114
+ mask = o_d < D
115
+
116
+ p_g = g + (bos + T - 1) * D + o_d
117
+ p_o = o + (bos + T - 2) * D + o_d
118
+ p_dx = dx + (bos + T - 1) * D + o_d
119
+ p_dg = dg + (bos + T - 1) * D + o_d
120
+ p_do = do + (bos + T - 1) * D + o_d
121
+
122
+ b_dh = tl.zeros([BD], dtype=tl.float32)
123
+ if USE_FINAL_STATE_GRADIENT:
124
+ p_dht = dht + i_n * D + o_d
125
+ b_dh += tl.load(p_dht, mask=mask, other=0).to(tl.float32)
126
+
127
+ for i in range(T - 1, -1, -1):
128
+ b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
129
+ b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32)
130
+ if i > 0:
131
+ b_o = tl.load(p_o, mask=mask, other=0).to(tl.float32)
132
+ elif USE_INITIAL_STATE:
133
+ b_o = tl.load(h0 + i_n * D + o_d, mask=mask, other=0).to(tl.float32)
134
+ else:
135
+ b_o = tl.zeros([BD], dtype=tl.float32)
136
+
137
+ b_dh = b_dh + b_do
138
+ b_dx = b_dh
139
+ b_dh = b_dh * exp(b_g)
140
+ b_dg = b_dh * b_o
141
+ tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)
142
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), mask=mask)
143
+
144
+ p_g -= D
145
+ p_o -= D
146
+ p_dx -= D
147
+ p_dg -= D
148
+ p_do -= D
149
+
150
+ if USE_INITIAL_STATE:
151
+ p_dh0 = dh0 + i_n * D + o_d
152
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask)
153
+
154
+
155
+ def fused_recurrent_hgrn_fwd(
156
+ x: torch.Tensor,
157
+ g: torch.Tensor,
158
+ initial_state: torch.Tensor = None,
159
+ output_final_state: bool = False,
160
+ offsets: Optional[torch.LongTensor] = None,
161
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
162
+ B, T, D = x.shape
163
+ N = B if offsets is None else len(offsets) - 1
164
+
165
+ o = torch.empty_like(x)
166
+ final_state = x.new_empty(N, D) if output_final_state else None
167
+
168
+ def grid(meta): return (triton.cdiv(D, meta['BD']), N)
169
+ fused_recurrent_hgrn_fwd_kernel[grid](
170
+ x=x,
171
+ g=g,
172
+ o=o,
173
+ h0=initial_state,
174
+ ht=final_state,
175
+ offsets=offsets,
176
+ T=T,
177
+ D=D
178
+ )
179
+ return o, final_state
180
+
181
+
182
+ def fused_recurrent_hgrn_bwd(
183
+ g: torch.Tensor,
184
+ o: torch.Tensor,
185
+ do: torch.Tensor,
186
+ dht: torch.Tensor = None,
187
+ initial_state: torch.Tensor = None,
188
+ offsets: Optional[torch.LongTensor] = None
189
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
190
+ B, T, D = do.shape
191
+ N = B if offsets is None else len(offsets) - 1
192
+
193
+ dx = torch.empty_like(o, dtype=torch.float)
194
+ dg = torch.empty_like(g, dtype=torch.float)
195
+ dh0 = torch.empty_like(initial_state, dtype=torch.float) if initial_state is not None else None
196
+ def grid(meta): return (triton.cdiv(D, meta['BD']), N)
197
+ fused_recurrent_hgrn_bwd_kernel[grid](
198
+ g=g,
199
+ o=o,
200
+ h0=initial_state,
201
+ dx=dx,
202
+ dg=dg,
203
+ do=do,
204
+ dht=dht,
205
+ dh0=dh0,
206
+ offsets=offsets,
207
+ T=T,
208
+ D=D
209
+ )
210
+ return dx, dg, dh0
211
+
212
+
213
+ class FusedRecurrentHGRNFunction(torch.autograd.Function):
214
+
215
+ @staticmethod
216
+ @input_guard
217
+ def forward(
218
+ ctx,
219
+ x: torch.Tensor,
220
+ g: torch.Tensor,
221
+ initial_state: torch.Tensor = None,
222
+ output_final_state: bool = False,
223
+ offsets: Optional[torch.LongTensor] = None
224
+ ):
225
+ o, ht = fused_recurrent_hgrn_fwd(
226
+ x=x,
227
+ g=g,
228
+ initial_state=initial_state,
229
+ output_final_state=output_final_state,
230
+ offsets=offsets
231
+ )
232
+ ctx.save_for_backward(g, o, initial_state)
233
+ ctx.offsets = offsets
234
+ return o, ht
235
+
236
+ @staticmethod
237
+ @input_guard
238
+ def backward(ctx, do, dht=None):
239
+ g, o, initial_state = ctx.saved_tensors
240
+ offsets = ctx.offsets
241
+
242
+ dx, dg, dh0 = fused_recurrent_hgrn_bwd(
243
+ g=g,
244
+ o=o,
245
+ do=do,
246
+ dht=dht,
247
+ initial_state=initial_state,
248
+ offsets=offsets
249
+ )
250
+ return dx, dg, dh0, None, None
251
+
252
+
253
+ @torch.compiler.disable
254
+ def fused_recurrent_hgrn(
255
+ x: torch.Tensor,
256
+ g: torch.Tensor,
257
+ initial_state: torch.Tensor = None,
258
+ output_final_state: bool = False,
259
+ cu_seqlens: Optional[torch.LongTensor] = None,
260
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
261
+ r"""
262
+ Args:
263
+ x (torch.Tensor):
264
+ inputs of shape `[B, T, D].
265
+ g (torch.Tensor):
266
+ Forget gates of shape `[B, T, D]`.
267
+ initial_state (Optional[torch.Tensor]):
268
+ Initial state of shape `[N, D]` for `N` input sequences.
269
+ For equal-length input sequences, `N` equals the batch size `B`.
270
+ Default: `None`.
271
+ output_final_state (Optional[bool]):
272
+ Whether to output the final state of shape `[N, D]`. Default: `False`.
273
+ cu_seqlens (torch.LongTensor):
274
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
275
+ consistent with the FlashAttention API.
276
+
277
+ Returns:
278
+ o (torch.Tensor):
279
+ Outputs of shape `[B, T, D]`.
280
+ final_state (torch.Tensor):
281
+ Final state of shape `[N, D]` if `output_final_state=True` else `None`.
282
+
283
+ Examples::
284
+ >>> import torch
285
+ >>> import torch.nn.functional as F
286
+ >>> from einops import rearrange
287
+ >>> from fla.ops.hgrn import fused_recurrent_hgrn
288
+ # inputs with equal lengths
289
+ >>> B, T, D = 4, 2048, 512
290
+ >>> x = torch.randn(B, T, D, device='cuda')
291
+ >>> g = F.logsigmoid(torch.randn(B, T, D, device='cuda'))
292
+ >>> h0 = torch.randn(B, D, device='cuda')
293
+ >>> o, ht = fused_recurrent_hgrn(x, g, initial_state=h0, output_final_state=True)
294
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
295
+ >>> x, g = map(lambda x: rearrange(x, 'b t d -> 1 (b t) d'), (x, g))
296
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
297
+ >>> cu_seqlens = x.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
298
+ >>> o_var, ht_var = fused_recurrent_hgrn(x, g, initial_state=h0, output_final_state=True, cu_seqlens=cu_seqlens)
299
+ >>> assert o.allclose(o_var.view(o.shape))
300
+ >>> assert ht.allclose(ht_var)
301
+ """
302
+ return FusedRecurrentHGRNFunction.apply(
303
+ x,
304
+ g,
305
+ initial_state,
306
+ output_final_state,
307
+ cu_seqlens
308
+ )
fla/ops/hgrn/naive.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+
7
+
8
+ def naive_recurrent_hgrn(
9
+ x: torch.Tensor,
10
+ g: torch.Tensor,
11
+ initial_state: Optional[torch.Tensor] = None,
12
+ output_final_state: Optional[bool] = False
13
+ ) -> torch.Tensor:
14
+ dtype = x.dtype
15
+ x, g = map(lambda i: i.float(), (x, g))
16
+ B, T, D = x.shape
17
+
18
+ h = torch.zeros(B, D, dtype=torch.float, device=x.device)
19
+ o = torch.zeros_like(x)
20
+
21
+ final_state = None
22
+ if initial_state is not None:
23
+ h += initial_state
24
+
25
+ for i in range(T):
26
+ h = g[:, i].exp() * h + x[:, i]
27
+ o[:, i] = h
28
+
29
+ if output_final_state:
30
+ final_state = h
31
+ return o.to(dtype), final_state
32
+
33
+
34
+ def naive_chunk_hgrn(
35
+ x: torch.Tensor,
36
+ g: torch.Tensor,
37
+ initial_state: Optional[torch.Tensor] = None,
38
+ output_final_state: Optional[bool] = False,
39
+ chunk_size: int = 64
40
+ ) -> torch.Tensor:
41
+ dtype = x.dtype
42
+ x, g = map(lambda i: i.float(), (x, g))
43
+ B, T, D = x.shape
44
+
45
+ gc = g.view(B, chunk_size, D).cumsum(-2).view_as(g)
46
+ h = torch.zeros(B, D, dtype=torch.float, device=x.device)
47
+ o = torch.zeros_like(x)
48
+
49
+ final_state = None
50
+ if initial_state is not None:
51
+ h += initial_state
52
+
53
+ for i in range(0, T, chunk_size):
54
+ hp = h
55
+ h = torch.zeros(B, D, dtype=torch.float, device=x.device)
56
+ for j in range(i, i + chunk_size):
57
+ h = g[:, j].exp() * h + x[:, j]
58
+ o[:, j] = hp * gc[:, j].exp() + h
59
+ h = o[:, j].clone()
60
+
61
+ if output_final_state:
62
+ final_state = h
63
+ return o.to(dtype), final_state
fla/ops/lightning_attn/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (3.75 kB). View file
 
fla/ops/linear_attn/__pycache__/fused_chunk.cpython-312.pyc ADDED
Binary file (18.5 kB). View file
 
fla/ops/linear_attn/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (13.9 kB). View file
 
fla/ops/nsa/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (268 Bytes). View file