# -*- coding: utf-8 -*- import logging from typing import Optional import torch import triton import triton.language as tl from fla.utils import check_pytorch_version, device, input_guard, use_cuda_graph logger = logging.getLogger(__name__) if not check_pytorch_version('2.4'): logger.warning('PyTorch < 2.4 detected - computations may be slower due to lack of optimizations') @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE': block_size}, num_warps=num_warps) for block_size in [128, 256, 512, 1024, 2048, 4096, 8192] for num_warps in [1, 2, 4, 8, 16, 32] ], key=['hidden_dim'], use_cuda_graph=use_cuda_graph, ) @triton.jit def fused_addcmul_fwd_kernel( hidden_ptr, x_ptr, ixr_ptr, ixw_ptr, ixk_ptr, ixv_ptr, ixa_ptr, ixg_ptr, oxr_ptr, oxw_ptr, oxk_ptr, oxv_ptr, oxa_ptr, oxg_ptr, use_xg: tl.constexpr, xnumel, hidden_dim: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): xoffset = tl.program_id(0) * BLOCK_SIZE xindex = xoffset + tl.arange(0, BLOCK_SIZE)[:] valid_indices = xnumel - xoffset xmask = xindex < (xoffset + valid_indices) x0 = xindex % hidden_dim b_hiddn = tl.load(hidden_ptr + (xindex), xmask, other=0.).to(tl.float32) b_x = tl.load(x_ptr + (xindex), xmask, other=0.).to(tl.float32) b_ixr = tl.load(ixr_ptr + (x0), eviction_policy='evict_last').to(tl.float32) b_ixw = tl.load(ixw_ptr + (x0), eviction_policy='evict_last').to(tl.float32) b_ixk = tl.load(ixk_ptr + (x0), eviction_policy='evict_last').to(tl.float32) b_ixv = tl.load(ixv_ptr + (x0), eviction_policy='evict_last').to(tl.float32) b_ixa = tl.load(ixa_ptr + (x0), eviction_policy='evict_last').to(tl.float32) b_oxr = b_hiddn + b_x * b_ixr b_oxw = b_hiddn + b_x * b_ixw b_oxk = b_hiddn + b_x * b_ixk b_oxv = b_hiddn + b_x * b_ixv b_oxa = b_hiddn + b_x * b_ixa tl.store(oxr_ptr + (xindex), b_oxr.to(oxr_ptr.dtype.element_ty), xmask) tl.store(oxw_ptr + (xindex), b_oxw.to(oxw_ptr.dtype.element_ty), xmask) tl.store(oxk_ptr + (xindex), b_oxk.to(oxk_ptr.dtype.element_ty), xmask) tl.store(oxv_ptr + (xindex), b_oxv.to(oxv_ptr.dtype.element_ty), xmask) tl.store(oxa_ptr + (xindex), b_oxa.to(oxa_ptr.dtype.element_ty), xmask) if use_xg: b_ixg = tl.load(ixg_ptr + (x0), eviction_policy='evict_last').to(tl.float32) b_oxg = b_hiddn + b_x * b_ixg tl.store(oxg_ptr + (xindex), b_oxg.to(oxg_ptr.dtype.element_ty), xmask) @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE': block_size}, num_warps=num_warps) for block_size in [128, 256, 512, 1024, 2048, 4096, 8192] for num_warps in [1, 2, 4, 8, 16, 32] ], key=['hidden_dim'], use_cuda_graph=use_cuda_graph, ) @triton.jit def addcmul_bwd_kernel1( ixr_ptr, ixw_ptr, ixk_ptr, ixv_ptr, ixa_ptr, ixg_ptr, dxr_ptr, dxw_ptr, dxk_ptr, dxv_ptr, dxa_ptr, dxg_ptr, ghidden_ptr, gx_ptr, use_xg: tl.constexpr, xnumel, hidden_dim: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): xoffset = tl.program_id(0) * BLOCK_SIZE xindex = xoffset + tl.arange(0, BLOCK_SIZE)[:] valid_indices = xnumel - xoffset xmask = xindex < (xoffset + valid_indices) x0 = xindex % hidden_dim b_dxr = tl.load(dxr_ptr + (xindex), None).to(tl.float32) b_dxw = tl.load(dxw_ptr + (xindex), None).to(tl.float32) b_dxk = tl.load(dxk_ptr + (xindex), None).to(tl.float32) b_dxv = tl.load(dxv_ptr + (xindex), None).to(tl.float32) b_dxa = tl.load(dxa_ptr + (xindex), None).to(tl.float32) b_ixr = tl.load(ixr_ptr + (x0), eviction_policy='evict_last').to(tl.float32) b_ixw = tl.load(ixw_ptr + (x0), eviction_policy='evict_last').to(tl.float32) b_iwk = tl.load(ixk_ptr + (x0), eviction_policy='evict_last').to(tl.float32) b_ixv = tl.load(ixv_ptr + (x0), eviction_policy='evict_last').to(tl.float32) b_ixa = tl.load(ixa_ptr + (x0), eviction_policy='evict_last').to(tl.float32) if use_xg: b_dxg = tl.load(dxg_ptr + (xindex), None).to(tl.float32) b_ixg = tl.load(ixg_ptr + (x0), eviction_policy='evict_last').to(tl.float32) g_hidden = b_dxr + b_dxw + b_dxk + b_dxv + b_dxa + b_dxg g_x = b_dxr * b_ixr + b_dxw * b_ixw + b_dxk * b_iwk + b_dxv * b_ixv + b_dxa * b_ixa + b_dxg * b_ixg else: g_hidden = b_dxr + b_dxw + b_dxk + b_dxv + b_dxa g_x = b_dxr * b_ixr + b_dxw * b_ixw + b_dxk * b_iwk + b_dxv * b_ixv + b_dxa * b_ixa tl.store(ghidden_ptr + (xindex), g_hidden.to(ghidden_ptr.dtype.element_ty), xmask) tl.store(gx_ptr + (xindex), g_x.to(gx_ptr.dtype.element_ty), xmask) def addcmul_bwd1(d_oxr, d_oxw, d_oxk, d_oxv, d_oxa, d_oxg, x_r, x_w, x_k, x_v, x_a, x_g, hidden_states, xx, use_xg): d_hiddn = torch.empty_like(hidden_states) d_xx = torch.empty_like(xx) numel = hidden_states.numel() def grid(meta): return (triton.cdiv(meta['xnumel'], meta['BLOCK_SIZE']),) addcmul_bwd_kernel1[grid]( ixr_ptr=x_r, ixw_ptr=x_w, ixk_ptr=x_k, ixv_ptr=x_v, ixa_ptr=x_a, ixg_ptr=x_g, dxr_ptr=d_oxr, dxw_ptr=d_oxw, dxk_ptr=d_oxk, dxv_ptr=d_oxv, dxa_ptr=d_oxa, dxg_ptr=d_oxg, ghidden_ptr=d_hiddn, gx_ptr=d_xx, use_xg=use_xg, xnumel=numel, hidden_dim=hidden_states.size(-1), ) return d_hiddn, d_xx @torch.compile(fullgraph=True) def addcmul_bwd2(d_oxr, d_oxw, d_oxk, d_oxv, d_oxa, d_oxg, xx, use_xg: bool): g_xr = (d_oxr * xx).sum(dim=(0, 1), keepdim=True) g_xw = (d_oxw * xx).sum(dim=(0, 1), keepdim=True) g_xk = (d_oxk * xx).sum(dim=(0, 1), keepdim=True) g_xv = (d_oxv * xx).sum(dim=(0, 1), keepdim=True) g_xa = (d_oxa * xx).sum(dim=(0, 1), keepdim=True) g_xg = (d_oxg * xx).sum(dim=(0, 1), keepdim=True) if use_xg else None return g_xr, g_xw, g_xk, g_xv, g_xa, g_xg class Rwkv7FusedAddcmul(torch.autograd.Function): @staticmethod @input_guard def forward(ctx, hidden_states, xx, x_r, x_w, x_k, x_v, x_a, x_g, num_elements ): oxr = torch.empty_like(hidden_states) oxw = torch.empty_like(hidden_states) oxk = torch.empty_like(hidden_states) oxv = torch.empty_like(hidden_states) oxa = torch.empty_like(hidden_states) if x_g is not None: use_xg = True oxg = torch.empty_like(hidden_states) else: use_xg = False oxg = None ctx.save_for_backward(hidden_states, xx, x_r, x_w, x_k, x_v, x_a, x_g) ctx.use_xg = use_xg def grid(meta): return (triton.cdiv(meta['xnumel'], meta['BLOCK_SIZE']),) fused_addcmul_fwd_kernel[grid]( hidden_states, xx, x_r, x_w, x_k, x_v, x_a, x_g, oxr, oxw, oxk, oxv, oxa, oxg, use_xg, num_elements, hidden_states.size(-1), ) return oxr, oxw, oxk, oxv, oxa, oxg @staticmethod @input_guard def backward(ctx, dxr, dxw, dxk, dxv, dxa, dxg): hidden_states, xx, x_r, x_w, x_k, x_v, x_a, x_g = ctx.saved_tensors d_hiddn, d_xx = addcmul_bwd1(dxr, dxw, dxk, dxv, dxa, dxg, x_r, x_w, x_k, x_v, x_a, x_g, hidden_states, xx, ctx.use_xg) d_ixr, d_ixw, d_ixk, d_ixv, d_ixa, d_ixg = addcmul_bwd2(dxr, dxw, dxk, dxv, dxa, dxg, xx, ctx.use_xg) return d_hiddn, d_xx, d_ixr, d_ixw, d_ixk, d_ixv, d_ixa, d_ixg, None def fused_addcmul_rwkv7( hidden_states: torch.Tensor, xx: torch.Tensor, xr: torch.Tensor, xw: torch.Tensor, xk: torch.Tensor, xv: torch.Tensor, xa: torch.Tensor, xg: Optional[torch.Tensor] = None ): num_elements = hidden_states.numel() if num_elements < 16777216 and device == "cuda": return torch_addcmul_rwkv7(hidden_states, xx, xr, xw, xk, xv, xa, xg) else: return Rwkv7FusedAddcmul.apply(hidden_states, xx, xr, xw, xk, xv, xa, xg, num_elements) def torch_addcmul_rwkv7(hidden_states, xx, xr, xw, xk, xv, xa, xg=None): oxr = torch.addcmul(hidden_states, xx, xr) oxw = torch.addcmul(hidden_states, xx, xw) oxk = torch.addcmul(hidden_states, xx, xk) oxv = torch.addcmul(hidden_states, xx, xv) oxa = torch.addcmul(hidden_states, xx, xa) if xg is not None: oxg = torch.addcmul(hidden_states, xx, xg) return oxr, oxw, oxk, oxv, oxa, oxg else: return oxr, oxw, oxk, oxv, oxa, None