# -*- coding: utf-8 -*- # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang from typing import Optional, Tuple import torch import triton import triton.language as tl from fla.ops.common.utils import prepare_chunk_offsets from fla.ops.utils.op import exp from fla.utils import check_shared_mem, use_cuda_graph @triton.heuristics({ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, 'USE_OFFSETS': lambda args: args['offsets'] is not None, }) @triton.autotune( configs=[ triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [2, 4, 8, 16, 32] for num_stages in [2, 3, 4] ], key=['BT', 'BK', 'BV'], use_cuda_graph=use_cuda_graph, ) @triton.jit(do_not_specialize=['T']) def chunk_dplr_fwd_kernel_h( kg, v, w, bg, u, v_new, gk, h, h0, ht, offsets, chunk_offsets, T, H: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, BC: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, NT: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, USE_OFFSETS: tl.constexpr, HEAD_FIRST: tl.constexpr, ): i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_n, i_h = i_nh // H, i_nh % H if USE_OFFSETS: bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos NT = tl.cdiv(T, BT) boh = tl.load(chunk_offsets + i_n).to(tl.int32) else: bos, eos = i_n * T, i_n * T + T NT = tl.cdiv(T, BT) boh = i_n * NT # [BK, BV] b_h = tl.zeros([BK, BV], dtype=tl.float32) if USE_INITIAL_STATE: p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) for i_t in range(NT): if HEAD_FIRST: p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) else: p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) b_hc = tl.zeros([BK, BV], dtype=tl.float32) # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden for i_c in range(tl.cdiv(min(BT, T - i_t * BT), BC)): if HEAD_FIRST: p_kg = tl.make_block_ptr(kg + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) p_bg = tl.make_block_ptr(bg + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) p_w = tl.make_block_ptr(w + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0)) p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) p_u = tl.make_block_ptr(u + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) p_v_new = tl.make_block_ptr(v_new+i_nh*T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) else: p_kg = tl.make_block_ptr(kg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) p_bg = tl.make_block_ptr(bg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) p_w = tl.make_block_ptr(w+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0)) p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) p_u = tl.make_block_ptr(u+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT+i_c*BC, i_v * BV), (BC, BV), (1, 0)) # [BK, BC] b_kg = tl.load(p_kg, boundary_check=(0, 1)) b_v = tl.load(p_v, boundary_check=(0, 1)) b_w = tl.load(p_w, boundary_check=(0, 1)) b_bg = tl.load(p_bg, boundary_check=(0, 1)) b_v2 = tl.dot(b_w, b_h.to(b_w.dtype)) + tl.load(p_u, boundary_check=(0, 1)) b_hc += tl.dot(b_kg, b_v) b_hc += tl.dot(b_bg.to(b_hc.dtype), b_v2) tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1)) last_idx = min((i_t + 1) * BT, T) - 1 if HEAD_FIRST: b_g_last = tl.load(gk + i_nh * T * K + last_idx * K + tl.arange(0, BK), mask=tl.arange(0, BK) < K).to(tl.float32) else: b_g_last = tl.load(gk + (bos + last_idx) * H * K + i_h * K + tl.arange(0, BK), mask=tl.arange(0, BK) < K).to(tl.float32) b_h *= exp(b_g_last[:, None]) b_h += b_hc if STORE_FINAL_STATE: p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) tl.store(p_ht, b_h.to(p_ht.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) def chunk_dplr_fwd_h( kg: torch.Tensor, v: torch.Tensor, w: torch.Tensor, u: torch.Tensor, bg: torch.Tensor, gk: torch.Tensor, initial_state: Optional[torch.Tensor] = None, output_final_state: bool = False, offsets: Optional[torch.LongTensor] = None, indices: Optional[torch.LongTensor] = None, head_first: bool = True, chunk_size: int = 64 ) -> Tuple[torch.Tensor, torch.Tensor]: if head_first: B, H, T, K, V = *kg.shape, u.shape[-1] else: B, T, H, K, V = *kg.shape, u.shape[-1] BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) # N: the actual number of sequences in the batch with either equal or variable lengths if offsets is None: N, NT, chunk_offsets = B, triton.cdiv(T, BT), None else: N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT) BK = triton.next_power_of_2(K) assert BK <= 256, "current kernel does not support head dimension larger than 256." # H100 can have larger block size if check_shared_mem('hopper', kg.device.index): BV = 64 BC = 64 if K <= 128 else 32 elif check_shared_mem('ampere', kg.device.index): # A100 BV = 32 BC = 32 else: BV = 16 BC = 16 BC = min(BT, BC) NK = triton.cdiv(K, BK) NV = triton.cdiv(V, BV) assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' if head_first: h = kg.new_empty(B, H, NT, K, V) else: h = kg.new_empty(B, NT, H, K, V) final_state = kg.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None v_new = torch.empty_like(u) grid = (NK, NV, N * H) chunk_dplr_fwd_kernel_h[grid]( kg=kg, v=v, w=w, bg=bg, u=u, v_new=v_new, h=h, gk=gk, h0=initial_state, ht=final_state, offsets=offsets, chunk_offsets=chunk_offsets, T=T, H=H, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT, HEAD_FIRST=head_first ) return h, v_new, final_state