zaydzuhri's picture
Add files using upload-large-folder tool
4135502 verified
raw
history blame
2.19 kB
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
from typing import Optional
import torch
from fla.ops.generalized_delta_rule import chunk_dplr_delta_rule
def chunk_rwkv7(
r: torch.Tensor,
w: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
scale: float = 1.0,
initial_state: torch.Tensor = None,
output_final_state: bool = True,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False
):
"""
Args:
r (torch.Tensor):
r of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
w (torch.Tensor):
log decay of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
k (torch.Tensor):
k of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
v (torch.Tensor):
v of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
a (torch.Tensor):
a of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
b (torch.Tensor):
b of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
scale (float):
scale of the attention.
initial_state (Optional[torch.Tensor]):
Initial state of shape `[N, H, K, V]` for `N` input sequences.
For equal-length input sequences, `N` equals the batch size `B`.
Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
head_first (bool):
whether to use head first. Recommended to be False to avoid extra transposes.
"""
return chunk_dplr_delta_rule(
q=r,
k=k,
v=v,
a=a,
b=b,
gk=w,
scale=scale,
initial_state=initial_state,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
head_first=head_first
)