|
|
|
|
|
|
|
from typing import Optional |
|
|
|
import torch |
|
import triton |
|
from einops import rearrange |
|
|
|
from fla.modules.l2norm import l2norm_bwd, l2norm_fwd |
|
from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu, chunk_gated_delta_rule_fwd_h |
|
from fla.ops.common.chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv_local, chunk_fwd_o |
|
from fla.ops.common.utils import prepare_chunk_indices |
|
from fla.ops.delta_rule.wy_fast import bwd_prepare_wy_repr, fwd_prepare_wy_repr, fwd_recompute_w_u |
|
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard |
|
|
|
|
|
def chunk_delta_rule_fwd( |
|
q: torch.Tensor, |
|
k: torch.Tensor, |
|
v: torch.Tensor, |
|
beta: torch.Tensor, |
|
scale: float, |
|
initial_state: torch.Tensor, |
|
output_final_state: bool, |
|
offsets: Optional[torch.LongTensor] = None, |
|
indices: Optional[torch.LongTensor] = None, |
|
head_first: bool = True, |
|
chunk_size: int = 64 |
|
): |
|
T = q.shape[2] if head_first else q.shape[1] |
|
BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) |
|
|
|
w, u, A = fwd_prepare_wy_repr( |
|
k=k, |
|
v=v, |
|
beta=beta, |
|
offsets=offsets, |
|
indices=indices, |
|
head_first=head_first, |
|
chunk_size=BT |
|
) |
|
|
|
h, v_new, final_state = chunk_gated_delta_rule_fwd_h( |
|
k=k, |
|
w=w, |
|
u=u, |
|
g=None, |
|
initial_state=initial_state, |
|
output_final_state=output_final_state, |
|
offsets=offsets, |
|
indices=indices, |
|
head_first=head_first, |
|
chunk_size=BT |
|
) |
|
o = chunk_fwd_o( |
|
q=q, |
|
k=k, |
|
v=v_new, |
|
h=h, |
|
g=None, |
|
scale=scale, |
|
offsets=offsets, |
|
indices=indices, |
|
head_first=head_first, |
|
chunk_size=BT |
|
) |
|
return o, A, final_state |
|
|
|
|
|
def chunk_delta_rule_bwd( |
|
q: torch.Tensor, |
|
k: torch.Tensor, |
|
v: torch.Tensor, |
|
beta: torch.Tensor, |
|
A: torch.Tensor, |
|
scale: float, |
|
initial_state: torch.Tensor, |
|
do: torch.Tensor, |
|
dht: torch.Tensor, |
|
offsets: Optional[torch.LongTensor] = None, |
|
indices: Optional[torch.LongTensor] = None, |
|
head_first: bool = True, |
|
chunk_size: int = 64 |
|
): |
|
T = q.shape[2] if head_first else q.shape[1] |
|
BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) |
|
w, u = fwd_recompute_w_u( |
|
k=k, |
|
v=v, |
|
beta=beta, |
|
A=A, |
|
offsets=offsets, |
|
indices=indices, |
|
head_first=head_first, |
|
chunk_size=BT |
|
) |
|
h, v_new, _ = chunk_gated_delta_rule_fwd_h( |
|
k=k, |
|
w=w, |
|
u=u, |
|
g=None, |
|
initial_state=initial_state, |
|
output_final_state=False, |
|
offsets=offsets, |
|
indices=indices, |
|
head_first=head_first, |
|
chunk_size=BT |
|
) |
|
dv = chunk_bwd_dv_local( |
|
q=q, |
|
k=k, |
|
do=do, |
|
g=None, |
|
dh=None, |
|
scale=scale, |
|
offsets=offsets, |
|
indices=indices, |
|
head_first=head_first, |
|
chunk_size=BT |
|
) |
|
dh, dh0, dv = chunk_gated_delta_rule_bwd_dhu( |
|
q=q, |
|
k=k, |
|
w=w, |
|
g=None, |
|
h0=initial_state, |
|
dht=dht, |
|
do=do, |
|
dv=dv, |
|
scale=scale, |
|
offsets=offsets, |
|
indices=indices, |
|
head_first=head_first, |
|
chunk_size=BT |
|
) |
|
dq, dk, dw, _ = chunk_bwd_dqkwg( |
|
q=q, |
|
k=k, |
|
v=v_new, |
|
h=h, |
|
w=w, |
|
dv=dv, |
|
do=do, |
|
dh=dh, |
|
g=None, |
|
scale=scale, |
|
offsets=offsets, |
|
indices=indices, |
|
head_first=head_first, |
|
chunk_size=BT |
|
) |
|
dk2, dv, db = bwd_prepare_wy_repr( |
|
k=k, |
|
v=v, |
|
beta=beta, |
|
A=A, |
|
dw=dw, |
|
du=dv, |
|
offsets=offsets, |
|
indices=indices, |
|
head_first=head_first, |
|
chunk_size=BT |
|
) |
|
dk.add_(dk2) |
|
return dq, dk, dv, db, dh0 |
|
|
|
|
|
class ChunkDeltaRuleFunction(torch.autograd.Function): |
|
|
|
@staticmethod |
|
@input_guard |
|
@autocast_custom_fwd |
|
def forward( |
|
ctx, |
|
q: torch.Tensor, |
|
k: torch.Tensor, |
|
v: torch.Tensor, |
|
beta: torch.Tensor, |
|
scale: float, |
|
initial_state: torch.Tensor, |
|
output_final_state: bool, |
|
offsets: Optional[torch.LongTensor] = None, |
|
head_first: bool = True, |
|
use_qk_l2norm_in_kernel: bool = True |
|
): |
|
T = q.shape[2] if head_first else q.shape[1] |
|
chunk_size = min(64, max(triton.next_power_of_2(T), 16)) |
|
|
|
q_orig = q |
|
k_orig = k |
|
|
|
if use_qk_l2norm_in_kernel: |
|
q = l2norm_fwd(q) |
|
k = l2norm_fwd(k) |
|
|
|
|
|
|
|
|
|
|
|
indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None |
|
|
|
o, A, final_state = chunk_delta_rule_fwd( |
|
q=q, |
|
k=k, |
|
v=v, |
|
beta=beta, |
|
scale=scale, |
|
initial_state=initial_state, |
|
output_final_state=output_final_state, |
|
offsets=offsets, |
|
indices=indices, |
|
head_first=head_first, |
|
chunk_size=chunk_size |
|
) |
|
ctx.save_for_backward(q_orig, k_orig, v, beta, A, initial_state) |
|
ctx.chunk_size = chunk_size |
|
ctx.scale = scale |
|
ctx.offsets = offsets |
|
ctx.indices = indices |
|
ctx.head_first = head_first |
|
ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel |
|
return o.to(q.dtype), final_state |
|
|
|
@staticmethod |
|
@input_guard |
|
@autocast_custom_bwd |
|
def backward( |
|
ctx, |
|
do: torch.Tensor, |
|
dht: torch.Tensor |
|
): |
|
q, k, v, beta, A, initial_state = ctx.saved_tensors |
|
use_qk_l2norm_in_kernel = ctx.use_qk_l2norm_in_kernel |
|
if use_qk_l2norm_in_kernel: |
|
q, q_orig = l2norm_fwd(q), q |
|
k, k_orig = l2norm_fwd(k), k |
|
|
|
dq, dk, dv, db, dh0 = chunk_delta_rule_bwd( |
|
q=q, |
|
k=k, |
|
v=v, |
|
beta=beta, |
|
A=A, |
|
scale=ctx.scale, |
|
initial_state=initial_state, |
|
do=do, |
|
dht=dht, |
|
offsets=ctx.offsets, |
|
indices=ctx.indices, |
|
head_first=ctx.head_first, |
|
chunk_size=ctx.chunk_size |
|
) |
|
if use_qk_l2norm_in_kernel: |
|
dq = l2norm_bwd(q_orig, dq) |
|
dk = l2norm_bwd(k_orig, dk) |
|
return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), db.to(beta.dtype), None, dh0, None, None, None, None, None, None |
|
|
|
|
|
@torch.compiler.disable |
|
def chunk_delta_rule( |
|
q: torch.Tensor, |
|
k: torch.Tensor, |
|
v: torch.Tensor, |
|
beta: torch.Tensor, |
|
scale: float = None, |
|
initial_state: torch.Tensor = None, |
|
output_final_state: bool = False, |
|
cu_seqlens: Optional[torch.LongTensor] = None, |
|
head_first: bool = False, |
|
use_qk_l2norm_in_kernel: bool = False |
|
): |
|
r""" |
|
Args: |
|
q (torch.Tensor): |
|
queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. |
|
k (torch.Tensor): |
|
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. |
|
v (torch.Tensor): |
|
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. |
|
beta (torch.Tensor): |
|
betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. |
|
scale (Optional[int]): |
|
Scale factor for the RetNet attention scores. |
|
If not provided, it will default to `1 / sqrt(K)`. Default: `None`. |
|
initial_state (Optional[torch.Tensor]): |
|
Initial state of shape `[N, H, K, V]` for `N` input sequences. |
|
For equal-length input sequences, `N` equals the batch size `B`. |
|
Default: `None`. |
|
output_final_state (Optional[bool]): |
|
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. |
|
cu_seqlens (torch.LongTensor): |
|
Cumulative sequence lengths of shape `[N+1]` used for variable-length training, |
|
consistent with the FlashAttention API. |
|
head_first (Optional[bool]): |
|
Whether the inputs are in the head-first format, which is not supported for variable-length inputs. |
|
Default: `False`. |
|
use_qk_l2norm_in_kernel (Optional[bool]): |
|
Whether to use qk l2norm within the kernel for saving GPU memory. |
|
Default: `False`. |
|
|
|
Returns: |
|
o (torch.Tensor): |
|
Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. |
|
final_state (torch.Tensor): |
|
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. |
|
|
|
Examples:: |
|
>>> import torch |
|
>>> import torch.nn.functional as F |
|
>>> from einops import rearrange |
|
>>> from fla.ops.delta_rule import chunk_delta_rule |
|
# inputs with equal lengths |
|
>>> B, T, H, K, V = 4, 2048, 4, 512, 512 |
|
>>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda') |
|
>>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1) |
|
>>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda') |
|
>>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid() |
|
>>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda') |
|
>>> o, ht = chunk_delta_rule( |
|
q, k, v, beta, |
|
initial_state=h0, |
|
output_final_state=True |
|
) |
|
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required |
|
>>> q, k, v, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta)) |
|
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected |
|
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) |
|
>>> o_var, ht_var = chunk_delta_rule( |
|
q, k, v, beta, |
|
initial_state=h0, |
|
output_final_state=True, |
|
cu_seqlens=cu_seqlens |
|
) |
|
""" |
|
assert q.dtype == k.dtype == v.dtype |
|
assert q.dtype != torch.float32, "ChunkDeltaRuleFunction does not support float32. Please use bfloat16." |
|
assert len(beta.shape) == 3, "beta must be of shape (batch size, num of head, seq len)." |
|
|
|
if cu_seqlens is not None: |
|
if q.shape[0] != 1: |
|
raise ValueError( |
|
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." |
|
f"Please flatten variable-length inputs before processing." |
|
) |
|
if head_first: |
|
raise RuntimeError( |
|
"Sequences with variable lengths are not supported for head-first mode" |
|
) |
|
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: |
|
raise ValueError( |
|
f"The number of initial states is expected to be equal to the number of input sequences, " |
|
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." |
|
) |
|
if head_first: |
|
q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v)) |
|
beta = rearrange(beta, 'b h t -> b t h') |
|
scale = k.shape[-1] ** -0.5 if scale is None else scale |
|
o, final_state = ChunkDeltaRuleFunction.apply( |
|
q, |
|
k, |
|
v, |
|
beta, |
|
scale, |
|
initial_state, |
|
output_final_state, |
|
cu_seqlens, |
|
False, |
|
use_qk_l2norm_in_kernel |
|
) |
|
if head_first: |
|
o = rearrange(o, 'b t h v -> b h t v') |
|
return o, final_state |
|
|