zaydzuhri's picture
Add files using upload-large-folder tool
f72219a verified
raw
history blame
17.3 kB
# -*- coding: utf-8 -*-
# Copyright (c) 2024-2025, Songlin Yang, Yu Zhang
from typing import Tuple
import torch
import triton
import triton.language as tl
from fla.utils import input_guard
@triton.heuristics({
'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
'USE_OFFSETS': lambda args: args['offsets'] is not None
})
@triton.autotune(
configs=[
triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages)
for BV in [32, 64]
for num_warps in [2, 4, 8, 16]
for num_stages in [2, 3, 4]
],
key=["BK"],
)
@triton.jit
def fused_recurrent_fwd_kernel(
q, # query [B, H, L, K]
k, # key [B, H, L, V]
v, # value [B, H, L, V].
a, # a [B, H, L, K]
b, # b [B, H, L, K]
o, # output [B, H, L, V]
ha, # tmp variable [B, H, L, V] for storing intermediate results of (h * a[None, :]).sum(0)
h0, # initial hidden state [B, H, K, V]
ht, # final hidden state [B, H, K, V]
offsets, # varlen offsets
scale, # K ** -0.5
H, # n_heads
T, # seq_len
K: tl.constexpr, # K
V: tl.constexpr, # V
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
STORE_FINAL_STATE: tl.constexpr, # whether to store final state
USE_OFFSETS: tl.constexpr,
HEAD_FIRST: tl.constexpr
):
# indices
i_v, i_nh = tl.program_id(0), tl.program_id(1)
i_n, i_h = i_nh // H, i_nh % H
if USE_OFFSETS:
bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
T = eos - bos
else:
bos, eos = i_n * T, i_n * T + T
if HEAD_FIRST:
p_q = q + i_nh * T*K + tl.arange(0, BK)
p_k = k + i_nh * T*K + tl.arange(0, BK)
p_a = a + i_nh * T*K + tl.arange(0, BK)
p_b = b + i_nh * T*K + tl.arange(0, BK)
p_o = o + i_nh * T*V + i_v * BV + tl.arange(0, BV)
p_v = v + i_nh * T*V + i_v * BV + tl.arange(0, BV)
p_ha = ha + i_nh * T*V + i_v * BV + tl.arange(0, BV)
else:
p_q = q + (bos * H + i_h) * K + tl.arange(0, BK)
p_k = k + (bos * H + i_h) * K + tl.arange(0, BK)
p_a = a + (bos * H + i_h) * K + tl.arange(0, BK)
p_b = b + (bos * H + i_h) * K + tl.arange(0, BK)
p_ha = ha + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
p_v = v + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
p_o = o + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
mask_k = tl.arange(0, BK) < K
mask_v = (i_v * BV + tl.arange(0, BV)) < V
mask_h = mask_k[None, :] & mask_v[:, None]
b_h = tl.zeros([BV, BK], dtype=tl.float32)
if USE_INITIAL_STATE:
p_h0 = h0 + i_nh * K * V + (tl.arange(0, BK)[None, :]) * V + ((i_v * BV + tl.arange(0, BV))[:, None])
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
for _ in range(0, T):
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32)
b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32)
# to store
tmp = tl.sum(b_h * b_a[None, :], axis=1)
b_h += (tmp[:, None] * b_b[None, :] + b_k[None, :] * b_v[:, None])
_o = b_h * b_q[None, :]
_o = tl.sum(_o, axis=1)
tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_v)
tl.store(p_ha, tmp.to(p_ha.dtype.element_ty), mask=mask_v)
p_q += K if HEAD_FIRST else K*H
p_k += K if HEAD_FIRST else K*H
p_o += V if HEAD_FIRST else V*H
p_v += V if HEAD_FIRST else V*H
p_ha += V if HEAD_FIRST else V*H
p_a += K if HEAD_FIRST else K*H
p_b += K if HEAD_FIRST else K*H
if STORE_FINAL_STATE:
p_ht = ht + i_nh * K * V + (tl.arange(0, BK)[None, :]) * V + ((i_v * BV + tl.arange(0, BV))[:, None])
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
@triton.heuristics({
'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
'USE_DHT': lambda args: args['dht'] is not None,
'USE_DH0': lambda args: args['dh0'] is not None,
'USE_OFFSETS': lambda args: args['offsets'] is not None
})
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
for num_warps in [2, 4, 8, 16]
for num_stages in [2, 3]
],
key=["BK", "BV"],
)
@triton.jit
def fused_recurrent_bwd_kernel(
# B: batch_size, H: n_heads, T: seq_len, D: b_dhead
# NV: number of split in the V dimension. NK: number of split in the K dimension
q, # query [B, H, L, K]
k, # key [B, H, L, V]
v, # value [B, H, L, V]
a, # a [B, H, L, K]
b, # b [B, H, L, K]
ha, # ha [B, H, L, V]
dht, # gradient of final state [B, H, K, V]
dh0, # gradient of initial state [B, H, K, V]
do, # gradient of output [B, H, L, V]
dq, # gradient of query [NV, B, H, L, K]
dk, # gradient of key [NV, B, H, L, K]
dv, # gradient of value [NK, B, H, L, V]
da, # gradient of a [NV, B, H, L, K]
db, # gradient of b [NV, B, H, L, K]
dha, # gradient of ha [NK, B, H, L, V]
h0, # initial state [B, H, K, V]
scale, # K ** -0.5
offsets, # offsets
B, # batch_size
H, # n_heads
T, # seq_len
K: tl.constexpr, # K
V: tl.constexpr, # V
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state h0
USE_DH0: tl.constexpr, # whether to use dh0
USE_DHT: tl.constexpr, # whether to use dht
USE_OFFSETS: tl.constexpr,
HEAD_FIRST: tl.constexpr
):
i_v, i_nh = tl.program_id(0), tl.program_id(1)
i_n, i_h = i_nh // H, i_nh % H
dk += i_v * B * H * K * T
db += i_v * B * H * K * T
dq += i_v * B * H * K * T
da += i_v * B * H * K * T
if USE_OFFSETS:
bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
T = eos - bos
else:
bos, eos = i_n * T, i_n * T + T
mask_k = tl.arange(0, BK) < K
mask_v = (tl.arange(0, BV) + i_v * BV) < V
q += (i_nh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K)
k += (i_nh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K)
v += (i_nh * T*V + i_v * BV) if HEAD_FIRST else ((bos * H + i_h) * V + i_v * BV)
ha += (i_nh * T*V + i_v * BV) if HEAD_FIRST else ((bos * H + i_h) * V + i_v * BV)
a += (i_nh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K)
b += (i_nh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K)
do += (i_nh * T*V + i_v * BV) if HEAD_FIRST else ((bos * H + i_h) * V + i_v * BV)
dq += (i_nh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K)
dk += (i_nh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K)
dv += (i_nh * T*V + i_v * BV) if HEAD_FIRST else ((bos * H + i_h) * V + i_v * BV)
da += (i_nh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K)
db += (i_nh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K)
dha += (i_nh * T*V + i_v * BV) if HEAD_FIRST else ((bos * H + i_h) * V + i_v * BV)
p_q = q + tl.arange(0, BK) + (T - 1) * K * (1 if HEAD_FIRST else H)
p_k = k + tl.arange(0, BK) + (T - 1) * K * (1 if HEAD_FIRST else H)
p_v = v + tl.arange(0, BV) + (T - 1) * V * (1 if HEAD_FIRST else H)
p_ha = ha + tl.arange(0, BV) + (T - 1) * V * (1 if HEAD_FIRST else H)
p_a = a + tl.arange(0, BK) + (T - 1) * K * (1 if HEAD_FIRST else H)
p_b = b + tl.arange(0, BK) + (T - 1) * K * (1 if HEAD_FIRST else H)
p_do = do + tl.arange(0, BV) + (T - 1) * V * (1 if HEAD_FIRST else H)
p_dk = dk + tl.arange(0, BK) + (T - 1) * K * (1 if HEAD_FIRST else H)
p_dv = dv + tl.arange(0, BV) + (T - 1) * V * (1 if HEAD_FIRST else H)
p_dha = dha + tl.arange(0, BV) + (T - 1) * V * (1 if HEAD_FIRST else H)
p_db = db + tl.arange(0, BK) + (T - 1) * K * (1 if HEAD_FIRST else H)
p_da = da + tl.arange(0, BK) + (T - 1) * K * (1 if HEAD_FIRST else H)
p_dq = dq + tl.arange(0, BK) + (T - 1) * K * (1 if HEAD_FIRST else H)
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
if USE_DHT:
p_ht = dht + i_nh * K * V + (tl.arange(0, BK)[:, None]) * V + ((i_v * BV + tl.arange(0, BV))[None, :])
b_dh += tl.load(p_ht, mask=mask_k[:, None] & mask_v[None, :], other=0).to(tl.float32)
for _ in range(T):
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32)
b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32)
b_ha = tl.load(p_ha, mask=mask_v, other=0).to(tl.float32)
b_dh += b_q[:, None] * b_do[None, :]
d_k = tl.sum(b_dh * b_v[None, :], axis=1)
d_v = tl.sum(b_dh * b_k[:, None], axis=0)
tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_k)
tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_v)
b_dha = tl.sum(b_dh * b_b[:, None], axis=0)
tl.store(p_dha, b_dha.to(p_dha.dtype.element_ty), mask=mask_v)
b_db = tl.sum(b_dh * b_ha[None, :], axis=1)
tl.store(p_db, b_db.to(p_db.dtype.element_ty), mask=mask_k)
b_dh += b_dha[None, :] * b_a[:, None]
p_do -= V if HEAD_FIRST else V*H
p_q -= K if HEAD_FIRST else K*H
p_k -= K if HEAD_FIRST else K*H
p_v -= V if HEAD_FIRST else V*H
p_dk -= K if HEAD_FIRST else K*H
p_dv -= V if HEAD_FIRST else V*H
p_b -= K if HEAD_FIRST else K*H
p_db -= K if HEAD_FIRST else K*H
p_a -= K if HEAD_FIRST else K*H
p_dha -= V if HEAD_FIRST else V*H
p_ha -= V if HEAD_FIRST else V*H
if USE_DH0:
p_dh0 = dh0 + i_nh * K * V + (tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_k[:, None] & mask_v[None, :])
tl.debug_barrier()
b_h = tl.zeros([BK, BV], dtype=tl.float32)
if USE_INITIAL_STATE:
mask_kv = mask_k[:, None] & mask_v[None, :]
p_h0 = h0 + i_nh * K * V + (tl.arange(0, BK)[:, None]) * V + ((i_v * BV + tl.arange(0, BV))[None, :])
b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)
p_k = k + tl.arange(0, BK)
p_v = v + tl.arange(0, BV)
p_ha = ha + tl.arange(0, BV)
p_do = do + tl.arange(0, BV)
p_dha = dha + tl.arange(0, BV)
p_da = da + tl.arange(0, BK)
p_dq = dq + tl.arange(0, BK)
p_b = b + tl.arange(0, BK)
for i in range(0, T):
b_dha = tl.load(p_dha, mask=mask_v, other=0).to(tl.float32)
d_a = tl.sum(b_dha[None, :] * b_h, axis=1)
tl.store(p_da, d_a.to(p_da.dtype.element_ty), mask=mask_k)
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32)
b_ha = tl.load(p_ha, mask=mask_v, other=0).to(tl.float32)
b_h += b_k[:, None] * b_v[None, :] + b_b[:, None] * b_ha[None, :]
_d_q = b_h * b_do[None, :]
d_q = tl.sum(_d_q, axis=1) * scale
tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_k)
p_k += K if HEAD_FIRST else K*H
p_do += V if HEAD_FIRST else V*H
p_v += V if HEAD_FIRST else V*H
p_da += K if HEAD_FIRST else K*H
p_dha += V if HEAD_FIRST else V*H
p_ha += V if HEAD_FIRST else V*H
p_dq += K if HEAD_FIRST else K*H
p_b += K if HEAD_FIRST else K*H
class FusedRecurrentIPLRDeltaRuleFunction(torch.autograd.Function):
@staticmethod
@input_guard
def forward(ctx, q, k, v, a, b, scale=None, initial_state=None, output_final_state=False, offsets=None, head_first=False):
if head_first:
B, H, T, K, V = *k.shape, v.shape[-1]
else:
B, T, H, K, V = *k.shape, v.shape[-1]
N = B if offsets is None else len(offsets) - 1
BK = triton.next_power_of_2(K)
if output_final_state:
final_state = q.new_empty(B, H, K, V, dtype=torch.float32)
else:
final_state = None
ha = torch.empty_like(v, dtype=torch.float32)
def grid(meta): return (
triton.cdiv(V, meta['BV']),
N * H
)
o = torch.empty_like(v)
fused_recurrent_fwd_kernel[grid](
q=q,
k=k,
v=v,
a=a,
b=b,
o=o,
ha=ha,
h0=initial_state,
ht=final_state,
scale=scale,
offsets=offsets,
H=H,
T=T,
K=K,
V=V,
BK=BK,
HEAD_FIRST=head_first
)
ctx.save_for_backward(q, k, v, a, b, ha, initial_state)
ctx.scale = scale
ctx.head_first = head_first
ctx.offsets = offsets
return o, final_state
@staticmethod
@input_guard
def backward(ctx, do, dht):
q, k, v, a, b, ha, initial_state = ctx.saved_tensors
if ctx.head_first:
B, H, T, K, V = *q.shape, v.shape[-1]
else:
B, T, H, K, V = *q.shape, v.shape[-1]
N = B if ctx.offsets is None else len(ctx.offsets) - 1
scale = ctx.scale
BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 64)
NV = triton.cdiv(V, BV)
dq = q.new_empty(NV, *q.shape)
dk = k.new_empty(NV, *k.shape)
da = a.new_empty(NV, *a.shape)
db = b.new_empty(NV, *b.shape)
dv = torch.empty_like(v)
dha = torch.empty_like(ha)
grid = (NV, N * H)
if initial_state is not None and initial_state.requires_grad:
dh0 = torch.empty_like(initial_state, dtype=torch.float32)
else:
dh0 = None
fused_recurrent_bwd_kernel[grid](
q=q,
k=k,
v=v,
a=a,
b=b,
ha=ha,
dht=dht,
dh0=dh0,
do=do,
dq=dq,
dk=dk,
dv=dv,
da=da,
db=db,
dha=dha,
h0=initial_state,
scale=scale,
offsets=ctx.offsets,
B=B,
H=H,
T=T,
K=K,
V=V,
BK=BK,
BV=BV,
HEAD_FIRST=ctx.head_first
)
dq = dq.sum(0)
dk = dk.sum(0)
da = da.sum(0)
db = db.sum(0)
return dq.to(q), dk.to(k), dv.to(v), da.to(a), db.to(b), None, dh0, None, None, None
def fused_recurrent_iplr_delta_rule(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
scale: float = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
offsets: torch.Tensor = None,
head_first: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
This function computes the recurrence S_t = S_t @ (I + a_t b_t^T) + v_t k_t^T in a recurrent manner.
Args:
q (torch.Tensor):
queries of shape `[B, H, T, K]`
k (torch.Tensor):
keys of shape `[B, H, T, K]`
v (torch.Tensor):
values of shape `[B, H, T, V]`
a (torch.Tensor):
as of shape `[B, H, T, K]`
b (torch.Tensor):
bs of shape `[B, H, T, K]`
scale (Optional[int]):
Scale factor for the RetNet attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `[B, H, K, V]`. Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state of shape `[B, H, K, V]`. Default: `False`.
offsets (Optional[torch.Tensor]):
"""
if offsets is not None:
if q.shape[0] != 1:
raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `offsets`."
f"Please flatten variable-length inputs before processing.")
if head_first:
raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
if initial_state is not None and initial_state.shape[0] != len(offsets) - 1:
raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, "
f"i.e., {len(offsets) - 1} rather than {initial_state.shape[0]}.")
if scale is None:
scale = q.shape[-1] ** -0.5
else:
assert scale > 0, "scale must be positive"
o, final_state = FusedRecurrentIPLRDeltaRuleFunction.apply(
q, k, v, a, b, scale, initial_state, output_final_state, offsets, head_first)
return o, final_state