# -*- coding: utf-8 -*- # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang from typing import Optional, Tuple import torch from fla.ops.simple_gla.fused_recurrent import fused_recurrent_simple_gla def fused_recurrent_retention( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, scale: Optional[float] = None, initial_state: Optional[torch.Tensor] = None, output_final_state: bool = False, reverse: bool = False, cu_seqlens: Optional[torch.LongTensor] = None, head_first: bool = True ) -> Tuple[torch.Tensor, torch.Tensor]: if head_first: n_heads = q.shape[1] else: n_heads = q.shape[2] s = (1 - q.new_tensor(2., dtype=torch.float).pow(-5. - q.new_tensor(range(n_heads), dtype=torch.float))).log() if head_first: g = s[None, :, None].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous() else: g = s[None, None, :].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous() return fused_recurrent_simple_gla( q=q, k=k, v=v, g=g, scale=scale, initial_state=initial_state, output_final_state=output_final_state, reverse=reverse, cu_seqlens=cu_seqlens, head_first=head_first )