|
|
|
|
|
|
|
from typing import Optional |
|
|
|
import torch |
|
|
|
from fla.ops.generalized_delta_rule import chunk_dplr_delta_rule |
|
|
|
|
|
def chunk_rwkv7( |
|
r: torch.Tensor, |
|
w: torch.Tensor, |
|
k: torch.Tensor, |
|
v: torch.Tensor, |
|
a: torch.Tensor, |
|
b: torch.Tensor, |
|
scale: float = 1.0, |
|
initial_state: torch.Tensor = None, |
|
output_final_state: bool = True, |
|
cu_seqlens: Optional[torch.LongTensor] = None, |
|
head_first: bool = False |
|
): |
|
""" |
|
Args: |
|
r (torch.Tensor): |
|
r of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. |
|
w (torch.Tensor): |
|
log decay of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. |
|
k (torch.Tensor): |
|
k of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. |
|
v (torch.Tensor): |
|
v of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. |
|
a (torch.Tensor): |
|
a of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. |
|
b (torch.Tensor): |
|
b of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. |
|
scale (float): |
|
scale of the attention. |
|
initial_state (Optional[torch.Tensor]): |
|
Initial state of shape `[N, H, K, V]` for `N` input sequences. |
|
For equal-length input sequences, `N` equals the batch size `B`. |
|
Default: `None`. |
|
output_final_state (Optional[bool]): |
|
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. |
|
cu_seqlens (torch.LongTensor): |
|
Cumulative sequence lengths of shape `[N+1]` used for variable-length training, |
|
consistent with the FlashAttention API. |
|
head_first (bool): |
|
whether to use head first. Recommended to be False to avoid extra transposes. |
|
""" |
|
return chunk_dplr_delta_rule( |
|
q=r, |
|
k=k, |
|
v=v, |
|
a=a, |
|
b=b, |
|
gk=w, |
|
scale=scale, |
|
initial_state=initial_state, |
|
output_final_state=output_final_state, |
|
cu_seqlens=cu_seqlens, |
|
head_first=head_first |
|
) |
|
|