|
|
|
|
|
|
|
from typing import Optional, Tuple |
|
|
|
import torch |
|
import triton |
|
import triton.language as tl |
|
from einops import rearrange |
|
|
|
from fla.modules.l2norm import l2norm_bwd, l2norm_fwd |
|
from fla.utils import input_guard |
|
|
|
|
|
@triton.heuristics({ |
|
'USE_INITIAL_STATE': lambda args: args['h0'] is not None, |
|
'STORE_FINAL_STATE': lambda args: args['ht'] is not None, |
|
'USE_OFFSETS': lambda args: args['offsets'] is not None |
|
}) |
|
@triton.jit(do_not_specialize=['T']) |
|
def fused_recurrent_delta_rule_fwd_kernel( |
|
q, |
|
k, |
|
v, |
|
u, |
|
beta, |
|
o, |
|
h0, |
|
ht, |
|
offsets, |
|
scale, |
|
T, |
|
B: tl.constexpr, |
|
H: tl.constexpr, |
|
K: tl.constexpr, |
|
V: tl.constexpr, |
|
BK: tl.constexpr, |
|
BV: tl.constexpr, |
|
USE_INITIAL_STATE: tl.constexpr, |
|
STORE_FINAL_STATE: tl.constexpr, |
|
IS_BETA_HEADWISE: tl.constexpr, |
|
USE_OFFSETS: tl.constexpr, |
|
HEAD_FIRST: tl.constexpr |
|
): |
|
i_v, i_k, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) |
|
i_n, i_h = i_nh // H, i_nh % H |
|
if USE_OFFSETS: |
|
bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64) |
|
all = T |
|
T = eos - bos |
|
else: |
|
bos, eos = i_n * T, i_n * T + T |
|
all = B * T |
|
|
|
if HEAD_FIRST: |
|
p_q = q + i_nh * T*K + i_k * BK + tl.arange(0, BK) |
|
p_k = k + i_nh * T*K + i_k * BK + tl.arange(0, BK) |
|
p_v = v + i_nh * T*V + i_v * BV + tl.arange(0, BV) |
|
p_u = u + i_nh * T*V + i_v * BV + tl.arange(0, BV) |
|
if IS_BETA_HEADWISE: |
|
p_beta = beta + i_nh * T*V + i_v * BV + tl.arange(0, BV) |
|
else: |
|
p_beta = beta + i_nh * T |
|
p_o = o + (i_k * B*H + i_nh) * T*V + i_v * BV + tl.arange(0, BV) |
|
else: |
|
p_q = q + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK) |
|
p_k = k + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK) |
|
p_v = v + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) |
|
p_u = u + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) |
|
if IS_BETA_HEADWISE: |
|
p_beta = beta + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) |
|
else: |
|
p_beta = beta + bos * H + i_h |
|
p_o = o + ((i_k * all + bos) * H + i_h) * V + i_v * BV + tl.arange(0, BV) |
|
|
|
mask_k = (i_k * BK + tl.arange(0, BK)) < K |
|
mask_v = (i_v * BV + tl.arange(0, BV)) < V |
|
mask_h = mask_k[None, :] & mask_v[:, None] |
|
|
|
b_h = tl.zeros([BV, BK], dtype=tl.float32) |
|
if USE_INITIAL_STATE: |
|
p_h0 = h0 + i_nh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) |
|
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) |
|
|
|
for _ in range(0, T): |
|
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) |
|
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) |
|
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale |
|
b_v_minus = tl.sum(b_h * b_k[None, :], axis=1) |
|
b_v -= b_v_minus |
|
if IS_BETA_HEADWISE: |
|
b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) |
|
else: |
|
b_beta = tl.load(p_beta).to(tl.float32) |
|
tl.store(p_u, b_v.to(p_v.dtype.element_ty), mask=mask_v) |
|
b_v *= b_beta |
|
b_h += b_k[None, :] * b_v[:, None] |
|
b_o = b_h * b_q[None, :] |
|
b_o = tl.sum(b_o, axis=1) |
|
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) |
|
|
|
p_q += K if HEAD_FIRST else H*K |
|
p_k += K if HEAD_FIRST else H*K |
|
p_o += V if HEAD_FIRST else H*V |
|
p_v += V if HEAD_FIRST else H*V |
|
p_u += V if HEAD_FIRST else H*V |
|
p_beta += (1 if HEAD_FIRST else H) * (V if IS_BETA_HEADWISE else 1) |
|
|
|
if STORE_FINAL_STATE: |
|
p_ht = ht + i_nh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) |
|
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) |
|
|
|
|
|
@triton.heuristics({ |
|
'USE_INITIAL_STATE': lambda args: args['h0'] is not None, |
|
'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, |
|
'USE_OFFSETS': lambda args: args['offsets'] is not None |
|
}) |
|
@triton.jit(do_not_specialize=['T']) |
|
def fused_recurrent_delta_rule_bwd_kernel( |
|
q, |
|
k, |
|
v, |
|
beta, |
|
h0, |
|
dh0, |
|
dht, |
|
do, |
|
dq, |
|
dk, |
|
dv, |
|
db, |
|
offsets, |
|
scale, |
|
B: tl.constexpr, |
|
T, |
|
H: tl.constexpr, |
|
K: tl.constexpr, |
|
V: tl.constexpr, |
|
BK: tl.constexpr, |
|
BV: tl.constexpr, |
|
NK: tl.constexpr, |
|
IS_BETA_HEADWISE: tl.constexpr, |
|
USE_INITIAL_STATE: tl.constexpr, |
|
USE_FINAL_STATE_GRADIENT: tl.constexpr, |
|
USE_OFFSETS: tl.constexpr, |
|
HEAD_FIRST: tl.constexpr |
|
): |
|
i_v, i_k, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) |
|
i_n, i_h = i_nh // H, i_nh % H |
|
if USE_OFFSETS: |
|
bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64) |
|
all = T |
|
T = eos - bos |
|
else: |
|
bos, eos = i_n * T, i_n * T + T |
|
all = B * T |
|
|
|
mask_k = i_k * BK + tl.arange(0, BK) < K |
|
mask_v = i_v * BV + tl.arange(0, BV) < V |
|
|
|
if HEAD_FIRST: |
|
p_q = q + i_nh * T*K + i_k * BK + tl.arange(0, BK) + (T - 1) * K |
|
p_k = k + i_nh * T*K + i_k * BK + tl.arange(0, BK) + (T - 1) * K |
|
p_v = v + i_nh * T*V + i_v * BV + tl.arange(0, BV) + (T - 1) * V |
|
p_do = do + i_nh * T*V + i_v * BV + tl.arange(0, BV) + (T - 1) * V |
|
p_dk = dk + (i_v * B*H + i_nh) * T*K + i_k * BK + tl.arange(0, BK) + (T - 1) * K |
|
p_dv = dv + (i_k * B*H + i_nh) * T*V + i_v * BV + tl.arange(0, BV) + (T - 1) * V |
|
if IS_BETA_HEADWISE: |
|
p_beta = beta + i_nh * T*V + i_v * BV + tl.arange(0, BV) + (T - 1) * V |
|
p_dbeta = db + (i_v * NK*B*H + i_k * B*H + i_nh) * T*V + tl.arange(0, BV) + (T - 1) * V |
|
else: |
|
p_beta = beta + i_nh * T + T - 1 |
|
p_dbeta = db + (i_v * B*H + i_nh) * T + T - 1 |
|
else: |
|
p_q = q + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK) + (T - 1) * H*K |
|
p_k = k + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK) + (T - 1) * H*K |
|
p_v = v + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) + (T - 1) * H*V |
|
p_do = do + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) + (T - 1) * H*V |
|
p_dk = dk + ((i_v * all + bos) * H + i_h) * K + i_k * BK + tl.arange(0, BK) + (T - 1) * H*K |
|
p_dv = dv + ((i_k * all + bos) * H + i_h) * V + i_v * BV + tl.arange(0, BV) + (T - 1) * H*V |
|
if IS_BETA_HEADWISE: |
|
p_beta = beta + (bos + T - 1) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) |
|
p_dbeta = db + ((i_v * NK + i_k) * all + bos + T - 1) * H*V + i_h * V + tl.arange(0, BV) |
|
else: |
|
p_beta = beta + (bos + T - 1) * H + i_h |
|
p_dbeta = db + (i_v * all + bos + T - 1) * H + i_h |
|
|
|
b_dh = tl.zeros([BK, BV], dtype=tl.float32) |
|
if USE_FINAL_STATE_GRADIENT: |
|
p_ht = dht + i_nh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) |
|
b_dh += tl.load(p_ht, mask=mask_k[:, None] & mask_v[None, :], other=0).to(tl.float32) |
|
|
|
for _ in range(T): |
|
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale |
|
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) |
|
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) |
|
b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32) |
|
if IS_BETA_HEADWISE: |
|
b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) |
|
else: |
|
b_beta = tl.load(p_beta).to(tl.float32) |
|
b_dh += b_q[:, None] * b_do[None, :] |
|
b_dk = tl.sum(b_dh * (b_v * b_beta)[None, :], axis=1) |
|
b_dv = tl.sum(b_dh * b_k[:, None], axis=0) |
|
|
|
b_db = b_dv * b_v if IS_BETA_HEADWISE else tl.sum(b_dv * b_v) |
|
b_dv = b_dv * b_beta |
|
|
|
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_k) |
|
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_v) |
|
if IS_BETA_HEADWISE: |
|
tl.store(p_dbeta, b_db.to(p_dbeta.dtype.element_ty), mask=mask_v) |
|
else: |
|
tl.store(p_dbeta, b_db.to(p_dbeta.dtype.element_ty)) |
|
|
|
b_dh -= b_k[:, None] * b_dv[None, :] |
|
|
|
p_q -= K if HEAD_FIRST else H*K |
|
p_k -= K if HEAD_FIRST else H*K |
|
p_v -= V if HEAD_FIRST else H*V |
|
p_do -= V if HEAD_FIRST else H*V |
|
p_dk -= K if HEAD_FIRST else H*K |
|
p_dv -= V if HEAD_FIRST else H*V |
|
p_dbeta -= (1 if HEAD_FIRST else H) * (V if IS_BETA_HEADWISE else 1) |
|
p_beta -= (1 if HEAD_FIRST else H) * (V if IS_BETA_HEADWISE else 1) |
|
|
|
if USE_INITIAL_STATE: |
|
p_dh0 = dh0 + i_nh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) |
|
tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_k[:, None] & mask_v[None, :]) |
|
|
|
tl.debug_barrier() |
|
|
|
b_h = tl.zeros([BK, BV], dtype=tl.float32) |
|
|
|
if HEAD_FIRST: |
|
p_q = q + i_nh * T*K + i_k * BK + tl.arange(0, BK) |
|
p_k = k + i_nh * T*K + i_k * BK + tl.arange(0, BK) |
|
p_v = v + i_nh * T*V + i_v * BV + tl.arange(0, BV) |
|
if IS_BETA_HEADWISE: |
|
p_beta = beta + i_nh * T*V + i_v * BV + tl.arange(0, BV) |
|
else: |
|
p_beta = beta + i_nh * T |
|
p_do = do + i_nh * T*V + i_v * BV + tl.arange(0, BV) |
|
p_dq = dq + (i_v * B*H + i_nh) * T*K + i_k * BK + tl.arange(0, BK) |
|
p_dk = dk + (i_v * B*H + i_nh) * T*K + i_k * BK + tl.arange(0, BK) |
|
p_dv = dv + (i_k * B*H + i_nh) * T*V + i_v * BV + tl.arange(0, BV) |
|
else: |
|
p_q = q + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK) |
|
p_k = k + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK) |
|
p_v = v + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) |
|
if IS_BETA_HEADWISE: |
|
p_beta = beta + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) |
|
else: |
|
p_beta = beta + bos * H + i_h |
|
p_do = do + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) |
|
p_dq = dq + ((i_v * all + bos) * H + i_h) * K + i_k * BK + tl.arange(0, BK) |
|
p_dk = dk + ((i_v * all + bos) * H + i_h) * K + i_k * BK + tl.arange(0, BK) |
|
p_dv = dv + ((i_k * all + bos) * H + i_h) * V + i_v * BV + tl.arange(0, BV) |
|
|
|
if USE_INITIAL_STATE: |
|
mask_h = mask_k[:, None] & mask_v[None, :] |
|
p_h0 = h0 + i_nh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) |
|
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) |
|
|
|
for _ in range(0, T): |
|
b_dk = tl.load(p_dk, mask=mask_k, other=0).to(tl.float32) |
|
b_dv = tl.load(p_dv, mask=mask_v, other=0).to(tl.float32) |
|
b_dk -= tl.sum(b_dv[None, :] * b_h, axis=1) |
|
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_k) |
|
|
|
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) |
|
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) |
|
b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32) |
|
if IS_BETA_HEADWISE: |
|
b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) |
|
else: |
|
b_beta = tl.load(p_beta).to(tl.float32) |
|
b_v *= b_beta |
|
|
|
b_h += b_k[:, None] * b_v[None, :] |
|
b_dq = b_h * b_do[None, :] |
|
d_q = tl.sum(b_dq, axis=1) * scale |
|
tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_k) |
|
|
|
p_k += K if HEAD_FIRST else H*K |
|
p_v += V if HEAD_FIRST else H*V |
|
p_do += V if HEAD_FIRST else H*V |
|
p_dq += K if HEAD_FIRST else H*K |
|
p_dk += K if HEAD_FIRST else H*K |
|
p_dv += V if HEAD_FIRST else H*V |
|
p_beta += (1 if HEAD_FIRST else H) * (V if IS_BETA_HEADWISE else 1) |
|
|
|
|
|
def fused_recurrent_delta_rule_fwd( |
|
q: torch.Tensor, |
|
k: torch.Tensor, |
|
v: torch.Tensor, |
|
beta: torch.Tensor, |
|
scale: float, |
|
initial_state: torch.Tensor, |
|
output_final_state: bool, |
|
offsets: Optional[torch.LongTensor] = None, |
|
head_first: bool = True |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
if head_first: |
|
B, H, T, K, V = *k.shape, v.shape[-1] |
|
else: |
|
B, T, H, K, V = *k.shape, v.shape[-1] |
|
N = B if offsets is None else len(offsets) - 1 |
|
BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8) |
|
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) |
|
assert NK == 1, "NK > 1 is not supported yet" |
|
num_stages = 1 |
|
num_warps = 1 |
|
|
|
o = q.new_empty(NK, *v.shape) |
|
if output_final_state: |
|
final_state = q.new_empty(N, H, K, V, dtype=torch.float32) |
|
else: |
|
final_state = None |
|
|
|
grid = (NV, NK, N * H) |
|
u = torch.empty_like(v) |
|
fused_recurrent_delta_rule_fwd_kernel[grid]( |
|
q, |
|
k, |
|
v, |
|
u, |
|
beta, |
|
o, |
|
initial_state, |
|
final_state, |
|
offsets, |
|
scale, |
|
T=T, |
|
B=B, |
|
H=H, |
|
K=K, |
|
V=V, |
|
BK=BK, |
|
BV=BV, |
|
IS_BETA_HEADWISE=beta.ndim == v.ndim, |
|
HEAD_FIRST=head_first, |
|
num_warps=num_warps, |
|
num_stages=num_stages, |
|
) |
|
o = o.squeeze(0) |
|
return o, u, final_state |
|
|
|
|
|
def fused_recurrent_delta_rule_bwd( |
|
q: torch.Tensor, |
|
k: torch.Tensor, |
|
v: torch.Tensor, |
|
beta: torch.Tensor, |
|
dht: torch.Tensor, |
|
do: torch.Tensor, |
|
scale: float, |
|
initial_state: torch.Tensor, |
|
offsets: Optional[torch.LongTensor] = None, |
|
head_first: bool = True |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
|
if head_first: |
|
B, H, T, K, V = *k.shape, v.shape[-1] |
|
else: |
|
B, T, H, K, V = *k.shape, v.shape[-1] |
|
N = B if offsets is None else len(offsets) - 1 |
|
BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32) |
|
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) |
|
assert NK == 1, "NK > 1 is not supported yet" |
|
num_stages = 1 |
|
num_warps = 2 |
|
|
|
beta_vector = beta.ndim == v.ndim |
|
|
|
dq = q.new_empty(NV, *q.shape) |
|
dk = q.new_empty(NV, *k.shape) |
|
dv = q.new_empty(NK, *v.shape) |
|
if beta_vector: |
|
db = q.new_empty(NV, NK, B, H, T, V) if head_first else q.new_empty(NV, NK, B, T, H, V) |
|
else: |
|
db = q.new_empty(NV, B, H, T) if head_first else q.new_empty(NV, B, T, H) |
|
grid = (NV, NK, N * H) |
|
|
|
if initial_state is not None and initial_state.requires_grad: |
|
dh0 = torch.empty_like(initial_state, dtype=torch.float32) |
|
else: |
|
dh0 = None |
|
|
|
fused_recurrent_delta_rule_bwd_kernel[grid]( |
|
q, |
|
k, |
|
v, |
|
beta, |
|
initial_state, |
|
dh0, |
|
dht, |
|
do, |
|
dq, |
|
dk, |
|
dv, |
|
db, |
|
offsets, |
|
scale, |
|
T=T, |
|
B=B, |
|
H=H, |
|
K=K, |
|
V=V, |
|
BK=BK, |
|
BV=BV, |
|
NK=NK, |
|
IS_BETA_HEADWISE=beta_vector, |
|
HEAD_FIRST=head_first, |
|
num_warps=num_warps, |
|
num_stages=num_stages |
|
) |
|
dq = dq.sum(0) |
|
dk = dk.sum(0) |
|
dv = dv.sum(0) |
|
db = db.sum((0, 1)) if beta_vector else db.sum(0) |
|
|
|
return dq, dk, dv, db, dh0 |
|
|
|
|
|
class FusedRecurrentFunction(torch.autograd.Function): |
|
|
|
@staticmethod |
|
@input_guard |
|
def forward( |
|
ctx, |
|
q: torch.Tensor, |
|
k: torch.Tensor, |
|
v: torch.Tensor, |
|
beta: torch.Tensor, |
|
scale: float, |
|
initial_state: torch.Tensor, |
|
output_final_state: bool, |
|
offsets: Optional[torch.LongTensor] = None, |
|
head_first: bool = True, |
|
use_qk_l2norm_in_kernel: bool = False |
|
): |
|
q_orig = q |
|
k_orig = k |
|
|
|
if use_qk_l2norm_in_kernel: |
|
q = l2norm_fwd(q) |
|
k = l2norm_fwd(k) |
|
|
|
o, u, final_state = fused_recurrent_delta_rule_fwd( |
|
q=q, |
|
k=k, |
|
v=v, |
|
beta=beta, |
|
scale=scale, |
|
initial_state=initial_state, |
|
output_final_state=output_final_state, |
|
offsets=offsets, |
|
head_first=head_first |
|
) |
|
|
|
ctx.save_for_backward(q_orig, k_orig, u, beta, initial_state) |
|
ctx.scale = scale |
|
ctx.offsets = offsets |
|
ctx.head_first = head_first |
|
ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel |
|
return o, final_state |
|
|
|
@staticmethod |
|
@input_guard |
|
def backward(ctx, do, dht): |
|
q, k, v, beta, initial_state = ctx.saved_tensors |
|
if ctx.use_qk_l2norm_in_kernel: |
|
q, q_orig = l2norm_fwd(q), q |
|
k, k_orig = l2norm_fwd(k), k |
|
dq, dk, dv, db, dh0 = fused_recurrent_delta_rule_bwd( |
|
q=q, |
|
k=k, |
|
v=v, |
|
beta=beta, |
|
dht=dht, |
|
do=do, |
|
scale=ctx.scale, |
|
initial_state=initial_state, |
|
offsets=ctx.offsets, |
|
head_first=ctx.head_first |
|
) |
|
if ctx.use_qk_l2norm_in_kernel: |
|
dq, dk = l2norm_bwd(q_orig, dq), l2norm_bwd(k_orig, dk) |
|
return dq.to(q), dk.to(k), dv.to(v), db.to(beta), None, dh0, None, None, None, None |
|
|
|
|
|
@torch.compiler.disable |
|
def fused_recurrent_delta_rule( |
|
q: torch.Tensor, |
|
k: torch.Tensor, |
|
v: torch.Tensor, |
|
beta: torch.Tensor = None, |
|
scale: float = None, |
|
initial_state: torch.Tensor = None, |
|
output_final_state: bool = False, |
|
cu_seqlens: Optional[torch.LongTensor] = None, |
|
head_first: bool = True, |
|
use_qk_l2norm_in_kernel: bool = False |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
r""" |
|
Args: |
|
q (torch.Tensor): |
|
queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. |
|
k (torch.Tensor): |
|
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. |
|
v (torch.Tensor): |
|
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. |
|
beta (torch.Tensor): |
|
betas of shape `[B, T, H]` if `head_first=False` else `(B, H, T)`. |
|
scale (Optional[int]): |
|
Scale factor for the RetNet attention scores. |
|
If not provided, it will default to `1 / sqrt(K)`. Default: `None`. |
|
initial_state (Optional[torch.Tensor]): |
|
Initial state of shape `[N, H, K, V]` for `N` input sequences. |
|
For equal-length input sequences, `N` equals the batch size `B`. |
|
Default: `None`. |
|
output_final_state (Optional[bool]): |
|
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. |
|
cu_seqlens (torch.LongTensor): |
|
Cumulative sequence lengths of shape `[N+1]` used for variable-length training, |
|
consistent with the FlashAttention API. |
|
head_first (Optional[bool]): |
|
Whether the inputs are in the head-first format, which is not supported for variable-length inputs. |
|
Default: `False`. |
|
|
|
Returns: |
|
o (torch.Tensor): |
|
Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. |
|
final_state (torch.Tensor): |
|
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. |
|
|
|
Examples:: |
|
>>> import torch |
|
>>> import torch.nn.functional as F |
|
>>> from einops import rearrange |
|
>>> from fla.ops.delta_rule import fused_recurrent_delta_rule |
|
# inputs with equal lengths |
|
>>> B, T, H, K, V = 4, 2048, 4, 512, 512 |
|
>>> q = torch.randn(B, T, H, K, device='cuda') |
|
>>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1) |
|
>>> v = torch.randn(B, T, H, V, device='cuda') |
|
>>> beta = torch.rand(B, T, H, device='cuda').sigmoid() |
|
>>> h0 = torch.randn(B, H, K, V, device='cuda') |
|
>>> o, ht = fused_recurrent_delta_rule( |
|
q, k, v, beta, |
|
initial_state=h0, |
|
output_final_state=True |
|
) |
|
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required |
|
>>> q, k, v, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta)) |
|
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected |
|
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) |
|
>>> o_var, ht_var = fused_recurrent_delta_rule( |
|
q, k, v, beta, |
|
initial_state=h0, |
|
output_final_state=True, |
|
cu_seqlens=cu_seqlens |
|
) |
|
>>> assert o.allclose(o_var.view(o.shape)) |
|
>>> assert ht.allclose(ht_var) |
|
""" |
|
if cu_seqlens is not None: |
|
if q.shape[0] != 1: |
|
raise ValueError( |
|
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." |
|
f"Please flatten variable-length inputs before processing." |
|
) |
|
if head_first: |
|
raise RuntimeError( |
|
"Sequences with variable lengths are not supported for head-first mode" |
|
) |
|
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: |
|
raise ValueError( |
|
f"The number of initial states is expected to be equal to the number of input sequences, " |
|
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." |
|
) |
|
if scale is None: |
|
scale = k.shape[-1] ** -0.5 |
|
else: |
|
assert scale > 0, "scale must be positive" |
|
if beta is None: |
|
beta = torch.ones_like(q[..., 0]) |
|
if head_first: |
|
q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v)) |
|
beta = rearrange(beta, 'b h t -> b t h') |
|
o, final_state = FusedRecurrentFunction.apply( |
|
q, |
|
k, |
|
v, |
|
beta, |
|
scale, |
|
initial_state, |
|
output_final_state, |
|
cu_seqlens, |
|
False, |
|
use_qk_l2norm_in_kernel |
|
) |
|
if head_first: |
|
o = rearrange(o, 'b t h v -> b h t v') |
|
return o, final_state |
|
|