|
|
|
|
|
|
|
from typing import Optional, Tuple |
|
|
|
import torch |
|
|
|
from fla.ops.simple_gla.fused_recurrent import fused_recurrent_simple_gla |
|
|
|
|
|
def fused_recurrent_retention( |
|
q: torch.Tensor, |
|
k: torch.Tensor, |
|
v: torch.Tensor, |
|
scale: Optional[float] = None, |
|
initial_state: Optional[torch.Tensor] = None, |
|
output_final_state: bool = False, |
|
reverse: bool = False, |
|
cu_seqlens: Optional[torch.LongTensor] = None, |
|
head_first: bool = True |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
if head_first: |
|
n_heads = q.shape[1] |
|
else: |
|
n_heads = q.shape[2] |
|
s = (1 - q.new_tensor(2., dtype=torch.float).pow(-5. - q.new_tensor(range(n_heads), dtype=torch.float))).log() |
|
if head_first: |
|
g = s[None, :, None].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous() |
|
else: |
|
g = s[None, None, :].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous() |
|
return fused_recurrent_simple_gla( |
|
q=q, |
|
k=k, |
|
v=v, |
|
g=g, |
|
scale=scale, |
|
initial_state=initial_state, |
|
output_final_state=output_final_state, |
|
reverse=reverse, |
|
cu_seqlens=cu_seqlens, |
|
head_first=head_first |
|
) |
|
|