# -*- coding: utf-8 -*- # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang from __future__ import annotations import math from typing import TYPE_CHECKING, Dict, Optional, Tuple import torch import torch.nn as nn from einops import rearrange from torch.nn import functional as F from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule if TYPE_CHECKING: from transformers.processing_utils import Unpack from fla.models.utils import Cache @torch.compile def elu_p1(x): return (F.elu(x, 1., False) + 1.).to(x) @torch.compile def sum_norm(x): return (x / x.sum(-1, keepdim=True)).to(x) class GatedDeltaNet(nn.Module): """ The layer implementaion for [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464). # noqa Similar to Mamba2, each layer contains around 6*hidden_size*hidden_size parameters. Parameter alloation when use_gate=True: - 0.75 * hidden_size * hidden_size for the q_proj and k_proj each - 1.5 * hidden_size * hidden_size for the v_proj, g_proj and o_proj each - Others are ignorably small. - In total = 0.75 * 2 + 1.5 * 3 = 6 * hidden_size * hidden_size NOTE: num_heads * head_dim = 0.75 * hidden_size, please make sure to set the correct num_heads and head_dim. Parameter allocation when use_gate=False: - 1 * hidden_size * hidden_size for the q_proj and k_proj each - 2 * hidden_size * hidden_size for the v_proj and o_proj each - Others are ignorably small. - In total = 1 * 2 + 2 * 2 = 6 * hidden_size * hidden_size Args: hidden_size (int, Optional): The hidden size of the input. Default: 2048. expand_v (float, Optional): The expansion ratio for the value dim. Default: 2.0. head_dim (int, Optional): The dimension of each head. Default: 256. num_heads (int, Optional): The number of heads. Default: 4. mode (str, Optional): Which Gated DeltaNet kernel to use. Currently available: `chunk` and `fused_recurrent`. Default: `chunk`. use_beta (bool, Optional): Whether to use beta. Default: `True`. use_gate (bool, Optional): Whether to use output gate. Default: `True`. use_short_conv (bool, Optional): Whether to use short convolutions. Default: `True`. conv_size (int, Optional): The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4. conv_bias (bool, Optional): Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`. layer_idx (int, Optional): The index of the layer. Default: None. norm_eps (float, Optional): The epsilon value for the normalization layer. Default: 1e-5. """ def __init__( self, hidden_size: int = 2048, expand_v: float = 2, head_dim: int = 256, num_heads: int = 6, mode: str = 'chunk', use_gate: bool = True, use_short_conv: bool = True, conv_size: int = 4, conv_bias: bool = False, layer_idx: int = None, norm_eps: float = 1e-5, **kwargs ) -> GatedDeltaNet: super().__init__() self.mode = mode self.hidden_size = hidden_size self.expand_v = expand_v self.use_gate = use_gate self.use_short_conv = use_short_conv self.conv_size = conv_size self.conv_bias = conv_bias self.head_dim = head_dim self.num_heads = num_heads self.key_dim = int(self.num_heads * self.head_dim) self.value_dim = int(self.key_dim * self.expand_v) self.head_k_dim = head_dim self.head_v_dim = int(head_dim * self.expand_v) self.layer_idx = layer_idx # Consistency check: Ensure expand_v produces integer values if not math.isclose(self.key_dim * expand_v, self.value_dim, rel_tol=1e-5): raise ValueError( f"expand_v={expand_v} does not produce an integer value when multiplied by key_dim={self.key_dim}. " f"Resulting value_dim would be {self.key_dim * expand_v}, which is invalid for nn.Linear." ) if not math.isclose(head_dim * expand_v, self.head_v_dim, rel_tol=1e-5): raise ValueError( f"expand_v={expand_v} does not produce an integer value when multiplied by head_dim={head_dim}. " f"Resulting head_v_dim would be {head_dim * expand_v}, which is invalid for FusedRMSNormGated." ) assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) self.a_proj = nn.Linear(hidden_size, self.num_heads, bias=False) self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False) A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16) self.A_log = nn.Parameter(torch.log(A)) self.A_log._no_weight_decay = True # hard coded for now dt_min = 0.001 dt_max = 0.1 dt_init_floor = 1e-4 dt = torch.exp( torch.rand(self.num_heads) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) ) dt = torch.clamp(dt, min=dt_init_floor) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) self.dt_bias = nn.Parameter(inv_dt) # Just to be explicit. Without this we already don't put wd on dt_bias because of the check # name.endswith("bias") in param_grouping.py self.dt_bias._no_weight_decay = True if use_short_conv: self.conv_size = conv_size self.q_conv1d = ShortConvolution( hidden_size=self.key_dim, kernel_size=conv_size, activation='silu' ) self.k_conv1d = ShortConvolution( hidden_size=self.key_dim, kernel_size=conv_size, activation='silu' ) self.v_conv1d = ShortConvolution( hidden_size=self.value_dim, kernel_size=conv_size, activation='silu' ) else: raise UserWarning( "ShortConvolution is crucial to the performance. " "Do not turn it off, i.e., setting `use_short_conv=False` unless you know what you are doing." ) if use_gate: self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) self.o_norm = FusedRMSNormGated(self.head_v_dim, eps=norm_eps) else: self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps) self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, **kwargs: Unpack[Dict] ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: if attention_mask is not None: assert len(attention_mask.shape) == 2, ( "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " "for padding purposes (0 indicating padding). " "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." ) mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode if self.training: assert mode == 'chunk', "Only chunk mode is supported in training." last_state = None if past_key_values is not None and len(past_key_values) > self.layer_idx: last_state = past_key_values[self.layer_idx] cu_seqlens = kwargs.get('cu_seqlens', None) if self.use_short_conv: conv_state_q, conv_state_k, conv_state_v = None, None, None if last_state is not None: conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None q, conv_state_q = self.q_conv1d( x=self.q_proj(hidden_states), mask=conv_mask, cache=conv_state_q, output_final_state=use_cache, cu_seqlens=cu_seqlens ) k, conv_state_k = self.k_conv1d( x=self.k_proj(hidden_states), mask=conv_mask, cache=conv_state_k, output_final_state=use_cache, cu_seqlens=cu_seqlens ) v, conv_state_v = self.v_conv1d( x=self.v_proj(hidden_states), mask=conv_mask, cache=conv_state_v, output_final_state=use_cache, cu_seqlens=cu_seqlens ) else: q = F.silu(self.q_proj(hidden_states)) k = F.silu(self.k_proj(hidden_states)) v = F.silu(self.v_proj(hidden_states)) q, k = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', d=self.head_k_dim), (q, k)) v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim) beta = self.b_proj(hidden_states).sigmoid() g = -self.A_log.float().exp() * F.softplus(self.a_proj(hidden_states).float() + self.dt_bias) # dealing with padding if attention_mask is not None: beta = beta.mul(attention_mask[:, -beta.shape[-2]:, None]) g = g.mul(attention_mask[:, -g.shape[-2]:, None]) recurrent_state = last_state['recurrent_state'] if last_state is not None else None if mode == 'chunk': o, recurrent_state = chunk_gated_delta_rule( q=q, k=k, v=v, g=g, beta=beta, initial_state=recurrent_state, output_final_state=use_cache, cu_seqlens=cu_seqlens, head_first=False, use_qk_l2norm_in_kernel=True ) elif mode == 'fused_recurrent': o, recurrent_state = fused_recurrent_gated_delta_rule( q=q, k=k, v=v, g=g, beta=beta, initial_state=recurrent_state, output_final_state=use_cache, cu_seqlens=cu_seqlens, head_first=False, use_qk_l2norm_in_kernel=True ) if past_key_values is not None: past_key_values.update( recurrent_state=recurrent_state, conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None, layer_idx=self.layer_idx, offset=q.shape[1] ) if self.use_gate: g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim) o = self.o_norm(o, g) else: o = self.o_norm(o) o = rearrange(o, 'b t h d -> b t (h d)') o = self.o_proj(o) return o, None, past_key_values