# ========== AUTO GENERATED FILE ========= # This file is auto generated by 'hf_builder.py', do not edit this file directly # As part of the RWKV/RWKV-block project # ========== =================== ========= # ---------------- # block/kernel/rwkv7_attn_pytorch.py # ---------------- import torch # Enable tensorfloat 32 torch.set_float32_matmul_precision('high') # Handles the RWKV v7 attention mechanic, in pure pytorch def rwkv7_attn_pytorch( r,w,k,v, kk,a, BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, xx, wkv_state_in ): ### Reference implement # return rwkv7_attn_pytorch_ref( # r,w,k,v, kk,a, # BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, # xx, wkv_state_in # ) ### # # This works, but it has too much of a vram overhead ### # return rwkv7_attn_pytorch_v2_chunk_w_compile_break( # r,w,k,v, kk,a, # BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, # xx, wkv_state_in # ) ### # > per 9k chunk, per block, on a 4090 ... # > with forward_with_reduce_compile on the timemix ... # # Somehow... # The reference implement takes: 2281ms # The chunked version takes: 389ms (chunksize 256) # Get the shape B,T,HC = w.shape # Compute the chunks chunk_size = 256 chunk_count = SEQ_LEN // chunk_size chunk_remainder = SEQ_LEN % chunk_size # The wkv_state_out wkv_state_out = wkv_state_in.float() # # List of tensor to build # xlist = [] xx = xx.clone() # Loop over the chunks for i in range(chunk_count): sta = i * chunk_size end = sta + chunk_size xx[:,sta:end], wkv_state_out = rwkv7_attn_pytorch_v2_chunk_w_compile_break( # xpart, wkv_state_out = rwkv7_attn_pytorch_chunk_with_w_compile_break( r[:,sta:end],w[:,sta:end],k[:,sta:end],v[:,sta:end], kk[:,sta:end],a[:,sta:end], BATCH_SIZE, chunk_size, N_HEAD, HEAD_SIZE, # xx[:,sta:end], wkv_state_out torch.zeros(B,chunk_size,HC, dtype=xx.dtype, device=xx.device), wkv_state_out ) # xlist.append(xpart) # Handle the remainder if chunk_remainder > 0: sta = chunk_count * chunk_size end = sta + chunk_remainder xx[:,sta:end], wkv_state_out = rwkv7_attn_pytorch_v2_chunk_w_compile_break( # xpart, wkv_state_out = rwkv7_attn_pytorch_chunk_with_w_compile_break( r[:,sta:end],w[:,sta:end],k[:,sta:end],v[:,sta:end], kk[:,sta:end],a[:,sta:end], BATCH_SIZE, chunk_remainder, N_HEAD, HEAD_SIZE, # xx[:,sta:end], wkv_state_out, torch.zeros(B,chunk_remainder,HC, dtype=xx.dtype, device=xx.device), wkv_state_out, # offset=0, chunk_size=chunk_remainder ) # xlist.append(xpart) # # Concatenate the list # xx = torch_cat_no_compiler(xlist, dim=1) # Return the output return xx, wkv_state_out.to(dtype=wkv_state_in.dtype) #################################################################################################### # Working reference copy, that has been validated to be "identical" to the reference implementation # However this has known pytorch compilation issues, hence the modified chunk wise version is used # instead for an approximate 5x speed up #################################################################################################### @torch.compiler.disable() def rwkv7_attn_pytorch_ref( r,w,k,v, kk,a, BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, xx, wkv_state_in ): ######## pure pytorch method # See: https://github.com/BlinkDL/RWKV-LM/blob/d4c42b2cac10f8f3896ce153e2310dc763662b7a/RWKV-v7/rwkv_v7_demo_fast.py#L238 ######## vk_state = wkv_state_in.float() for t in range(SEQ_LEN): r_, w_, k_, v_, kk_, a_ = r[:,t], w[:,t], k[:,t], v[:,t], kk[:,t], a[:,t] vk = v_.view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1) @ k_.view(BATCH_SIZE,N_HEAD,1,HEAD_SIZE) ab = (-kk_).view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1) @ (kk_*a_).view(BATCH_SIZE,N_HEAD,1,HEAD_SIZE) vk_state = (vk_state * w_.view(BATCH_SIZE,N_HEAD,1,HEAD_SIZE).float() + vk_state @ ab.float() + vk.float()) xx[:,t] = ((vk_state.to(dtype=xx.dtype) @ r_.view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1)).view(BATCH_SIZE,N_HEAD*HEAD_SIZE)) wkv_state_out = vk_state.to(dtype=wkv_state_in.dtype) return xx, wkv_state_out #################################################################################################### #################################################################################################### # Modified reference computation done in fp32, # with changes made to bring the result closer to the cuda kernel #################################################################################################### @torch.compiler.disable() def rwkv7_attn_pytorch_ref_fp32( r,w,k,v, kk, iclr, BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, xx, wkv_state_in ): ######## pure pytorch method (modified for fp32) # See: https://github.com/BlinkDL/RWKV-LM/blob/d4c42b2cac10f8f3896ce153e2310dc763662b7a/RWKV-v7/rwkv_v7_demo_fast.py#L238 ######## w = (-w.float().exp()).exp() # wkv_state_in = torch.zeros(BATCH_SIZE,N_HEAD,HEAD_SIZE,HEAD_SIZE, dtype=torch.float,device=w.device) vk_state = wkv_state_in.float() a = -kk b = kk * iclr for t in range(SEQ_LEN): r_, w_, k_, v_, a_, b_= r[:,t].float(), w[:,t].float(), k[:,t].float(), v[:,t].float(), a[:,t].float(), b[:,t].float() # ab = (-kk_).view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1) @ (kk_*a_).view(BATCH_SIZE,N_HEAD,1,HEAD_SIZE) vk = v_.view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1) @ k_.view(BATCH_SIZE,N_HEAD,1,HEAD_SIZE) vk_state = (vk_state * w_.view(BATCH_SIZE,N_HEAD,1,HEAD_SIZE).float() + vk_state @ a_.float().view(BATCH_SIZE, N_HEAD,HEAD_SIZE,1) @ b_.view(BATCH_SIZE, N_HEAD,1,HEAD_SIZE) + vk.float()) xx[:,t] = ((vk_state @ r_.view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1)).view(BATCH_SIZE,N_HEAD*HEAD_SIZE)).to(dtype=xx.dtype) wkv_state_out = vk_state.to(dtype=wkv_state_in.dtype) return xx, wkv_state_out #################################################################################################### def rwkv7_attn_pytorch_chunk( r,w,k,v, kk,a, BATCH_SIZE, N_HEAD, HEAD_SIZE, xx, wkv_state_in, offset=0, chunk_size=16 ): ''' Chunked version of the RWKV7 attention, for better performance. If the chunk size is less then 128, this is generally compilable This is used by the triton/cuda implement, for the remaining % 16 chunks ''' ######## pure pytorch method # See: https://github.com/BlinkDL/RWKV-LM/blob/d4c42b2cac10f8f3896ce153e2310dc763662b7a/RWKV-v7/rwkv_v7_demo_fast.py#L238 ######## vk_state = wkv_state_in.float() for i in range(chunk_size): t = offset + i r_, w_, k_, v_, kk_, a_ = r[:,t], w[:,t], k[:,t], v[:,t], kk[:,t], a[:,t] vk = v_.view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1) @ k_.view(BATCH_SIZE,N_HEAD,1,HEAD_SIZE) ab = (-kk_).view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1) @ (kk_*a_).view(BATCH_SIZE,N_HEAD,1,HEAD_SIZE) vk_state = (vk_state * w_.view(BATCH_SIZE,N_HEAD,1,HEAD_SIZE).float() + vk_state @ ab.float() + vk.float()) xx[:,t] = (vk_state.to(dtype=xx.dtype) @ r_.view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1)).view(BATCH_SIZE,N_HEAD*HEAD_SIZE) wkv_state_out = vk_state.to(dtype=wkv_state_in.dtype) return xx, wkv_state_out def rwkv7_attn_pytorch_v2_chunk_w_compile_break( r,w,k,v, kk,a, BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, xx, wkv_state_in ): ''' Chunked version of the RWKV7 attention, for better performance ''' full_vk_ = v.view(BATCH_SIZE,SEQ_LEN,N_HEAD, HEAD_SIZE,1) @ k.view(BATCH_SIZE,SEQ_LEN,N_HEAD, 1,HEAD_SIZE) full_iclr_ = (kk * a).view(BATCH_SIZE,SEQ_LEN,N_HEAD,1,HEAD_SIZE) full_ab = (-kk).view(BATCH_SIZE,SEQ_LEN,N_HEAD, HEAD_SIZE,1) @ full_iclr_ wkv_xx = torch.empty(BATCH_SIZE,SEQ_LEN,N_HEAD,HEAD_SIZE,HEAD_SIZE, dtype=xx.dtype, device=xx.device) wkv_xx, wkv_state_out = rwkv7_attn_pytorch_v2_inner_w_compile_break( r,w, full_vk_, full_ab, BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, wkv_xx, wkv_state_in # xx, wkv_state_in ) # if BATCH_SIZE != 1: # print("BATCH_SIZE != 1 : ", BATCH_SIZE) # if SEQ_LEN != 256: # print("SEQ_LEN != 256 : ", SEQ_LEN) # xx[:,t] = ((wkv_state.to(dtype=xx.dtype) @ r_.view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1)).view(BATCH_SIZE,N_HEAD*HEAD_SIZE)) xx[:] = (wkv_xx.to(dtype=xx.dtype) @ r.view(BATCH_SIZE,SEQ_LEN,N_HEAD,HEAD_SIZE,1)).view(BATCH_SIZE,SEQ_LEN,N_HEAD*HEAD_SIZE) return xx, wkv_state_out @torch.compiler.disable() def rwkv7_attn_pytorch_v2_inner_w_compile_break( r, w, full_vk_, full_ab, BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, xx, wkv_state_in ): ''' Isolated sub-function with no compilation ''' return rwkv7_attn_pytorch_v2_inner_jit( r, w, full_vk_, full_ab, BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, xx, wkv_state_in ) # @torch.compile(fullgraph=True) @torch.jit.script def rwkv7_attn_pytorch_v2_inner_jit( r, w, full_vk_, full_ab, BATCH_SIZE:int, SEQ_LEN:int, N_HEAD:int, HEAD_SIZE:int, wkv_xx, wkv_state_in ): ''' Isolated sub-function with JIT ''' # wkv_xx = torch.zeros(BATCH_SIZE,SEQ_LEN,N_HEAD,HEAD_SIZE,HEAD_SIZE, dtype=xx.dtype,device=xx.device) # wkv_state_in = torch.zeros(BATCH_SIZE,N_HEAD,HEAD_SIZE,HEAD_SIZE, dtype=torch.float,device=w.device) wkv_state = wkv_state_in for t in range(SEQ_LEN): # r_ = r[:,t] # w_ = w[:,t] # vk = full_vk_[:,t].view(BATCH_SIZE,N_HEAD,HEAD_SIZE,HEAD_SIZE) # ab = full_ab[:,t].view(BATCH_SIZE,N_HEAD,HEAD_SIZE,HEAD_SIZE) wkv_state = (wkv_state * w[:,t].view(BATCH_SIZE,N_HEAD,1,HEAD_SIZE).float() + wkv_state @ full_ab[:,t].view(BATCH_SIZE,N_HEAD,HEAD_SIZE,HEAD_SIZE).float() + full_vk_[:,t].view(BATCH_SIZE,N_HEAD,HEAD_SIZE,HEAD_SIZE).float()).clone() wkv_xx[:,t] = wkv_state.to(dtype=w.dtype) return wkv_xx, wkv_state # xx[:,t] = ((wkv_state.to(dtype=xx.dtype) @ r_.view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1)).view(BATCH_SIZE,N_HEAD*HEAD_SIZE)) # return xx, wkv_state # ---------------- # block/kernel/rwkv7_attn_cuda.py # ---------------- import torch, os, time # from .rwkv7_attn_pytorch import rwkv7_attn_pytorch_chunk #################################################################################################### # Stateless reference implementation #################################################################################################### def load_ref_wkv_cuda_kernel(CHUNK_LEN = 16, HEAD_SIZE = 64): from torch.utils.cpp_extension import load # load_name = f"wind_backstepping_C{HEAD_SIZE}_L{CHUNK_LEN}" load_name = "wind_backstepping" load_file = "wkv7" # Check if the load_name is already loaded if load_name in torch.ops: return torch.ops.wind_backstepping # Logging of warning usage for reference implementation print("[WARNING] Reference CUDA kernel does not support input RWKV state, and is used only for training/validaiton purposes") # Get the this script file path, to cmpute the cuda path this_file_path = os.path.dirname(os.path.abspath(__file__)) # # Get the device compute capability # cuda_device = torch.cuda.current_device() # compute_capability = torch.cuda.get_device_capability(cuda_device) # compute_capability_str = f"{compute_capability[0]}{compute_capability[1]}" # print("[INFO] Using compute capability:", compute_capability_str) # Load the kernel, there is some wierd edge condition in compilation, # that try catching.... and trying again.... sometimes work? flags = ['-res-usage', f'-D_C_={HEAD_SIZE}', f"-D_CHUNK_LEN_={CHUNK_LEN}", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization"] # try: load(name=load_name, sources=[f'{this_file_path}/cuda/{load_file}_cuda.cu', f'{this_file_path}/cuda/{load_file}_op.cpp'], is_python_module=False, verbose=True, extra_cuda_cflags=flags) except Exception as e: print("[WARNING] Failed to load the kernel, trying again (sometimes the compiler has wierd race condition)...") time.sleep(2) # Somehow this works, with minor compilation error, that passes on subsequent reruns load(name=load_name, sources=[f'{this_file_path}/cuda/{load_file}_cuda.cu', f'{this_file_path}/cuda/{load_file}_op.cpp'], is_python_module=False, verbose=True, extra_cuda_cflags=flags) # Return the loaded kernel return torch.ops.wind_backstepping @torch.compiler.disable() def ref_wkv_cuda_forward(w,q,k,v,z,b, y,s,sa): torch.ops.wind_backstepping.forward(w,q,k,v,z,b, y,s,sa) @torch.compiler.disable() def ref_wkv_cuda_backward(w,q,k,v,z,b, dy,s,sa, dw,dq,dk,dv,dz,db): torch.ops.wind_backstepping.backward(w,q,k,v,z,b, dy,s,sa, dw,dq,dk,dv,dz,db) class RefCudaWindBackstepping(torch.autograd.Function): @staticmethod def forward(ctx, w,q,k,v,z,b): CHUNK_LEN=16 B,T,H,C = w.shape assert T%CHUNK_LEN == 0 assert all(i.dtype==torch.bfloat16 for i in [w,q,k,v,z,b]) assert all(i.is_contiguous() for i in [w,q,k,v,z,b]) y = torch.empty_like(v) s = torch.empty(B,H,T//CHUNK_LEN,C,C, dtype=torch.float32,device=w.device) sa = torch.empty(B,T,H,C, dtype=torch.float32,device=w.device) ref_wkv_cuda_forward(w,q,k,v,z,b, y,s,sa) ctx.save_for_backward(w,q,k,v,z,b,s,sa) return y @staticmethod def backward(ctx, dy): assert all(i.dtype==torch.bfloat16 for i in [dy]) assert all(i.is_contiguous() for i in [dy]) w,q,k,v,z,b,s,sa = ctx.saved_tensors dw,dq,dk,dv,dz,db = [torch.empty_like(x) for x in [w,q,k,v,z,b]] ref_wkv_cuda_backward(w,q,k,v,z,b, dy,s,sa, dw,dq,dk,dv,dz,db) return dw,dq,dk,dv,dz,db @torch.compiler.disable() def rwkv7_attn_cuda_ref(q,w,k,v, kk,iclr, HEAD_SIZE=64, s0=None): # Preload the kernel load_ref_wkv_cuda_kernel() # Get the shape B,T,HC = w.shape C = HEAD_SIZE H = HC//C # Assert that the chunk is multiple of 16 assert T % 16 == 0, 'reference cuda, only works in multiple of 16' # Initialize the state, if not provided - for compatibility (THE STATE IS NOT UPDATED) s0 = torch.zeros(B,H,C,C, dtype=torch.float,device=w.device) if s0 is None else s0 # Handling the cuda kernel q,w,k,v,a,b = [i.view(B,T,H,C) for i in [q,w,k,v,(-kk),(kk*iclr)]] # Forward with backprop xx = RefCudaWindBackstepping.apply(w,q,k,v,a,b) return xx.view(B,T,HC), s0.view(B,H,C,C) #################################################################################################### # State based cuda code #################################################################################################### def load_wkv_cuda_kernel(CHUNK_LEN = 16, HEAD_SIZE = 64): from torch.utils.cpp_extension import load # load_name = f"wind_backstepping_C{HEAD_SIZE}_L{CHUNK_LEN}" load_name = "state_wind_backstepping" load_file = "state_wkv7" # Check if the load_name is already loaded if load_name in torch.ops: return torch.ops.state_wind_backstepping # Get the this script file path, to cmpute the cuda path this_file_path = os.path.dirname(os.path.abspath(__file__)) # Load the kernel, there is some wierd edge condition in compilation, # that try catching.... and trying again.... sometimes work? flags = ['-res-usage', f'-D_C_={HEAD_SIZE}', f"-D_CHUNK_LEN_={CHUNK_LEN}", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization"] # try: load(name=load_name, sources=[f'{this_file_path}/cuda/{load_file}_cuda.cu', f'{this_file_path}/cuda/{load_file}_op.cpp'], is_python_module=False, verbose=True, extra_cuda_cflags=flags) except Exception as e: print("[WARNING] Failed to load the kernel, trying again (sometimes the compiler has wierd race condition)...") time.sleep(2) # Somehow this works, with minor compilation error, that passes on subsequent reruns load(name=load_name, sources=[f'{this_file_path}/cuda/{load_file}_cuda.cu', f'{this_file_path}/cuda/{load_file}_op.cpp'], is_python_module=False, verbose=True, extra_cuda_cflags=flags) # Return the loaded kernel return torch.ops.state_wind_backstepping @torch.compiler.disable() def wkv_cuda_forward(state, w,q,k,v,z,b, y,s,sa): torch.ops.state_wind_backstepping.forward(state, w,q,k,v,z,b, y,s,sa) @torch.compiler.disable() def wkv_cuda_backward(state, w,q,k,v,z,b, dy,s,sa, dw,dq,dk,dv,dz,db): torch.ops.state_wind_backstepping.backward(state, w,q,k,v,z,b, dy,s,sa, dw,dq,dk,dv,dz,db) class CudaWindBackstepping(torch.autograd.Function): @staticmethod def forward(ctx, s0, w,q,k,v,z,b): CHUNK_LEN=16 B,T,H,C = w.shape assert T%CHUNK_LEN == 0 assert all(i.dtype==torch.bfloat16 for i in [w,q,k,v,z,b]) assert all(i.is_contiguous() for i in [w,q,k,v,z,b]) y = torch.empty_like(v) s = torch.empty(B,H,T//CHUNK_LEN,C,C, dtype=torch.float32,device=w.device) sa = torch.empty(B,T,H,C, dtype=torch.float32,device=w.device) sOri = s0.clone() wkv_cuda_forward(s0, w,q,k,v,z,b, y,s,sa) ctx.save_for_backward(sOri, w,q,k,v,z,b,s,sa) return y @staticmethod def backward(ctx, dy): assert all(i.dtype==torch.bfloat16 for i in [dy]) assert all(i.is_contiguous() for i in [dy]) state,w,q,k,v,z,b,s,sa = ctx.saved_tensors dS0,dw,dq,dk,dv,dz,db = [torch.empty_like(x) for x in [state,w,q,k,v,z,b]] wkv_cuda_backward(state, w,q,k,v,z,b, dy,s,sa, dw,dq,dk,dv,dz,db) return dS0,dw,dq,dk,dv,dz,db @torch.compiler.disable() def rwkv7_attn_cuda(r,w,k,v, kk,iclr, HEAD_SIZE=64, s0=None): # Preload the kernel load_wkv_cuda_kernel() # Get the shape B,T,HC = w.shape # Check if the chunk is multiple of 16 chunk_remainder = T % 16 # Initialize the state C = HEAD_SIZE H = HC//C # Initialize the state s0 = torch.zeros(B,H,C,C, dtype=torch.float,device=w.device) if s0 is None else s0 sT = s0.to(dtype=torch.float) # Optimize the call, if chunk is multiple of 16 if chunk_remainder == 0: chunk_xx, chunk_sT = rwkv7_attn_cuda_chunk(r,w,k,v, kk,iclr, HEAD_SIZE, sT) return chunk_xx, chunk_sT.to(dtype=s0.dtype) # Compute the number of chunks chunks = T // 16 si = chunks * 16 # Get the chunked output chunk_xx, chunk_sT = rwkv7_attn_cuda_chunk( r[:,:si],w[:,:si],k[:,:si],v[:,:si], kk[:,:si],iclr[:,:si], HEAD_SIZE, s0 ) # Get the remainder remain_xx, last_sT = rwkv7_attn_pytorch_chunk( r[:,si:],torch.exp(-torch.exp(w[:,si:])),k[:,si:],v[:,si:], kk[:,si:],iclr[:,si:], B, H, C, torch.zeros(B, chunk_remainder, HC, device=w.device, dtype=w.dtype), chunk_sT, chunk_size=chunk_remainder ) # Concatenate and return results return torch.cat([chunk_xx.to(dtype=w.dtype), remain_xx.to(dtype=w.dtype)], dim=1), last_sT.to(dtype=s0.dtype) def rwkv7_attn_cuda_chunk(r,w,k,v, kk,iclr, HEAD_SIZE=64, s0=None): ''' Triton implementation running in blocks of 16 (hardcoded requirement for the kernel) ''' B,T,HC = w.shape assert T % 16 == 0, 'pure cuda, only works in multiple of 16' C = HEAD_SIZE H = HC//C # Handling the cuda kernel a,b = -kk, (kk*iclr) r,w,k,v,a,b = [i.view(B,T,H,C) for i in [r,w,k,v,a,b]] if s0 is None: s1 = torch.zeros(B,H,C,C, dtype=torch.float,device=w.device) else: s1 = s0.clone() # Forward with backprop xx = CudaWindBackstepping.apply(s1,w,r,k,v,a,b) return xx.view(B,T,HC), s1.view(B,H,C,C) # ---------------- # block/kernel/rwkv7_attn_fla.py # ---------------- def rwkv7_attn_fla( r,w,k,v, kk,iclr, BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, xx, wkv_state_in ): from fla.ops.rwkv7.chunk import chunk_rwkv7 # Preprocessing the FLA r,w,k,v,a,b = [i.view(BATCH_SIZE,SEQ_LEN,N_HEAD,-1) for i in [r,w,k,v,-kk,(kk*iclr)]] log_w = -w.float().exp() # Run the FLA output, vk_state = chunk_rwkv7(r=r, log_w=log_w, k=k, v=v, a=a, b=b, initial_state=wkv_state_in.float(), output_final_state=True) return output, vk_state.to(dtype=wkv_state_in.dtype) def rwkv7_attn_fused_reccurent_fla( r,w,k,v, kk,iclr, BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, xx, wkv_state_in ): from fla.ops.rwkv7.fused_recurrent import fused_recurrent_rwkv7 # Preprocessing the FLA r,w,k,v,a,b = [i.view(BATCH_SIZE,SEQ_LEN,N_HEAD,-1) for i in [r,w,k,v,-kk,(kk*iclr)]] log_w = -w.float().exp() # Run the FLA output, vk_state = fused_recurrent_rwkv7(r=r, log_w=log_w, k=k, v=v, a=a, b=b, initial_state=wkv_state_in.float(), output_final_state=True) return output, vk_state.to(dtype=wkv_state_in.dtype) # ---------------- # block/kernel/rwkv7_attn_triton.py # ---------------- import torch import torch as th import triton import triton.language as tl #################################################################################################### # Triton specific coding (aka mostly songlin & Johan Sokrates Wind stuff) # # Copyright (c) 2024, Johan Sokrates Wind, licensed under MIT # https://github.com/johanwind/wind_rwkv/blob/main/LICENSE #################################################################################################### # ------------------------- # Triton "smallhead" and "bighead" common code # ------------------------- @triton.jit def IND3(a,b,c,nb,nc): return (a*nb+b)*nc+c @triton.jit def IND4(a,b,c,d,nb,nc,nd): return ((a*nb+b)*nc+c)*nd+d @triton.jit def IND5(a,b,c,d,e,nb,nc,nd,ne): return (((a*nb+b)*nc+c)*nd+d)*ne+e @triton.jit def _prod(a,b): return a*b # inv(I-A) where A is a strictly lower triangular nxn matrix @triton.jit def tri_minv(A, n:tl.constexpr, prec:tl.constexpr): i = tl.arange(0,n) prod = (i[None,:]==i[:,None]).to(tl.float32) for j in range(n-1): prod += tl_dot(prec, prod, (A*((i[None,:]==j)*(i[:,None]>i[None,:]))).trans()) return prod.trans() @triton.jit def tl_dot(prec:tl.constexpr, a, b): if prec == 'fp32': return tl.dot(a.to(tl.float32),b.trans().to(tl.float32).trans(), allow_tf32=False) elif prec == 'tf32': return tl.dot(a.to(tl.float32),b.trans().to(tl.float32).trans(), allow_tf32=True) elif prec == 'bf16': return tl.dot(a.to(tl.bfloat16),b.trans().to(tl.bfloat16).trans(), allow_tf32=True) else: tl.static_assert(False) # ------------------------- # Triton "smallhead" code # ------------------------- @triton.jit def fw_attn_triton(w_,q_,k_,v_,a_,b_, s0_,y_,s_,sT_, B:tl.constexpr,T:tl.constexpr,H:tl.constexpr,C:tl.constexpr,dT:tl.constexpr, prec:tl.constexpr): bi = tl.program_id(1) hi = tl.program_id(0) i = tl.arange(0,C)[None,:] state = tl.load(s0_+IND4(bi,hi,i.trans(),i, H,C,C)).to(tl.float32) for t0 in range(T//dT): t = t0*dT+tl.arange(0,dT)[:,None] sw = tl.load(w_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) sq = tl.load(q_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) sk = tl.load(k_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) sv = tl.load(v_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) sa = tl.load(a_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) sb = tl.load(b_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) w = (-sw.exp()).exp() fw = tl.reduce(w, 0, _prod, keep_dims=True) incl_pref = tl.cumprod(w,axis=0) non_incl_pref = incl_pref / w inv_incl_pref = 1 / incl_pref wq = sq * incl_pref wa = sa * non_incl_pref kwi = sk * inv_incl_pref bwi = sb * inv_incl_pref mask1 = (t > t.trans()) ab = tl_dot(prec, wa, bwi.trans()) * mask1 ak = tl_dot(prec, wa, kwi.trans()) * mask1 ab_inv = tri_minv(ab, dT, prec) ab_u = tl_dot(prec, ak, sv) + tl_dot(prec, wa, state.trans()) u = tl_dot(prec, ab_inv, ab_u) mask2 = (t >= t.trans()) qk = tl_dot(prec, wq, kwi.trans()) * mask2 qb = tl_dot(prec, wq, bwi.trans()) * mask2 yy = tl_dot(prec, qk, sv) + tl_dot(prec, qb, u) + tl_dot(prec, wq, state.trans()) tl.store(y_+IND4(bi,t,hi,i, T,H,C), yy.to(tl.bfloat16)) tl.store(s_+IND5(bi,hi,t0,i.trans(),i, H,T//dT,C,C), state.to(tl.float32)) state = state * fw + tl_dot(prec, sv.trans(), kwi*fw) + tl_dot(prec, u.trans(), bwi*fw) tl.store(sT_+IND4(bi,hi,i.trans(),i, H,C,C), state.to(tl.bfloat16)) @triton.jit def bw_attn_triton(w_,q_,k_,v_,a_,b_, dy_,s_,dsT_, dw_,dq_,dk_,dv_,da_,db_,ds0_, B:tl.constexpr,T:tl.constexpr,H:tl.constexpr,C:tl.constexpr,dT:tl.constexpr, prec:tl.constexpr): bi = tl.program_id(1) hi = tl.program_id(0) i = tl.arange(0,C)[None,:] dstate = tl.load(dsT_+IND4(bi,hi,i.trans(),i, H,C,C)).to(tl.float32) for t0 in range(T//dT-1,-1,-1): t = t0*dT+tl.arange(0,dT)[:,None] state = tl.load(s_+IND5(bi,hi,t0,i.trans(),i, H,T//dT,C,C)).to(tl.float32) sw = tl.load(w_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) sq = tl.load(q_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) sk = tl.load(k_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) sv = tl.load(v_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) sa = tl.load(a_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) sb = tl.load(b_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) sdy = tl.load(dy_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) dw_fac = -sw.exp() w = dw_fac.exp() fw = tl.reduce(w, 0, _prod, keep_dims=True) incl_pref = tl.cumprod(w,axis=0) non_incl_pref = incl_pref / w inv_incl_pref = 1 / incl_pref wq = sq * incl_pref wa = sa * non_incl_pref kwi = sk * inv_incl_pref bwi = sb * inv_incl_pref mask1 = (t > t.trans()) ab = tl_dot(prec, wa, bwi.trans()) * mask1 ak = tl_dot(prec, wa, kwi.trans()) * mask1 ab_inv = tri_minv(ab, dT, prec) ab_u = tl_dot(prec, ak, sv) + tl_dot(prec, wa, state.trans()) u = tl_dot(prec, ab_inv, ab_u) mask2 = (t >= t.trans()) qk = tl_dot(prec, wq, kwi.trans()) * mask2 qb = tl_dot(prec, wq, bwi.trans()) * mask2 du = tl_dot(prec, qb.trans(), sdy) + tl_dot(prec, bwi*fw, dstate.trans()) dab_u = tl_dot(prec, ab_inv.trans(), du) dv = tl_dot(prec, qk.trans(), sdy) + tl_dot(prec, kwi*fw, dstate.trans()) + tl_dot(prec, ak.trans(), dab_u) tl.store(dv_+IND4(bi,t,hi,i, T,H,C), dv.to(tl.bfloat16)) dab = tl_dot(prec, tl_dot(prec, ab_inv.trans(), du), u.trans()) * mask1 dak = tl_dot(prec, dab_u, sv.trans()) * mask1 dab_u_state = tl_dot(prec, dab_u, state) da = non_incl_pref * (tl_dot(prec, dab, bwi) + tl_dot(prec, dak, kwi) + dab_u_state) tl.store(da_+IND4(bi,t,hi,i, T,H,C), da.to(tl.bfloat16)) dqb = tl_dot(prec, sdy, u.trans()) * mask2 dqk = tl_dot(prec, sdy, sv.trans()) * mask2 dy_state = tl_dot(prec, sdy, state) dq = incl_pref * (tl_dot(prec, dqb, bwi) + tl_dot(prec, dqk, kwi) + dy_state) tl.store(dq_+IND4(bi,t,hi,i, T,H,C), dq.to(tl.bfloat16)) fw_u_dstate = fw * tl_dot(prec, u, dstate) db = inv_incl_pref * (tl_dot(prec, dab.trans(), wa) + tl_dot(prec, dqb.trans(), wq) + fw_u_dstate) tl.store(db_+IND4(bi,t,hi,i, T,H,C), db.to(tl.bfloat16)) fw_v_dstate = fw * tl_dot(prec, sv, dstate) dk = inv_incl_pref * (tl_dot(prec, dak.trans(), wa) + tl_dot(prec, dqk.trans(), wq) + fw_v_dstate) tl.store(dk_+IND4(bi,t,hi,i, T,H,C), dk.to(tl.bfloat16)) dw0 = fw * tl.sum(state*dstate, axis=0,keep_dims=True) for k in range(t0*dT,t0*dT+dT): lmask = (tk) A += (tl_dot(prec, dqb*lmask, bwi) + tl_dot(prec, dqk*lmask, kwi)) * wq * (t>=k) A += (fw_v_dstate*kwi + fw_u_dstate*bwi) * (tk) + dy_state*wq * (t>=k) dw = tl.sum(A, axis=0,keep_dims=True) + dw0 wk = tl.load(w_+IND4(bi,k,hi,i, T,H,C)).to(tl.float32) dw *= -wk.exp() tl.store(dw_+IND4(bi,k,hi,i, T,H,C), dw.to(tl.bfloat16)) dstate = dstate * fw + tl_dot(prec, sdy.trans(), wq) + tl_dot(prec, dab_u.trans(), wa) tl.store(ds0_+IND4(bi,hi,i.trans(),i, H,C,C), dstate.to(tl.bfloat16)) class TritonRWKV7(th.autograd.Function): @staticmethod def forward(ctx, w,q,k,v,z,b,s0, dot_prec): K = 16 B,T,H,C = w.shape s0 = th.zeros(B,H,C,C, dtype=w.dtype,device=w.device) if s0 is None else s0 y = th.empty_like(v) sT = th.empty_like(s0) s = th.zeros(B,H,T//K,C,C, dtype=th.float32,device=w.device) fw_attn_triton[(H,B)](w,q,k,v,z,b, s0,y,s,sT, B,T,H,C,K, dot_prec) ctx.dot_prec = dot_prec ctx.save_for_backward(w,q,k,v,z,b,s) return y, sT @staticmethod def backward(ctx, dy, dsT): K = 16 w,q,k,v,z,b,s = ctx.saved_tensors B,T,H,C = w.shape dw,dq,dk,dv,dz,db,ds0 = [th.empty_like(x) for x in [w,q,k,v,z,b,dsT]] bw_attn_triton[(H,B)](w,q,k,v,z,b, dy,s,dsT, dw,dq,dk,dv,dz,db,ds0, B,T,H,C,K, ctx.dot_prec) return dw,dq,dk,dv,dz,db,ds0,None # ------------------------- # Triton "bighead" code # ------------------------- @triton.autotune(configs=[triton.Config({'dC': dC}, num_stages=1) for dC in [16,32,64]], key=['T','H','C','dT','prec']) @triton.jit def fw_attn_triton_bighead(w_,q_,k_,v_,a_,b_, s0_,y_,s_,sT_, wq_,wa_,kwi_,bwi_,fw_, B:tl.constexpr,T:tl.constexpr,H:tl.constexpr,C:tl.constexpr,dT:tl.constexpr, prec:tl.constexpr, dC:tl.constexpr): tl.static_assert(C%dC == 0) bi = tl.program_id(1) hi = tl.program_id(0) for i0 in range(0,C,dC): i = i0+tl.arange(0,dC)[None,:] for j0 in range(0,C,dC): j = j0+tl.arange(0,dC)[None,:] state = tl.load(s0_+IND4(bi,hi,i.trans(),j, H,C,C)).to(tl.float32) tl.store(s_+IND5(bi,hi,0,i.trans(),j, H,T//dT,C,C), state.to(tl.float32)) for t0 in range(T//dT): dt = tl.arange(0,dT)[:,None] t = t0*dT+dt tl.debug_barrier() for j0 in range(0,C,dC): j = j0+tl.arange(0,dC)[None,:] sw = tl.load(w_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) sq = tl.load(q_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) sk = tl.load(k_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) sa = tl.load(a_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) sb = tl.load(b_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) w = (-sw.exp()).exp() fw = tl.reduce(w, 0, _prod, keep_dims=True) incl_pref = tl.cumprod(w,axis=0) non_incl_pref = incl_pref / w inv_incl_pref = 1 / incl_pref wq = sq * incl_pref wa = sa * non_incl_pref kwi = sk * inv_incl_pref bwi = sb * inv_incl_pref tl.store(wq_+IND4(bi,hi,dt,j, H,dT,C), wq.to(tl.float32)) tl.store(wa_+IND4(bi,hi,dt,j, H,dT,C), wa.to(tl.float32)) tl.store(kwi_+IND4(bi,hi,dt,j, H,dT,C), kwi.to(tl.float32)) tl.store(bwi_+IND4(bi,hi,dt,j, H,dT,C), bwi.to(tl.float32)) tl.store(fw_+IND3(bi,hi,j, H,C), fw.to(tl.float32)) tl.debug_barrier() ab = tl.zeros((dT,dT), tl.float32) ak = tl.zeros((dT,dT), tl.float32) qb = tl.zeros((dT,dT), tl.float32) qk = tl.zeros((dT,dT), tl.float32) for j0 in range(0,C,dC): j = j0+tl.arange(0,dC)[None,:] wa = tl.load(wa_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) wq = tl.load(wq_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) bwi = tl.load(bwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) kwi = tl.load(kwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) sa = tl.load(a_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) sb = tl.load(b_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) ab += tl_dot(prec, wa, bwi.trans()) ak += tl_dot(prec, wa, kwi.trans()) qb += tl_dot(prec, wq, bwi.trans()) qk += tl_dot(prec, wq, kwi.trans()) mask1 = (t > t.trans()) mask2 = (t >= t.trans()) ab *= mask1 ak *= mask1 qb *= mask2 qk *= mask2 ab_inv = tri_minv(ab, dT, prec) for i0 in range(0,C,dC): i = i0+tl.arange(0,dC)[None,:] sv = tl.load(v_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) wa_state = tl.zeros((dT,dC), tl.float32) wq_state = tl.zeros((dT,dC), tl.float32) for j0 in range(0,C,dC): j = j0+tl.arange(0,dC)[None,:] state = tl.load(s_+IND5(bi,hi,t0,i.trans(),j, H,T//dT,C,C)).to(tl.float32) wa = tl.load(wa_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) wq = tl.load(wq_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) wa_state += tl_dot(prec, wa, state.trans()) wq_state += tl_dot(prec, wq, state.trans()) ab_u = tl_dot(prec, ak, sv) + wa_state u = tl_dot(prec, ab_inv, ab_u) yy = tl_dot(prec, qk, sv) + tl_dot(prec, qb, u) + wq_state tl.store(y_+IND4(bi,t,hi,i, T,H,C), yy.to(tl.bfloat16)) for j0 in range(0,C,dC): j = j0+tl.arange(0,dC)[None,:] state = tl.load(s_+IND5(bi,hi,t0,i.trans(),j, H,T//dT,C,C)).to(tl.float32) kwi = tl.load(kwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) bwi = tl.load(bwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) fw = tl.load(fw_+IND3(bi,hi,j, H,C)) state = state * fw + tl_dot(prec, sv.trans(), kwi*fw) + tl_dot(prec, u.trans(), bwi*fw) if t0+1 < T//dT: tl.store(s_+IND5(bi,hi,t0+1,i.trans(),j, H,T//dT,C,C), state.to(tl.float32)) else: tl.store(sT_+IND4(bi,hi,i.trans(),j, H,C,C), state.to(tl.bfloat16)) @triton.autotune(configs=[triton.Config({'dC': dC}, num_stages=1) for dC in [16,32,64]], key=['T','H','C','dT','prec']) @triton.jit def bw_attn_triton_bighead(w_,q_,k_,v_,a_,b_, dy_,s_,dsT_,ds_, dw_,dq_,dk_,dv_,da_,db_,ds0_, wq_,wa_,kwi_,bwi_,fw_,u_,dab_u_, B:tl.constexpr,T:tl.constexpr,H:tl.constexpr,C:tl.constexpr,dT:tl.constexpr, prec:tl.constexpr, dC:tl.constexpr): tl.static_assert(C%dC == 0) bi = tl.program_id(1) hi = tl.program_id(0) for i0 in range(0,C,dC): i = i0+tl.arange(0,dC)[None,:] for j0 in range(0,C,dC): j = j0+tl.arange(0,dC)[None,:] dstate = tl.load(dsT_+IND4(bi,hi,i.trans(),j, H,C,C)).to(tl.float32) tl.store(ds_+IND4(bi,hi,i.trans(),j, H,C,C), dstate.to(tl.float32)) for t0 in range(T//dT-1,-1,-1): dt = tl.arange(0,dT)[:,None] t = t0*dT+dt tl.debug_barrier() for j0 in range(0,C,dC): j = j0+tl.arange(0,dC)[None,:] sw = tl.load(w_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) sq = tl.load(q_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) sk = tl.load(k_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) sa = tl.load(a_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) sb = tl.load(b_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) w = (-sw.exp()).exp() fw = tl.reduce(w, 0, _prod, keep_dims=True) incl_pref = tl.cumprod(w,axis=0) non_incl_pref = incl_pref / w inv_incl_pref = 1 / incl_pref wq = sq * incl_pref wa = sa * non_incl_pref kwi = sk * inv_incl_pref bwi = sb * inv_incl_pref tl.store(wq_+IND4(bi,hi,dt,j, H,dT,C), wq.to(tl.float32)) tl.store(wa_+IND4(bi,hi,dt,j, H,dT,C), wa.to(tl.float32)) tl.store(kwi_+IND4(bi,hi,dt,j, H,dT,C), kwi.to(tl.float32)) tl.store(bwi_+IND4(bi,hi,dt,j, H,dT,C), bwi.to(tl.float32)) tl.store(fw_+IND3(bi,hi,j, H,C), fw.to(tl.float32)) tl.debug_barrier() ab = tl.zeros((dT,dT), tl.float32) ak = tl.zeros((dT,dT), tl.float32) qb = tl.zeros((dT,dT), tl.float32) qk = tl.zeros((dT,dT), tl.float32) for j0 in range(0,C,dC): j = j0+tl.arange(0,dC)[None,:] wa = tl.load(wa_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) wq = tl.load(wq_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) bwi = tl.load(bwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) kwi = tl.load(kwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) sa = tl.load(a_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) sb = tl.load(b_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) ab += tl_dot(prec, wa, bwi.trans()) ak += tl_dot(prec, wa, kwi.trans()) qb += tl_dot(prec, wq, bwi.trans()) qk += tl_dot(prec, wq, kwi.trans()) mask1 = (t > t.trans()) mask2 = (t >= t.trans()) ab *= mask1 ak *= mask1 qb *= mask2 qk *= mask2 ab_inv = tri_minv(ab, dT, prec) dab = tl.zeros((dT,dT), tl.float32) dak = tl.zeros((dT,dT), tl.float32) dqb = tl.zeros((dT,dT), tl.float32) dqk = tl.zeros((dT,dT), tl.float32) tl.debug_barrier() for i0 in range(0,C,dC): i = i0+tl.arange(0,dC)[None,:] wa_state = tl.zeros((dT,dC), tl.float32) bwi_dw_dstate = tl.zeros((dT,dC), tl.float32) kwi_dw_dstate = tl.zeros((dT,dC), tl.float32) for j0 in range(0,C,dC): j = j0+tl.arange(0,dC)[None,:] state = tl.load(s_+IND5(bi,hi,t0,i.trans(),j, H,T//dT,C,C)).to(tl.float32) dstate = tl.load(ds_+IND4(bi,hi,i.trans(),j, H,C,C)).to(tl.float32) wa = tl.load(wa_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) bwi = tl.load(bwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) kwi = tl.load(kwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) fw = tl.load(fw_+IND3(bi,hi,j, H,C)) wa_state += tl_dot(prec, wa, state.trans()) bwi_dw_dstate += tl_dot(prec, bwi*fw, dstate.trans()) kwi_dw_dstate += tl_dot(prec, kwi*fw, dstate.trans()) sv = tl.load(v_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) sdy = tl.load(dy_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) ab_u = tl_dot(prec, ak, sv) + wa_state u = tl_dot(prec, ab_inv, ab_u) du = tl_dot(prec, qb.trans(), sdy) + bwi_dw_dstate dab_u = tl_dot(prec, ab_inv.trans(), du) tl.store(u_+IND4(bi,hi,dt,i, H,dT,C), u.to(tl.float32)) tl.store(dab_u_+IND4(bi,hi,dt,i, H,dT,C), dab_u.to(tl.float32)) dv = tl_dot(prec, qk.trans(), sdy) + kwi_dw_dstate + tl_dot(prec, ak.trans(), dab_u) tl.store(dv_+IND4(bi,t,hi,i, T,H,C), dv.to(tl.bfloat16)) dab += tl_dot(prec, dab_u, u.trans()) * mask1 dak += tl_dot(prec, dab_u, sv.trans()) * mask1 dqb += tl_dot(prec, sdy, u.trans()) * mask2 dqk += tl_dot(prec, sdy, sv.trans()) * mask2 tl.debug_barrier() for j0 in range(0,C,dC): j = j0+tl.arange(0,dC)[None,:] dy_state = tl.zeros((dT,dC), tl.float32) dab_u_state = tl.zeros((dT,dC), tl.float32) fw_u_dstate = tl.zeros((dT,dC), tl.float32) fw_v_dstate = tl.zeros((dT,dC), tl.float32) state_dstate = tl.zeros((1,dC), tl.float32) fw = tl.load(fw_+IND3(bi,hi,j, H,C)) wa = tl.load(wa_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) wq = tl.load(wq_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) for i0 in range(0,C,dC): i = i0+tl.arange(0,dC)[None,:] u = tl.load(u_+IND4(bi,hi,dt,i, H,dT,C)).to(tl.float32) dab_u = tl.load(dab_u_+IND4(bi,hi,dt,i, H,dT,C)).to(tl.float32) sv = tl.load(v_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) sdy = tl.load(dy_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) state = tl.load(s_+IND5(bi,hi,t0,i.trans(),j, H,T//dT,C,C)).to(tl.float32) tl.debug_barrier() dstate = tl.load(ds_+IND4(bi,hi,i.trans(),j, H,C,C)).to(tl.float32) tl.debug_barrier() dab_u_state += tl_dot(prec, dab_u, state) fw_u_dstate += fw * tl_dot(prec, u, dstate) fw_v_dstate += fw * tl_dot(prec, sv, dstate) dy_state += tl_dot(prec, sdy, state) state_dstate += tl.sum(state*dstate, axis=0,keep_dims=True) dstate = dstate * fw + tl_dot(prec, sdy.trans(), wq) + tl_dot(prec, dab_u.trans(), wa) if t0 > 0: tl.store(ds_+IND4(bi,hi,i.trans(),j, H,C,C), dstate.to(tl.float32)) else: tl.store(ds0_+IND4(bi,hi,i.trans(),j, H,C,C), dstate.to(tl.bfloat16)) sw = tl.load(w_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) w = (-sw.exp()).exp() incl_pref = tl.cumprod(w,axis=0) non_incl_pref = incl_pref / w inv_incl_pref = 1 / incl_pref bwi = tl.load(bwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) kwi = tl.load(kwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) da = non_incl_pref * (tl_dot(prec, dab, bwi) + tl_dot(prec, dak, kwi) + dab_u_state) tl.store(da_+IND4(bi,t,hi,j, T,H,C), da.to(tl.bfloat16)) dq = incl_pref * (tl_dot(prec, dqb, bwi) + tl_dot(prec, dqk, kwi) + dy_state) tl.store(dq_+IND4(bi,t,hi,j, T,H,C), dq.to(tl.bfloat16)) db = inv_incl_pref * (tl_dot(prec, dab.trans(), wa) + tl_dot(prec, dqb.trans(), wq) + fw_u_dstate) tl.store(db_+IND4(bi,t,hi,j, T,H,C), db.to(tl.bfloat16)) dk = inv_incl_pref * (tl_dot(prec, dak.trans(), wa) + tl_dot(prec, dqk.trans(), wq) + fw_v_dstate) tl.store(dk_+IND4(bi,t,hi,j, T,H,C), dk.to(tl.bfloat16)) dw0 = fw * state_dstate for k in range(t0*dT,t0*dT+dT): lmask = (tk) A += (tl_dot(prec, dqb*lmask, bwi) + tl_dot(prec, dqk*lmask, kwi)) * wq * (t>=k) A += (fw_v_dstate*kwi + fw_u_dstate*bwi) * (tk) + dy_state*wq * (t>=k) dw = tl.sum(A, axis=0,keep_dims=True) + dw0 wk = tl.load(w_+IND4(bi,k,hi,j, T,H,C)).to(tl.float32) dw *= -wk.exp() tl.store(dw_+IND4(bi,k,hi,j, T,H,C), dw.to(tl.bfloat16)) class TritonBigheadRWKV7(th.autograd.Function): @staticmethod def forward(ctx, w,q,k,v,a,b,s0, dot_prec): K = 16 B,T,H,C = w.shape assert T%K == 0 assert C%16 == 0 s0 = th.zeros(B,H,C,C, dtype=w.dtype,device=w.device) if s0 is None else s0 y = th.empty_like(v) sT = th.empty_like(s0) s = th.zeros(B,H,T//K,C,C, dtype=th.float32,device=w.device) wq,wa,kwi,bwi = [th.empty(B,H,K,C, dtype=th.float32,device=w.device) for i in range(4)] fw = th.empty(B,H,C, dtype=th.float32,device=w.device) fw_attn_triton_bighead[(H,B)](w,q,k,v,a,b, s0,y,s,sT, wq,wa,kwi,bwi,fw, B,T,H,C,K, dot_prec) ctx.dot_prec = dot_prec ctx.save_for_backward(w,q,k,v,a,b,s) return y, sT @staticmethod def backward(ctx, dy, dsT): K = 16 w,q,k,v,a,b,s = ctx.saved_tensors B,T,H,C = w.shape dw,dq,dk,dv,da,db,ds0 = [th.empty_like(x) for x in [w,q,k,v,a,b,dsT]] fw = th.empty(B,H,C, dtype=th.float32,device=w.device) ds = th.empty(B,H,C,C, dtype=th.float32,device=w.device) wq,wa,kwi,bwi,u,dab_u = [th.empty(B,H,K,C, dtype=th.float32,device=w.device) for i in range(6)] bw_attn_triton_bighead[(H,B)](w,q,k,v,a,b, dy,s,dsT,ds, dw,dq,dk,dv,da,db,ds0, wq,wa,kwi,bwi,fw,u,dab_u, B,T,H,C,K, ctx.dot_prec) return dw,dq,dk,dv,da,db,ds0,None #################################################################################################### # Start of pytorch code #################################################################################################### # from .rwkv7_attn_pytorch import rwkv7_attn_pytorch_chunk # ------------------------- # Pytorch "smallhead" code # ------------------------- def rwkv7_attn_triton(r,w,k,v, kk,iclr, HEAD_SIZE=64, dot_prec='fp32', s0=None): B,T,HC = w.shape # Check if the chunk is multiple of 16 chunk_remainder = T % 16 # Optimize the call, if chunk is multiple of 16 if chunk_remainder == 0: return rwkv7_attn_triton_chunk(r,w,k,v, kk,iclr, HEAD_SIZE, dot_prec, s0) # Initialize the state C = HEAD_SIZE H = HC//C s0 = th.zeros(B,H,C,C, dtype=th.float,device=w.device) if s0 is None else s0 # Compute the number of chunks chunks = T // 16 si = chunks * 16 # Get the chunked output chunk_xx, chunk_sT = rwkv7_attn_triton_chunk( r[:,:si],w[:,:si],k[:,:si],v[:,:si], kk[:,:si],iclr[:,:si], HEAD_SIZE, dot_prec, s0 ) # Get the remainder remain_xx, last_sT = rwkv7_attn_pytorch_chunk( r[:,si:],torch.exp(-torch.exp(w[:,si:])),k[:,si:],v[:,si:], kk[:,si:],iclr[:,si:], B, H, C, torch.zeros(B, chunk_remainder, HC, device=w.device, dtype=w.dtype), chunk_sT, chunk_size=chunk_remainder ) # Concatenate and return results return torch.cat([chunk_xx.to(dtype=w.dtype), remain_xx.to(dtype=w.dtype)], dim=1), last_sT.to(dtype=s0.dtype) def rwkv7_attn_triton_chunk(r,w,k,v, kk,iclr, HEAD_SIZE=64, dot_prec='fp32', s0=None): ''' Triton implementation running in blocks of 16 (hardcoded requirement for the kernel) ''' B,T,HC = w.shape assert T % 16 == 0, 'pure triton, only works in multiple of 16' C = HEAD_SIZE H = HC//C # Moving the triton specific operations into the chunk steps r,w,k,v,a,b = [i.view(B,T,H,C) for i in [r,w,k,v,-kk,(kk*iclr)]] s0 = th.zeros(B,H,C,C, dtype=th.float,device=w.device) if s0 is None else s0 xx, sT = TritonRWKV7.apply(w,r,k,v,a,b,s0,dot_prec) return xx.view(B,T,HC), sT # ------------------------- # Pytorch "bighead" code # ------------------------- def rwkv7_attn_triton_bighead(r,w,k,v, kk,iclr, HEAD_SIZE=64, dot_prec='fp32', s0=None): B,T,HC = w.shape # Check if the chunk is multiple of 16 chunk_remainder = T % 16 # Optimize the call, if chunk is multiple of 16 if chunk_remainder == 0: return rwkv7_attn_triton_bighead_chunk(r,w,k,v, kk,iclr, HEAD_SIZE, dot_prec, s0) # Initialize the state C = HEAD_SIZE H = HC//C s0 = th.zeros(B,H,C,C, dtype=th.float,device=w.device) if s0 is None else s0 # Compute the number of chunks chunks = T // 16 si = chunks * 16 # Get the chunked output chunk_xx, chunk_sT = rwkv7_attn_triton_bighead_chunk( r[:,:si],w[:,:si],k[:,:si],v[:,:si], kk[:,:si],iclr[:,:si], HEAD_SIZE, dot_prec, s0 ) # Get the remainder remain_xx, last_sT = rwkv7_attn_pytorch_chunk( r[:,si:],torch.exp(-torch.exp(w[:,si:])),k[:,si:],v[:,si:], kk[:,si:],iclr[:,si:], B, H, C, torch.zeros(B, chunk_remainder, HC, device=w.device, dtype=w.dtype), chunk_sT, chunk_size=chunk_remainder ) # Concatenate and return results return torch.cat([chunk_xx.to(dtype=w.dtype), remain_xx.to(dtype=w.dtype)], dim=1), last_sT.to(dtype=s0.dtype) def rwkv7_attn_triton_bighead_chunk(r,w,k,v, kk,iclr, HEAD_SIZE=64, dot_prec='fp32', s0=None): ''' Triton implementation running in blocks of 16 (hardcoded requirement for the kernel) ''' B,T,HC = w.shape assert T % 16 == 0, 'pure triton, only works in multiple of 16' C = HEAD_SIZE H = HC//C # Moving the triton specific operations into the chunk steps r,w,k,v,a,b = [i.view(B,T,H,C) for i in [r,w,k,v,-kk,(kk*iclr)]] s0 = th.zeros(B,H,C,C, dtype=th.float,device=w.device) if s0 is None else s0 xx, sT = TritonBigheadRWKV7.apply(w,r,k,v,a,b,s0,dot_prec) return xx.view(B,T,HC), sT # ---------------- # block/rwkv7_block_config_map.py # ---------------- from dataclasses import dataclass from typing import Optional from typing import Union import torch @dataclass class RWKV7BlockConfigMap: """Configuration map for RWKV based models""" # Key properties for the block / model num_hidden_layers: int hidden_size: int head_size: int = 64 # Dropout rate, should only be used in training dropout_rate: float = 0.0 # Implementation backend to use tmix_backend: str = "auto" # --- # OPTIONAL PROPS # # Optional properties which can be derived # or can be overwritten by the user # --- # Current layer_id of the block layer_id: Optional[int] = None # Device and Data type device: Union[torch.device, str, None] = None dtype: Union[torch.dtype, str, None] = None # Channel mix / FFN block dimension size hidden_size_ffn: Optional[int] = None hidden_size_att: Optional[int] = None # # number of heads # n_head: Optional[int] = None # --- # Initializer, with excess arg ignore # --- def __init__( self, num_hidden_layers: int, hidden_size: int, head_size: int = 64, dropout_rate: float = 0.0, tmix_backend: str = "auto", layer_id: Optional[int] = None, device: Union[torch.device, str, None] = None, dtype: Union[torch.dtype, str, None] = None, hidden_size_ffn: Optional[int] = None, hidden_size_att: Optional[int] = None, **kwargs ) -> None: ''' Constructor for the config ''' self.num_hidden_layers = num_hidden_layers self.hidden_size = hidden_size self.head_size = head_size self.dropout_rate = dropout_rate self.tmix_backend = tmix_backend self.layer_id = layer_id self.device = device self.dtype = dtype self.hidden_size_ffn = hidden_size_ffn self.hidden_size_att = hidden_size_att # --- # OPTIONAL PROPS FETCHER # --- def get_layer_id(self, fallback:int) -> int: ''' Returns the layer id ''' if self.layer_id is not None: return self.layer_id return fallback def get_device(self, fallback:str) -> torch.device: ''' Returns the device ''' if self.device is not None: return torch.device(self.device) if fallback is not None: return torch.device(fallback) return torch.get_default_device() def get_dtype(self, fallback:str) -> torch.dtype: ''' Returns the dtype ''' if self.dtype is not None: key = self.dtype else: key = fallback # if dtype is already torch.dtype if isinstance(key, torch.dtype): return key # Get and Check if the dtype is instance of torch.dtype ret = getattr(torch, key) assert isinstance(ret, torch.dtype), f"Invalid dtype: {self.dtype}" return ret # --- def get_hidden_size_att(self) -> int: ''' Returns the dimension of attention ''' if self.hidden_size_att is not None: hidden_size_att = self.hidden_size_att else: hidden_size = self.hidden_size assert hidden_size % 32 == 0, f"hidden_size must be divisible by 32" hidden_size_att = hidden_size assert hidden_size_att % 32 == 0, f"hidden_size_att must be divisible by 32 ({hidden_size_att})" return hidden_size_att def get_hidden_size_ffn(self) -> int: ''' Returns the dimension of feed forward network ''' if self.hidden_size_ffn is not None: hidden_size_ffn = self.hidden_size_ffn else: hidden_size = self.hidden_size assert hidden_size % 32 == 0, f"hidden_size must be divisible by 32" hidden_size_ffn = hidden_size * 4 assert hidden_size_ffn % 32 == 0, f"hidden_size_ffn must be divisible by 32" return hidden_size_ffn # def get_n_head(self) -> int: # ''' # Returns the number of heads # ''' # if self.n_head is not None: # n_head = self.n_head # else: # hidden_size_att = self.get_hidden_size_att() # n_head = self.get_hidden_size_att() // self.head_size # assert hidden_size_att % n_head == 0 , f"hidden_size_att must be divisible by head_size ({self.head_size})" # # return n_head # --- # Duplicator & Normalizer # --- def new_block_config_map(self, **kwargs) -> 'RWKV7BlockConfigMap': ''' Returns a new config map with updated values ''' new_dict = {} for key in RWKV7BlockConfigMap.__dataclass_fields__: if key in self.__dict__: new_dict[key] = self.__dict__[key] new_dict.update(kwargs) return RWKV7BlockConfigMap(**new_dict) @staticmethod def normalize(config_map: any) -> 'RWKV7BlockConfigMap': ''' Converts either maps, objs or RWKV7BlockConfigMap ''' if isinstance(config_map, RWKV7BlockConfigMap): return config_map dict_obj = None if isinstance(config_map, dict): dict_obj = config_map elif hasattr(config_map, '__dict__'): dict_obj = config_map.__dict__ if dict_obj is not None: # Filter for only valeus in RWKV7BlockConfigMap new_dict = {} for key, value in dict_obj.items(): if key in RWKV7BlockConfigMap.__dataclass_fields__: new_dict[key] = value return RWKV7BlockConfigMap(**new_dict) raise ValueError(f"Unsupported config_map type: {type(config_map)}") # ---------------- # block/rwkv7_channel_mix.py # ---------------- import torch from torch import nn from typing import Union # from .rwkv7_block_config_map import RWKV7BlockConfigMap class RWKV7ChannelMix(torch.nn.Module): ''' ChannelMix block for RWKV This is similar to transformer FFN block ''' def __init__(self, configMap: Union[RWKV7BlockConfigMap, any]): ''' Initialize the ChannelMix block. Note: this does not initialize the parameter weights itself which would depend on the `init_parameters()` method ''' super().__init__() configMap:RWKV7BlockConfigMap = RWKV7BlockConfigMap.normalize(configMap) self.configMap = configMap # Get various props hidden_size = configMap.hidden_size device = configMap.get_device(None) dtype = configMap.get_dtype('bfloat16') # By default, hidden_size_ffn = hidden_size * 4 hidden_size_ffn = configMap.get_hidden_size_ffn() # Build the various params # --- self.x_k = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype)) self.key = nn.Linear(hidden_size, hidden_size_ffn, bias=False, device=device, dtype=dtype) self.value = nn.Linear(hidden_size_ffn, hidden_size, bias=False, device=device, dtype=dtype) def init_parameters(self): ''' Reset the parameters of the block, to an initial state used for training a model from scratch ''' # Get required props configMap = self.configMap hidden_size = configMap.hidden_size num_hidden_layers = configMap.num_hidden_layers # Get optional props layer_id = configMap.get_layer_id(0) device = configMap.get_device(None) dtype = configMap.get_dtype('bfloat16') # By default, hidden_size_ffn = hidden_size * 4 hidden_size_ffn = configMap.get_hidden_size_ffn() # Reset the various params # --- with torch.no_grad(): # fancy init of time_mix ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0 ddd = torch.ones(1, 1, hidden_size) for i in range(hidden_size): ddd[0, 0, i] = i / hidden_size self.x_k = nn.Parameter( (1.0 - torch.pow(ddd, ratio_1_to_almost0**4)).to(device, dtype=dtype) ) self.key = nn.Linear(hidden_size, hidden_size_ffn, bias=False, device=device, dtype=dtype) self.value = nn.Linear(hidden_size_ffn, hidden_size, bias=False, device=device, dtype=dtype) def forward(self, x: torch.Tensor, last_state: torch.Tensor) -> tuple[torch.Tensor,torch.Tensor]: ''' Forwarding channel mix given the input tokens and states. Given: - Incoming token embedding size of shape [batch_size, seq_len, embedding_size] - Incoming channel mix, shift states of the various batches [batch_size, state_size] Returns a pair - Output embedding of shape [batch_size, seq_len, embedding_size] - Output channel mix, shift state of shape [batch_size, state_size] ''' # last_state = last_state.to(self.key.weight.device) ########## ## x070 ########## dxprev = torch.cat((last_state.unsqueeze(1), x[:, :-1]), dim=1) - x xk = x + dxprev * self.x_k k = torch.relu( self.key(xk) ) ** 2 return self.value(k), x[:,-1] @torch.compile(mode="default", fullgraph=True) def forward_with_default_compile(self, in_x: torch.Tensor, in_state: torch.Tensor, out_x: torch.Tensor, out_state: torch.Tensor) -> tuple[torch.Tensor,torch.Tensor]: ''' Compiled varient of the forward function With no new tensors being created for the output Useful for static memory allocation optimizations inference ''' out_x[:], out_state[:] = self.forward_with_reduce_compile(in_x, in_state) return out_x, out_state @torch.compile(mode="reduce-overhead", fullgraph=True) def forward_with_reduce_compile(self, in_x: torch.Tensor, in_state: torch.Tensor) -> tuple[torch.Tensor,torch.Tensor]: ''' Compiled varient of the forward function ''' return self.forward(in_x, in_state) def load_from_model_state_dict(self, model_state_dict: dict, layer_id:int, non_blocking:bool=True): ''' Given the Full/partial RWKV model weights, loaded via `torch.load` Setup the the current module weights, using the layer_id ''' # Get the current state_dict current_state_dict = self.state_dict() # Iterate each parameter in the state_dict, and compare from the model for n in current_state_dict: model_key = f"blocks.{layer_id}.ffn.{n}" if model_key not in model_state_dict: continue # Copy the values from the state_dict try: current_state_dict[n].copy_(model_state_dict[model_key], non_blocking=non_blocking) except Exception as e: print(f"[ERROR] loading: {model_key} | model shape: {current_state_dict[n].shape} | weight shape: {model_state_dict[model_key].shape}") raise e # ---------------- # block/rwkv7_time_mix.py # ---------------- import torch, math from torch import nn from torch import Tensor from typing import Union from torch.nn import functional as F # from .rwkv7_block_config_map import RWKV7BlockConfigMap # Check for triton package, if GPU is available triton = None if torch.cuda.is_available(): try: import triton except ImportError: triton = None else: print("[WARNING] Triton not available, falling back to pytorch mode by default - this is significantly slower") class RWKV7TimeMix(torch.nn.Module): ''' Time Mix block for RWKV V7 ''' def __init__(self, configMap: Union[RWKV7BlockConfigMap, any]): ''' Initialize the TimeMix block. Note: this does not initialize the parameter weights itself which would depend on the `init_parameters()` method ''' super().__init__() configMap:RWKV7BlockConfigMap = RWKV7BlockConfigMap.normalize(configMap) self.configMap = configMap # Get required props hidden_size = configMap.hidden_size # num_hidden_layers = configMap.num_hidden_layers # Get the layer id layer_id = configMap.get_layer_id(0) self.layer_id = layer_id # Get optional props device = configMap.get_device(None) dtype = configMap.get_dtype('bfloat16') # By default, hidden_size_ffn = hidden_size hidden_size_att = configMap.get_hidden_size_att() # Assert hidden_size == hidden_size_att, until we support different hidden_size and hidden_size_att assert hidden_size == hidden_size_att, "hidden_size should be equal to hidden_size_att (@TODO: support different hidden_size and hidden_size_att)" # Head size settings head_size = configMap.head_size self.head_size = head_size # Number of heads n_head = hidden_size_att // head_size assert hidden_size_att % head_size == 0, "hidden_size_att should be divisible by head_size" self.n_head = n_head # Backend self.tmix_backend = configMap.tmix_backend # Build the various params # --- with torch.no_grad(): # Note: for some data, you can reduce D_GATE_LORA or even remove this gate def calc_lora_rank(exponent, multiplier): return max(1, round(hidden_size ** exponent * multiplier / 32)) * 32 D_DECAY_LORA = calc_lora_rank(0.5, 1.8) D_AAA_LORA = calc_lora_rank(0.5, 1.8) D_MV_LORA = calc_lora_rank(0.5, 1.3) D_GATE_LORA = calc_lora_rank(0.8, 0.6) self.x_r = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype)) self.x_w = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype)) self.x_k = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype)) self.x_v = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype)) self.x_a = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype)) self.x_g = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype)) self.w0 = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype)) self.w1 = nn.Parameter(torch.empty(hidden_size, D_DECAY_LORA, device=device, dtype=dtype)) self.w2 = nn.Parameter(torch.empty(D_DECAY_LORA, hidden_size, device=device, dtype=dtype)) self.a0 = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype)) self.a1 = nn.Parameter(torch.empty(hidden_size,D_AAA_LORA, device=device, dtype=dtype)) self.a2 = nn.Parameter(torch.empty(D_AAA_LORA,hidden_size, device=device, dtype=dtype)) if layer_id > 0: self.v0 = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype)) self.v1 = nn.Parameter(torch.empty(hidden_size,D_MV_LORA, device=device, dtype=dtype)) self.v2 = nn.Parameter(torch.empty(D_MV_LORA,hidden_size, device=device, dtype=dtype)) self.g1 = nn.Parameter(torch.empty(hidden_size, D_GATE_LORA, device=device, dtype=dtype)) self.g2 = nn.Parameter(torch.empty(D_GATE_LORA, hidden_size, device=device, dtype=dtype)) self.k_k = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype)) self.k_a = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype)) self.r_k = nn.Parameter(torch.empty(n_head, head_size, device=device, dtype=dtype)) self.receptance = nn.Linear(hidden_size, hidden_size_att, bias=False, device=device, dtype=dtype) self.key = nn.Linear(hidden_size, hidden_size_att, bias=False, device=device, dtype=dtype) self.value = nn.Linear(hidden_size, hidden_size_att, bias=False, device=device, dtype=dtype) self.output = nn.Linear(hidden_size_att, hidden_size, bias=False, device=device, dtype=dtype) self.ln_x = nn.GroupNorm(n_head, hidden_size_att, device=device, dtype=dtype, eps=(1e-5)*head_size) def init_parameters(self): ''' Reset the parameters of the block, to an initial state used for training a model from scratch ''' configMap = self.configMap # Get required props hidden_size = configMap.hidden_size num_hidden_layers = configMap.num_hidden_layers # Get the layer id layer_id = self.layer_id # Get optional props device = configMap.get_device(None) dtype = configMap.get_dtype('bfloat16') # By default, hidden_size_ffn = hidden_size hidden_size_att = configMap.get_hidden_size_att() # Assert hidden_size == hidden_size_att, until we support different hidden_size and hidden_size_att assert hidden_size == hidden_size_att, "hidden_size should be equal to hidden_size_att (@TODO: support different hidden_size and hidden_size_att)" # Head size settings head_size = self.head_size # Number of heads n_head = hidden_size_att // head_size assert hidden_size_att % head_size == 0, "hidden_size_att should be divisible by head_size" # Reset the various params # --- with torch.no_grad(): ratio_0_to_1 = layer_id / (num_hidden_layers - 1) # 0 to 1 ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0 ddd = torch.ones(1, 1, hidden_size, device=device, dtype=dtype) for i in range(hidden_size): ddd[0, 0, i] = i / hidden_size # Note: for some data, you can reduce D_GATE_LORA or even remove this gate def calc_lora_rank(exponent, multiplier): return max(1, round(hidden_size ** exponent * multiplier / 32)) * 32 D_DECAY_LORA = calc_lora_rank(0.5, 1.8) D_AAA_LORA = calc_lora_rank(0.5, 1.8) D_MV_LORA = calc_lora_rank(0.5, 1.3) D_GATE_LORA = calc_lora_rank(0.8, 0.6) self.x_r.copy_(1.0 - torch.pow(ddd, 0.2 * ratio_1_to_almost0)) self.x_w.copy_(1.0 - torch.pow(ddd, 0.9 * ratio_1_to_almost0)) self.x_k.copy_(1.0 - (torch.pow(ddd, 0.9 * ratio_1_to_almost0) + 0.4 * ratio_0_to_1)) self.x_v.copy_(1.0 - (torch.pow(ddd, 0.4 * ratio_1_to_almost0) + 0.6 * ratio_0_to_1)) self.x_a.copy_(1.0 - torch.pow(ddd, 0.9 * ratio_1_to_almost0)) self.x_g.copy_(1.0 - torch.pow(ddd, 0.2 * ratio_1_to_almost0)) def ortho_init(x, scale): x = x.to(device) shape = x.shape if len(shape) == 2: gain = math.sqrt(shape[0] / shape[1]) if shape[0] > shape[1] else 1 nn.init.orthogonal_(x, gain=gain * scale) elif len(shape) == 3: gain = math.sqrt(shape[1] / shape[2]) if shape[1] > shape[2] else 1 for i in range(shape[0]): nn.init.orthogonal_(x[i], gain=gain * scale) else: assert False return x.to(device, dtype=dtype) # D_DECAY_LORA = max(32, int(round( (1.8*(hidden_size**0.5)) /32)*32)) decay_speed = torch.ones(hidden_size, device=device, dtype=dtype) for n in range(hidden_size): decay_speed[n] = -7 + 5 * (n / (hidden_size - 1)) ** (0.85 + 1.0 * ratio_0_to_1 ** 0.5) self.w0.copy_(decay_speed.reshape(1,1,hidden_size).to(device, dtype=dtype) + 0.5) # !!! 0.5 comes from F.softplus !!! self.w1.copy_(torch.zeros(hidden_size, D_DECAY_LORA, device=device, dtype=dtype)) self.w2.copy_(ortho_init(torch.zeros(D_DECAY_LORA, hidden_size), 0.1)) # D_AAA_LORA = max(32, int(round( (1.8*(hidden_size**0.5)) /32)*32)) # suggestion self.a0.copy_(torch.zeros(1,1,hidden_size, device=device, dtype=dtype)) self.a1.copy_(torch.zeros(hidden_size, D_AAA_LORA, device=device, dtype=dtype)) self.a2.copy_(ortho_init(torch.zeros(D_AAA_LORA, hidden_size), 0.1)) # D_MV_LORA = max(32, int(round( (1.3*(hidden_size**0.5)) /32)*32)) # suggestion if layer_id > 0: self.v0.copy_(torch.zeros(1,1,hidden_size, device=device, dtype=dtype)+1.0) self.v1.copy_(torch.zeros(hidden_size, D_MV_LORA, device=device, dtype=dtype)) self.v2.copy_(ortho_init(torch.zeros(D_MV_LORA, hidden_size), 0.1)) # D_GATE_LORA = max(32, int(round( (0.6*(hidden_size**0.8)) /32)*32)) # suggestion # Note: for some data, you can reduce D_GATE_LORA or even remove this gate self.g1.copy_(torch.zeros(hidden_size, D_GATE_LORA, device=device, dtype=dtype)) self.g2.copy_(ortho_init(torch.zeros(D_GATE_LORA, hidden_size), 0.1)) self.k_k.copy_(torch.ones(1,1,hidden_size, device=device, dtype=dtype)*0.85) self.k_a.copy_(torch.ones(1,1,hidden_size, device=device, dtype=dtype)) self.r_k.copy_(torch.zeros(n_head,head_size, device=device, dtype=dtype)) self.receptance = nn.Linear(hidden_size, hidden_size_att, bias=False, device=device, dtype=dtype) self.key = nn.Linear(hidden_size, hidden_size_att, bias=False, device=device, dtype=dtype) self.value = nn.Linear(hidden_size, hidden_size_att, bias=False, device=device, dtype=dtype) self.output = nn.Linear(hidden_size_att, hidden_size, bias=False, device=device, dtype=dtype) self.ln_x = nn.GroupNorm(n_head, hidden_size_att, device=device, dtype=dtype, eps=(1e-5)*head_size) def forward(self, x:Tensor, shift_state_in:Tensor, wkv_state_in:Tensor, v_first_val:Tensor) -> tuple[Tensor,Tensor,Tensor,Tensor]: ''' forwarding time mix given the model weights and the input tokens and states. Given: - Incoming token embedding size of shape [batch_size, seq_len, embedding_size] - Incoming states containing of shapes: [batch_size, state_size] ## Token Shift state, [batch_size, n_head, head_size, head_size] ## WKV state - Incoming v_first_val of shape [batch_size, seq_len, embedding_size] Returns a pair - output embedding of shape [batch_size, seq_len, embedding_size] - output state of shapes: [batch_size, state_size] ## Token Shift state, [batch_size, n_head, head_size, head_size] ## WKV state - output v_first_val of shape [batch_size, seq_len, embedding_size] ''' # Get the sizing BATCH_SIZE, SEQ_LEN, IN_EMB_SIZE = x.size() N_HEAD = self.n_head HEAD_SIZE = self.head_size # Ensure wkv_state_in is initialized if wkv_state_in is None: wkv_state_in = torch.zeros(BATCH_SIZE,N_HEAD,HEAD_SIZE,HEAD_SIZE, dtype=torch.float,device=w.device) else: wkv_state_in = wkv_state_in.clone() ########## ## x070 ########## shift_state_out = x[:, -1] dxprev = torch.cat((shift_state_in.unsqueeze(1), x[:, :-1]), dim=1) - x xr = x + dxprev * self.x_r xw = x + dxprev * self.x_w xk = x + dxprev * self.x_k xv = x + dxprev * self.x_v xa = x + dxprev * self.x_a xg = x + dxprev * self.x_g xx = dxprev r = self.receptance(xr) w = torch.tanh(xw @ self.w1) @ self.w2 k = self.key(xk) v = self.value(xv) g = torch.sigmoid(xg @ self.g1) @ self.g2 iclr = torch.sigmoid(self.a0 + (xa @ self.a1) @ self.a2) # a is "in-context learning rate" kk = F.normalize((k * self.k_k).view(BATCH_SIZE,SEQ_LEN,N_HEAD,-1), dim=-1, p=2.0).view(BATCH_SIZE, SEQ_LEN, IN_EMB_SIZE) k = k * (1 + (iclr-1) * self.k_a) if self.layer_id == 0: v_first_val = v # store the v of the first layer else: v = v + (v_first_val - v) * torch.sigmoid(self.v0 + (xv @ self.v1) @ self.v2) # add value residual tmix_backend = self.tmix_backend.lower() if tmix_backend == "auto": if triton is None or self.receptance.weight.device.type == "cpu": tmix_backend = "pytorch" else: tmix_backend = "cuda" if tmix_backend == "pytorch_ref" or tmix_backend == "pytorch_ref_ori": # Pure pytorch mode for rwkv attention # from .kernel.rwkv7_attn_pytorch import rwkv7_attn_pytorch_ref # Reference minimal compilation version w = torch.exp(-0.606531 * torch.sigmoid((self.w0 + w).float())) # 0.606531 = exp(-0.5) xx, wkv_state_out = rwkv7_attn_pytorch_ref(r, w, k, v, kk, iclr, BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, xx, wkv_state_in) elif tmix_backend == "pytorch_ref_fp32": # Pure pytorch mode for rwkv attention # from .kernel.rwkv7_attn_pytorch import rwkv7_attn_pytorch_ref_fp32 # Modified to follow the same logic as "cuda" version # w = torch.exp(-0.606531 * torch.sigmoid((self.w0 + w).float())) # 0.606531 = exp(-0.5) w = -F.softplus(-(self.w0 + w)) - 0.5 xx, wkv_state_out = rwkv7_attn_pytorch_ref_fp32(r, w, k, v, kk, iclr, BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, xx, wkv_state_in) elif tmix_backend == "pytorch": # Pure pytorch mode for rwkv attention # from .kernel.rwkv7_attn_pytorch import rwkv7_attn_pytorch # Tweaked pytorch compile varient w = torch.exp(-0.606531 * torch.sigmoid((self.w0 + w).float())) # 0.606531 = exp(-0.5) xx, wkv_state_out = rwkv7_attn_pytorch(r, w, k, v, kk, iclr, BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, xx, wkv_state_in) elif tmix_backend == "triton": if triton is None: raise ValueError("Triton not available, unable to load triton kernel") # from .kernel.rwkv7_attn_triton import rwkv7_attn_triton w = -F.softplus(-(self.w0 + w)) - 0.5 xx, wkv_state_out = rwkv7_attn_triton(r, w, k, v, kk, iclr, s0=wkv_state_in) elif tmix_backend == "triton_bighead": if triton is None: raise ValueError("Triton not available, unable to load triton kernel") # from .kernel.rwkv7_attn_triton import rwkv7_attn_triton_bighead w = -F.softplus(-(self.w0 + w)) - 0.5 xx, wkv_state_out = rwkv7_attn_triton_bighead(r, w, k, v, kk, iclr, s0=wkv_state_in) elif tmix_backend == "cuda_ref": # Cuda based method for rwkv attention # from .kernel.rwkv7_attn_cuda import rwkv7_attn_cuda_ref # Reference cuda version (no state output) w = -F.softplus(-(self.w0 + w)) - 0.5 xx, wkv_state_out = rwkv7_attn_cuda_ref(r, w, k, v, kk, iclr, s0=wkv_state_in) elif tmix_backend == "cuda": # Cuda based method for rwkv attention # from .kernel.rwkv7_attn_cuda import rwkv7_attn_cuda # Modified cuda version (with state output) w = -F.softplus(-(self.w0 + w)) - 0.5 xx, wkv_state_out = rwkv7_attn_cuda(r, w, k, v, kk, iclr, s0=wkv_state_in) elif tmix_backend == "fla": # FLA based method for rwkv attention # from .kernel.rwkv7_attn_fla import rwkv7_attn_fla # FLA runs with the softplus w w = -F.softplus(-(self.w0 + w)) - 0.5 xx, wkv_state_out = rwkv7_attn_fla(r, w, k, v, kk, iclr, BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, xx, wkv_state_in) elif tmix_backend == "fla_fused" or tmix_backend == "fused_fla": # FLA based method for rwkv attention # from .kernel.rwkv7_attn_fla import rwkv7_attn_fused_reccurent_fla # FLA runs with the softplus w w = -F.softplus(-(self.w0 + w)) - 0.5 xx, wkv_state_out = rwkv7_attn_fused_reccurent_fla(r, w, k, v, kk, iclr, BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, xx, wkv_state_in) else: raise ValueError(f"Unknown tmix_backend: {tmix_backend}") # wkv_state_in normalization of type if wkv_state_in is not None: wkv_state_out = wkv_state_out.to(wkv_state_in.dtype) ######## cuda-based method # wkv_state_out = wkv_state_in.clone() # w = -F.softplus(-(self.w0 + w)) - 0.5 # soft-clamp to (-inf, -0.5) # xx = RWKV7_OP(wkv_state_out, r, w, k, v, -kk, kk*a) ######## cuda-based method xx = self.ln_x(xx.view(BATCH_SIZE * SEQ_LEN, IN_EMB_SIZE)).view(BATCH_SIZE, SEQ_LEN, IN_EMB_SIZE) xx = xx + ((r.view(BATCH_SIZE,SEQ_LEN,N_HEAD,-1)*k.view(BATCH_SIZE,SEQ_LEN,N_HEAD,-1)*self.r_k).sum(dim=-1, keepdim=True) * v.view(BATCH_SIZE,SEQ_LEN,N_HEAD,-1)).view(BATCH_SIZE,SEQ_LEN,IN_EMB_SIZE) xx = self.output(xx * g) return xx, shift_state_out, wkv_state_out, v_first_val @torch.compile(mode="default") def forward_with_default_compile(self, in_x:Tensor, shift_state_in:Tensor, wkv_state_in:Tensor, v_first_val_in:Tensor, out_x:Tensor, shift_state_out:Tensor, wkv_state_out:Tensor, v_first_val_out:Tensor) -> tuple[Tensor,Tensor,Tensor,Tensor]: ''' Compiled varient of the forward function With no new tensors being created for the output Useful for static memory allocation optimizations inference ''' out_x[:], shift_state_out[:], wkv_state_out[:], v_first_val_out[:] = self.forward(in_x, shift_state_in, wkv_state_in, v_first_val_in) return out_x, shift_state_out, wkv_state_out, v_first_val_out @torch.compile(mode="reduce-overhead") def forward_with_reduce_compile(self, in_x:Tensor, shift_state_in:Tensor, wkv_state_in:Tensor, v_first_val:Tensor) -> tuple[Tensor,Tensor,Tensor,Tensor]: ''' Compiled varient of the forward function With no input tensor being modified. Useful for reduce-overhead compile mode ''' return self.forward(in_x, shift_state_in, wkv_state_in, v_first_val) # --------------------------------- # # Model state handling # # --------------------------------- def load_from_model_state_dict(self, model_state_dict: dict, layer_id:int, non_blocking:bool=True): ''' Given the Full/partial RWKV model weights, loaded via `torch.load` Setup the the current module weights, using the layer_id ''' # Get the current state_dict current_state_dict = self.state_dict() # Iterate each parameter in the state_dict, and compare from the model for n in current_state_dict: model_key = f"blocks.{layer_id}.att.{n}" if model_key not in model_state_dict: continue # Copy the values from the state_dict try: current_state_dict[n].copy_(model_state_dict[model_key], non_blocking=non_blocking) except Exception as e: print(f"[ERROR] loading: {model_key} | model shape: {current_state_dict[n].shape} | weight shape: {model_state_dict[model_key].shape}") raise e # ---------------- # block/rwkv7_layer_block.py # ---------------- import torch from torch import nn from torch import Tensor from typing import Union from torch.nn import functional as F # from .rwkv7_block_config_map import RWKV7BlockConfigMap # from .rwkv7_channel_mix import RWKV7ChannelMix # from .rwkv7_time_mix import RWKV7TimeMix class RWKV7LayerBlock(torch.nn.Module): ''' layer block for RWKV V7 ''' def __init__(self, configMap: Union[RWKV7BlockConfigMap, any]): super().__init__() configMap:RWKV7BlockConfigMap = RWKV7BlockConfigMap.normalize(configMap) self.configMap = configMap # Get required props hidden_size = configMap.hidden_size device = configMap.get_device(None) dtype = configMap.get_dtype('bfloat16') dropout_rate = configMap.dropout_rate # Get valid layer_id layer_id = configMap.get_layer_id(-1) assert layer_id >= 0, f'layer_id must be >= 0, got {layer_id}' # Setup the layernorms, and mixes self.ln1 = nn.LayerNorm(hidden_size, device=device, dtype=dtype) self.ln2 = nn.LayerNorm(hidden_size, device=device, dtype=dtype) if layer_id == 0: self.ln0 = nn.LayerNorm(hidden_size, device=device, dtype=dtype) else: self.ln0 = nn.Identity(device=device) # Setup the time and channel mix self.att = RWKV7TimeMix(configMap) self.ffn = RWKV7ChannelMix(configMap) # Setup droupout at block level if dropout_rate > 0.0: self.drop0 = nn.Dropout(p = dropout_rate,device=device) self.drop1 = nn.Dropout(p = dropout_rate,device=device) else: self.drop0 = nn.Identity(device=device) self.drop1 = nn.Identity(device=device) def init_parameters(self): ''' Reset the parameters of the block, to an initial state used for training a model from scratch ''' configMap = self.configMap # Get required props hidden_size = configMap.hidden_size device = configMap.get_device(None) dtype = configMap.get_dtype('bfloat16') dropout_rate = configMap.dropout_rate # Get valid layer_id layer_id = configMap.get_layer_id(-1) assert layer_id >= 0, f'layer_id must be >= 0, got {layer_id}' # Redo the Setup for the layernorms, and mixes self.ln1 = nn.LayerNorm(hidden_size, device=device, dtype=dtype) self.ln2 = nn.LayerNorm(hidden_size, device=device, dtype=dtype) if layer_id == 0: self.ln0 = nn.LayerNorm(hidden_size, device=device, dtype=dtype) else: self.ln0 = nn.Identity(device=device) # Call the sub blocks init_parameters self.att.init_parameters() self.ffn.init_parameters() def forward( self, x:torch.Tensor, last_state: tuple[torch.Tensor,torch.Tensor,torch.Tensor], v_first:torch.Tensor ) -> tuple[torch.Tensor,tuple[torch.Tensor,torch.Tensor,torch.Tensor],torch.Tensor]: ''' Forward the block given the input tokens and the last state Last state is a tuple of the following - time mix shift state - time mix wkv state - channel mix shift state Returns a pair of the output embedding, v_first and the next state ''' # # Ensure everything is in the right device # x = x.to(self.ln1.weight.device) # last_state = [ s.to(self.ln1.weight.device) for s in last_state ] # Note, that this only applies for layer 0 ln0_out = self.ln0(x) # assert self.ln1(x) is not None # assert last_state.tmix_shift is not None # assert last_state.tmix_wkv is not None att_out, tmix_shift, tmix_wkv, v_first = self.att( self.ln1(ln0_out), last_state[0], # tmix_shift, last_state[1], # tmix_wkv v_first ) # x = x + att_out x = self.drop0(ln0_out + att_out) ffn_out, ffn_state = self.ffn( self.ln2(x), last_state[2] # cmix_shift, ) # x = x + ffn_out x = self.drop1(x + ffn_out) # # Debugging for NaN # layer_id = self.configMap.get_layer_id(-1) # assert torch.isnan(att_out).sum() == 0, f'NaN detected att_out @ layer {layer_id}' # assert torch.isnan(ffn_out).sum() == 0, f'NaN detected ffn_out @ layer {layer_id}' # assert torch.isnan(v_first).sum() == 0, f'NaN detected v_first @ layer {layer_id}' # assert torch.isnan(tmix_shift).sum() == 0, f'NaN detected tmix_shift @ layer {layer_id}' # assert torch.isnan(tmix_wkv).sum() == 0, f'NaN detected tmix_wkv @ layer {layer_id}' # assert torch.isnan(ffn_state).sum() == 0, f'NaN detected ffn_state @ layer {layer_id}' # assert torch.isnan(x).sum() == 0, f'NaN detected block out @ layer {layer_id}' return x, (tmix_shift, tmix_wkv, ffn_state), v_first @torch.compile(mode="default") def forward_with_default_compile( self, in_x:torch.Tensor, in_state: tuple[torch.Tensor,torch.Tensor,torch.Tensor], in_v_first:torch.Tensor, out_x:torch.Tensor, out_state: tuple[torch.Tensor,torch.Tensor,torch.Tensor], out_v_first:torch.Tensor ) -> tuple[torch.Tensor,tuple[torch.Tensor,torch.Tensor,torch.Tensor],torch.Tensor]: ''' Compiled varient of the forward function With no new tensors being created for the output Useful for static memory allocation optimizations inference ''' out_x[:], tmp_state, out_v_first[:] = self.forward(in_x, in_state, in_v_first) out_state[0][:], out_state[1][:], out_state[2][:] = tmp_state return out_x, out_state, out_v_first @torch.compile(mode="reduce-overhead") def forward_with_reduce_compile(self, in_x: torch.Tensor, in_state: tuple[torch.Tensor,torch.Tensor,torch.Tensor], in_v_first:torch.Tensor) -> tuple[torch.Tensor,tuple[torch.Tensor,torch.Tensor,torch.Tensor],torch.Tensor]: ''' Compiled varient of the forward function ''' return self.forward(in_x, in_state, in_v_first) def load_from_model_state_dict(self, model_state_dict:dict, layer_id:int=-1, non_blocking:bool=True): ''' Given the Full/partial RWKV model weights, load the block weights accordingly ''' if layer_id <= -1: layer_id = self.configMap.get_layer_id(-1) assert layer_id >= 0, f'layer_id must be >= 0, got {layer_id}' # Get the current state_dict current_state_dict = self.state_dict() # Iterate each parameter in the state_dict, and compare from the model for n in current_state_dict: model_key = f"blocks.{layer_id}.{n}" if model_key not in model_state_dict: continue # Copy the values from the state_dict try: current_state_dict[n].copy_(model_state_dict[model_key], non_blocking=non_blocking) except Exception as e: print(f"[ERROR] loading: {model_key} | model shape: {current_state_dict[n].shape} | weight shape: {model_state_dict[model_key].shape}") raise e # ---------------- # model/rwkv7_goose_config_map.py # ---------------- from dataclasses import dataclass from typing import Optional from typing import Union import torch # from ..block.rwkv7_block_config_map import RWKV7BlockConfigMap @dataclass class RWKV7GooseConfigMap(RWKV7BlockConfigMap): # This is the world tokenizer size vocab_size: int = 65536 init_state_wkv: bool = False # --- # Initializer, with excess arg ignore # --- def __init__( self, vocab_size: int = 65536, init_state_wkv: bool = False, **kwargs ) -> None: self.vocab_size = vocab_size self.init_state_wkv = init_state_wkv super().__init__(**kwargs) @staticmethod def normalize(config_map: any) -> 'RWKV7GooseConfigMap': ''' Converts either maps, objs or RWKV7BlockConfigMap ''' if isinstance(config_map, RWKV7GooseConfigMap): return config_map if isinstance(config_map, dict): return RWKV7GooseConfigMap(**config_map) if hasattr(config_map, '__dict__'): return RWKV7GooseConfigMap(**config_map.__dict__) raise ValueError(f"Unsupported config_map type: {type(config_map)}") @staticmethod def from_model_state_dict(state_dict: dict, **kwargs) -> 'RWKV7GooseConfigMap': ''' Converts the state dict to the config map ''' # Iterate and count the layers num_hidden_layers = 0 for key in state_dict.keys(): if key.startswith('blocks.'): idx = key.split('.')[1] num_hidden_layers = max(num_hidden_layers, int(idx)+1) # Enable wkv_state if 'init_state.0.wkv' in state_dict: kwargs['init_state_wkv'] = True # Initialize the config map, with the configured values return RWKV7GooseConfigMap( num_hidden_layers=num_hidden_layers, hidden_size=state_dict['emb.weight'].shape[1], vocab_size=state_dict['emb.weight'].shape[0], # init_state_wkv=hasattr(state_dict, 'init_state.0.wkv'), # n_head=state_dict['blocks.0.att.r_k'].shape[0], head_size=state_dict['blocks.0.att.r_k'].shape[1], hidden_size_att=state_dict['blocks.0.att.key.weight'].shape[1], hidden_size_ffn=state_dict['blocks.0.ffn.key.weight'].shape[0], **kwargs ) # ---------------- # model/rwkv7_goose_model.py # ---------------- import torch from torch import nn from torch import Tensor from typing import Union # from .rwkv7_goose_config_map import RWKV7GooseConfigMap # from ..block.rwkv7_layer_block import RWKV7LayerBlock class RWKV7GooseModel(nn.Module): ''' RWKV7 Goose Model architecture Simplified implementation ''' def __init__(self, config: Union[RWKV7GooseConfigMap, any, None] = None): # Initialize the base class super().__init__() # Normalize the config configMap:RWKV7GooseConfigMap = RWKV7GooseConfigMap.normalize(config) self.configMap = configMap # Get the required prop num_hidden_layers = configMap.num_hidden_layers vocab_size = configMap.vocab_size device = configMap.get_device(None) dtype = configMap.get_dtype('bfloat16') hidden_size = configMap.hidden_size # Embedding layer self.emb = nn.Embedding(vocab_size, hidden_size, device=device, dtype=dtype) # main block layers blockList = [None]*num_hidden_layers for i in range(num_hidden_layers): blockList[i] = RWKV7LayerBlock(configMap.new_block_config_map(layer_id=i)) self.blocks = nn.ModuleList(blockList) # ln_out and head self.ln_out = nn.LayerNorm(hidden_size, device=device, dtype=dtype) self.head = nn.Linear(hidden_size, vocab_size, bias=False, device=device, dtype=dtype) # init state tuning support if configMap.init_state_wkv: stateTuneList = [None]*num_hidden_layers for i in range(num_hidden_layers): stateTuneList[i] = nn.ParameterDict({ "wkv": nn.Parameter(torch.zeros(hidden_size // 64, 64, 64, device=device, dtype=dtype)), }) self.init_state = nn.ParameterList(stateTuneList) def init_parameters(self): ''' Reset the parameters of the block, to an initial state used for training a model from scratch ''' # Get the required prop configMap = self.configMap num_hidden_layers = configMap.num_hidden_layers vocab_size = configMap.vocab_size device = configMap.get_device(None) dtype = configMap.get_dtype('bfloat16') hidden_size = configMap.hidden_size # Iterate and reset the blocks for i in range(num_hidden_layers): self.blocks[i].init_parameters() # Reinit the Embedding layer self.emb = nn.Embedding(vocab_size, hidden_size, device=device, dtype=dtype) # Reinit the ln_out and head self.ln_out = nn.LayerNorm(hidden_size, device=device, dtype=dtype) if self.head is not None: self.head = nn.Linear(hidden_size, vocab_size, bias=False, device=device, dtype=dtype) # Reinit the init state tuning support if configMap.init_state_wkv: stateTuneList = [None]*num_hidden_layers for i in range(num_hidden_layers): stateTuneList[i] = nn.ParameterDict({ "wkv": nn.Parameter(torch.zeros(hidden_size // 64, 64, 64, device=device, dtype=torch.float)), }) self.init_state = nn.ParameterList(stateTuneList) def load_from_model_state_dict(self, state_dict: dict, non_blocking:bool=True): ''' Given the Full/partial RWKV model weights, loaded via `torch.load` Setup the RWKV_TimeMix model weights, using the layer_id ''' for i, block in enumerate(self.blocks): block.load_from_model_state_dict(state_dict, i, non_blocking=non_blocking) self.ln_out.weight.data.copy_(state_dict['ln_out.weight'], non_blocking=True) self.ln_out.bias.data.copy_(state_dict['ln_out.bias'], non_blocking=True) self.head.weight.data.copy_(state_dict['head.weight'], non_blocking=True) self.emb.weight.data.copy_(state_dict['emb.weight'], non_blocking=True) ### --- ### ### Init state handling ### ### --- def get_init_state(self, batch_size:int=1, skip_init_state:bool=False) -> list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]]: ''' Get an initialized copy of the model state, for the given batch size ''' # Get required configs hidden_size = self.configMap.hidden_size init_state_wkv = self.configMap.init_state_wkv num_hidden_layers = self.configMap.num_hidden_layers # Prepare the initial state init_state = [ None for i in range(num_hidden_layers) ] for i in range(num_hidden_layers): device = self.blocks[i].ln1.weight.data.device dtype = self.blocks[i].ln1.weight.data.dtype # Use the saved init_state if enabled # TODO: Consider letting the wkv_state dtype be a parameter wkv_state = torch.zeros(batch_size, hidden_size // 64, 64, 64, device=device, dtype=torch.float) if init_state_wkv and skip_init_state == False: init_wkv = self.init_state[i]["wkv"] for b in range(batch_size): wkv_state[b][:] = init_wkv # Prepare the state init_state[i] = ( torch.zeros(batch_size, hidden_size, device=device, dtype=dtype), wkv_state, torch.zeros(batch_size, hidden_size, device=device, dtype=dtype) ) return init_state ### --- ### ### Model Forward ### ### --- def forward( self, idx:torch.Tensor, prv_stateList:list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]] = None, ret_stateList:list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]] = None, ) -> tuple[torch.Tensor,list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]]]: ''' Forward the block set, given the input tokens and the last state Last state is a list of tuple of the following - time mix shift state - time mix wkv state - channel mix shift state Returns a pair of the output embedding and the next state ''' # Prepare the state, with the batch size if prv_stateList is None: prv_stateList = self.get_init_state(idx.shape[0]) # If no return state is set, let _forward_internal, set it up if ret_stateList is None: ret_stateList = [ None for i in range(self.configMap.num_hidden_layers) ] return self._forward_internal(idx, prv_stateList, ret_stateList, overwrite_ret_tensor=False) # Forward internally return self._forward_internal(idx, prv_stateList, ret_stateList, overwrite_ret_tensor=True) def _forward_block_hook(self, block:RWKV7LayerBlock, x_hidden_state:torch.Tensor, prv_stateList:list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]], v_first:torch.Tensor ) -> tuple[torch.Tensor,tuple[torch.Tensor,torch.Tensor,torch.Tensor],torch.Tensor]: ''' Forward block hook operation, that is easily overridable. To implement gradient checkpointing for use in various trainers ''' x_hidden_state = x_hidden_state.to(block.ln1.weight.device, non_blocking=True) return block(x_hidden_state, prv_stateList, v_first) def _forward_internal( self, idx:torch.Tensor, prv_stateList:list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]], ret_stateList:list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]], overwrite_ret_tensor:bool=False ) -> tuple[torch.Tensor,list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]]]: ''' Internal forward operations, which assumes the state is already initialized Due to the lack of safety checks, this should not be used directly ''' # Lets get the embedding idx = idx.to(self.emb.weight.device, non_blocking=True) x_hidden_state = self.emb(idx) v_first = None # Iterate the block layers, compute the x_hidden_state if overwrite_ret_tensor: for i, block in enumerate(self.blocks): # x_hidden_state, last_block_state, v_first = block(x_hidden_state, prv_stateList[i], v_first) x_hidden_state, last_block_state, v_first = self._forward_block_hook(block, x_hidden_state, prv_stateList[i], v_first) ret_stateList[i][0][:] = last_block_state[0] ret_stateList[i][1][:] = last_block_state[1] ret_stateList[i][2][:] = last_block_state[2] else: ret_stateList = [ None for i in range( len(self.blocks) ) ] for i, block in enumerate(self.blocks): # x_hidden_state, ret_sublist, v_first = block(x_hidden_state, prv_stateList[i], v_first) x_hidden_state, ret_sublist, v_first = self._forward_block_hook(block, x_hidden_state, prv_stateList[i], v_first) ret_stateList[i] = ret_sublist # Final layer norm, and head x_hidden_state = x_hidden_state.to(self.ln_out.weight.device, non_blocking=True) x_hidden_state = self.ln_out(x_hidden_state) x_out = self.head(x_hidden_state) # Return the output and the state list return x_out, ret_stateList @torch.compile(mode="default") def forward_with_default_compile( self, idx:torch.Tensor, prv_stateList:list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]], ret_stateList:list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]], ) -> tuple[torch.Tensor,list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]]]: ''' Compiled varient of the forward function With no new tensors being created for the output Useful for static memory allocation optimizations inference ''' # Forward internally return self._forward_internal(idx, prv_stateList, ret_stateList, overwrite_ret_tensor=True) @torch.compile(mode="reduce-overhead") def forward_with_reduce_compile( self, in_idx:torch.Tensor, prv_stateList:list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]] ) -> tuple[torch.Tensor,list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]]]: ''' Compiled varient of the forward function, requires previous state to be passed ''' return self._forward_internal(in_idx, prv_stateList, None, overwrite_ret_tensor=False)