rwkv7-my-hf-test-code-do-not-use / modeling_blocks_rwkv7.py
picocreator's picture
Upload 13 files
33b8599 verified
# ========== AUTO GENERATED FILE =========
# This file is auto generated by 'hf_builder.py', do not edit this file directly
# As part of the RWKV/RWKV-block project
# ========== =================== =========
# ----------------
# block/kernel/rwkv7_attn_pytorch.py
# ----------------
import torch
# Enable tensorfloat 32
torch.set_float32_matmul_precision('high')
# Handles the RWKV v7 attention mechanic, in pure pytorch
def rwkv7_attn_pytorch(
r,w,k,v, kk,a,
BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE,
xx, wkv_state_in
):
### Reference implement
# return rwkv7_attn_pytorch_ref(
# r,w,k,v, kk,a,
# BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE,
# xx, wkv_state_in
# )
###
# # This works, but it has too much of a vram overhead
###
# return rwkv7_attn_pytorch_v2_chunk_w_compile_break(
# r,w,k,v, kk,a,
# BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE,
# xx, wkv_state_in
# )
###
# > per 9k chunk, per block, on a 4090 ...
# > with forward_with_reduce_compile on the timemix ...
#
# Somehow...
# The reference implement takes: 2281ms
# The chunked version takes: 389ms (chunksize 256)
# Get the shape
B,T,HC = w.shape
# Compute the chunks
chunk_size = 256
chunk_count = SEQ_LEN // chunk_size
chunk_remainder = SEQ_LEN % chunk_size
# The wkv_state_out
wkv_state_out = wkv_state_in.float()
# # List of tensor to build
# xlist = []
xx = xx.clone()
# Loop over the chunks
for i in range(chunk_count):
sta = i * chunk_size
end = sta + chunk_size
xx[:,sta:end], wkv_state_out = rwkv7_attn_pytorch_v2_chunk_w_compile_break(
# xpart, wkv_state_out = rwkv7_attn_pytorch_chunk_with_w_compile_break(
r[:,sta:end],w[:,sta:end],k[:,sta:end],v[:,sta:end],
kk[:,sta:end],a[:,sta:end],
BATCH_SIZE, chunk_size, N_HEAD, HEAD_SIZE,
# xx[:,sta:end], wkv_state_out
torch.zeros(B,chunk_size,HC, dtype=xx.dtype, device=xx.device), wkv_state_out
)
# xlist.append(xpart)
# Handle the remainder
if chunk_remainder > 0:
sta = chunk_count * chunk_size
end = sta + chunk_remainder
xx[:,sta:end], wkv_state_out = rwkv7_attn_pytorch_v2_chunk_w_compile_break(
# xpart, wkv_state_out = rwkv7_attn_pytorch_chunk_with_w_compile_break(
r[:,sta:end],w[:,sta:end],k[:,sta:end],v[:,sta:end],
kk[:,sta:end],a[:,sta:end],
BATCH_SIZE, chunk_remainder, N_HEAD, HEAD_SIZE,
# xx[:,sta:end], wkv_state_out,
torch.zeros(B,chunk_remainder,HC, dtype=xx.dtype, device=xx.device), wkv_state_out,
# offset=0, chunk_size=chunk_remainder
)
# xlist.append(xpart)
# # Concatenate the list
# xx = torch_cat_no_compiler(xlist, dim=1)
# Return the output
return xx, wkv_state_out.to(dtype=wkv_state_in.dtype)
####################################################################################################
# Working reference copy, that has been validated to be "identical" to the reference implementation
# However this has known pytorch compilation issues, hence the modified chunk wise version is used
# instead for an approximate 5x speed up
####################################################################################################
@torch.compiler.disable()
def rwkv7_attn_pytorch_ref(
r,w,k,v, kk,a,
BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE,
xx, wkv_state_in
):
######## pure pytorch method
# See: https://github.com/BlinkDL/RWKV-LM/blob/d4c42b2cac10f8f3896ce153e2310dc763662b7a/RWKV-v7/rwkv_v7_demo_fast.py#L238
########
vk_state = wkv_state_in.float()
for t in range(SEQ_LEN):
r_, w_, k_, v_, kk_, a_ = r[:,t], w[:,t], k[:,t], v[:,t], kk[:,t], a[:,t]
vk = v_.view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1) @ k_.view(BATCH_SIZE,N_HEAD,1,HEAD_SIZE)
ab = (-kk_).view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1) @ (kk_*a_).view(BATCH_SIZE,N_HEAD,1,HEAD_SIZE)
vk_state = (vk_state * w_.view(BATCH_SIZE,N_HEAD,1,HEAD_SIZE).float() + vk_state @ ab.float() + vk.float())
xx[:,t] = ((vk_state.to(dtype=xx.dtype) @ r_.view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1)).view(BATCH_SIZE,N_HEAD*HEAD_SIZE))
wkv_state_out = vk_state.to(dtype=wkv_state_in.dtype)
return xx, wkv_state_out
####################################################################################################
####################################################################################################
# Modified reference computation done in fp32,
# with changes made to bring the result closer to the cuda kernel
####################################################################################################
@torch.compiler.disable()
def rwkv7_attn_pytorch_ref_fp32(
r,w,k,v, kk, iclr,
BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE,
xx, wkv_state_in
):
######## pure pytorch method (modified for fp32)
# See: https://github.com/BlinkDL/RWKV-LM/blob/d4c42b2cac10f8f3896ce153e2310dc763662b7a/RWKV-v7/rwkv_v7_demo_fast.py#L238
########
w = (-w.float().exp()).exp()
# wkv_state_in = torch.zeros(BATCH_SIZE,N_HEAD,HEAD_SIZE,HEAD_SIZE, dtype=torch.float,device=w.device)
vk_state = wkv_state_in.float()
a = -kk
b = kk * iclr
for t in range(SEQ_LEN):
r_, w_, k_, v_, a_, b_= r[:,t].float(), w[:,t].float(), k[:,t].float(), v[:,t].float(), a[:,t].float(), b[:,t].float()
# ab = (-kk_).view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1) @ (kk_*a_).view(BATCH_SIZE,N_HEAD,1,HEAD_SIZE)
vk = v_.view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1) @ k_.view(BATCH_SIZE,N_HEAD,1,HEAD_SIZE)
vk_state = (vk_state * w_.view(BATCH_SIZE,N_HEAD,1,HEAD_SIZE).float() + vk_state @ a_.float().view(BATCH_SIZE, N_HEAD,HEAD_SIZE,1) @ b_.view(BATCH_SIZE, N_HEAD,1,HEAD_SIZE) + vk.float())
xx[:,t] = ((vk_state @ r_.view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1)).view(BATCH_SIZE,N_HEAD*HEAD_SIZE)).to(dtype=xx.dtype)
wkv_state_out = vk_state.to(dtype=wkv_state_in.dtype)
return xx, wkv_state_out
####################################################################################################
def rwkv7_attn_pytorch_chunk(
r,w,k,v, kk,a,
BATCH_SIZE, N_HEAD, HEAD_SIZE,
xx, wkv_state_in,
offset=0, chunk_size=16
):
'''
Chunked version of the RWKV7 attention, for better performance.
If the chunk size is less then 128, this is generally compilable
This is used by the triton/cuda implement, for the remaining % 16 chunks
'''
######## pure pytorch method
# See: https://github.com/BlinkDL/RWKV-LM/blob/d4c42b2cac10f8f3896ce153e2310dc763662b7a/RWKV-v7/rwkv_v7_demo_fast.py#L238
########
vk_state = wkv_state_in.float()
for i in range(chunk_size):
t = offset + i
r_, w_, k_, v_, kk_, a_ = r[:,t], w[:,t], k[:,t], v[:,t], kk[:,t], a[:,t]
vk = v_.view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1) @ k_.view(BATCH_SIZE,N_HEAD,1,HEAD_SIZE)
ab = (-kk_).view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1) @ (kk_*a_).view(BATCH_SIZE,N_HEAD,1,HEAD_SIZE)
vk_state = (vk_state * w_.view(BATCH_SIZE,N_HEAD,1,HEAD_SIZE).float() + vk_state @ ab.float() + vk.float())
xx[:,t] = (vk_state.to(dtype=xx.dtype) @ r_.view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1)).view(BATCH_SIZE,N_HEAD*HEAD_SIZE)
wkv_state_out = vk_state.to(dtype=wkv_state_in.dtype)
return xx, wkv_state_out
def rwkv7_attn_pytorch_v2_chunk_w_compile_break(
r,w,k,v, kk,a,
BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE,
xx, wkv_state_in
):
'''
Chunked version of the RWKV7 attention, for better performance
'''
full_vk_ = v.view(BATCH_SIZE,SEQ_LEN,N_HEAD, HEAD_SIZE,1) @ k.view(BATCH_SIZE,SEQ_LEN,N_HEAD, 1,HEAD_SIZE)
full_iclr_ = (kk * a).view(BATCH_SIZE,SEQ_LEN,N_HEAD,1,HEAD_SIZE)
full_ab = (-kk).view(BATCH_SIZE,SEQ_LEN,N_HEAD, HEAD_SIZE,1) @ full_iclr_
wkv_xx = torch.empty(BATCH_SIZE,SEQ_LEN,N_HEAD,HEAD_SIZE,HEAD_SIZE, dtype=xx.dtype, device=xx.device)
wkv_xx, wkv_state_out = rwkv7_attn_pytorch_v2_inner_w_compile_break(
r,w,
full_vk_, full_ab,
BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE,
wkv_xx, wkv_state_in
# xx, wkv_state_in
)
# if BATCH_SIZE != 1:
# print("BATCH_SIZE != 1 : ", BATCH_SIZE)
# if SEQ_LEN != 256:
# print("SEQ_LEN != 256 : ", SEQ_LEN)
# xx[:,t] = ((wkv_state.to(dtype=xx.dtype) @ r_.view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1)).view(BATCH_SIZE,N_HEAD*HEAD_SIZE))
xx[:] = (wkv_xx.to(dtype=xx.dtype) @ r.view(BATCH_SIZE,SEQ_LEN,N_HEAD,HEAD_SIZE,1)).view(BATCH_SIZE,SEQ_LEN,N_HEAD*HEAD_SIZE)
return xx, wkv_state_out
@torch.compiler.disable()
def rwkv7_attn_pytorch_v2_inner_w_compile_break(
r, w,
full_vk_, full_ab,
BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE,
xx, wkv_state_in
):
'''
Isolated sub-function with no compilation
'''
return rwkv7_attn_pytorch_v2_inner_jit(
r, w,
full_vk_, full_ab,
BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE,
xx, wkv_state_in
)
# @torch.compile(fullgraph=True)
@torch.jit.script
def rwkv7_attn_pytorch_v2_inner_jit(
r, w,
full_vk_, full_ab,
BATCH_SIZE:int, SEQ_LEN:int, N_HEAD:int, HEAD_SIZE:int,
wkv_xx, wkv_state_in
):
'''
Isolated sub-function with JIT
'''
# wkv_xx = torch.zeros(BATCH_SIZE,SEQ_LEN,N_HEAD,HEAD_SIZE,HEAD_SIZE, dtype=xx.dtype,device=xx.device)
# wkv_state_in = torch.zeros(BATCH_SIZE,N_HEAD,HEAD_SIZE,HEAD_SIZE, dtype=torch.float,device=w.device)
wkv_state = wkv_state_in
for t in range(SEQ_LEN):
# r_ = r[:,t]
# w_ = w[:,t]
# vk = full_vk_[:,t].view(BATCH_SIZE,N_HEAD,HEAD_SIZE,HEAD_SIZE)
# ab = full_ab[:,t].view(BATCH_SIZE,N_HEAD,HEAD_SIZE,HEAD_SIZE)
wkv_state = (wkv_state * w[:,t].view(BATCH_SIZE,N_HEAD,1,HEAD_SIZE).float() + wkv_state @ full_ab[:,t].view(BATCH_SIZE,N_HEAD,HEAD_SIZE,HEAD_SIZE).float() + full_vk_[:,t].view(BATCH_SIZE,N_HEAD,HEAD_SIZE,HEAD_SIZE).float()).clone()
wkv_xx[:,t] = wkv_state.to(dtype=w.dtype)
return wkv_xx, wkv_state
# xx[:,t] = ((wkv_state.to(dtype=xx.dtype) @ r_.view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1)).view(BATCH_SIZE,N_HEAD*HEAD_SIZE))
# return xx, wkv_state
# ----------------
# block/kernel/rwkv7_attn_cuda.py
# ----------------
import torch, os, time
# from .rwkv7_attn_pytorch import rwkv7_attn_pytorch_chunk
####################################################################################################
# Stateless reference implementation
####################################################################################################
def load_ref_wkv_cuda_kernel(CHUNK_LEN = 16, HEAD_SIZE = 64):
from torch.utils.cpp_extension import load
# load_name = f"wind_backstepping_C{HEAD_SIZE}_L{CHUNK_LEN}"
load_name = "wind_backstepping"
load_file = "wkv7"
# Check if the load_name is already loaded
if load_name in torch.ops:
return torch.ops.wind_backstepping
# Logging of warning usage for reference implementation
print("[WARNING] Reference CUDA kernel does not support input RWKV state, and is used only for training/validaiton purposes")
# Get the this script file path, to cmpute the cuda path
this_file_path = os.path.dirname(os.path.abspath(__file__))
# # Get the device compute capability
# cuda_device = torch.cuda.current_device()
# compute_capability = torch.cuda.get_device_capability(cuda_device)
# compute_capability_str = f"{compute_capability[0]}{compute_capability[1]}"
# print("[INFO] Using compute capability:", compute_capability_str)
# Load the kernel, there is some wierd edge condition in compilation,
# that try catching.... and trying again.... sometimes work?
flags = ['-res-usage', f'-D_C_={HEAD_SIZE}', f"-D_CHUNK_LEN_={CHUNK_LEN}", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization"] #
try:
load(name=load_name, sources=[f'{this_file_path}/cuda/{load_file}_cuda.cu', f'{this_file_path}/cuda/{load_file}_op.cpp'], is_python_module=False, verbose=True, extra_cuda_cflags=flags)
except Exception as e:
print("[WARNING] Failed to load the kernel, trying again (sometimes the compiler has wierd race condition)...")
time.sleep(2) # Somehow this works, with minor compilation error, that passes on subsequent reruns
load(name=load_name, sources=[f'{this_file_path}/cuda/{load_file}_cuda.cu', f'{this_file_path}/cuda/{load_file}_op.cpp'], is_python_module=False, verbose=True, extra_cuda_cflags=flags)
# Return the loaded kernel
return torch.ops.wind_backstepping
@torch.compiler.disable()
def ref_wkv_cuda_forward(w,q,k,v,z,b, y,s,sa):
torch.ops.wind_backstepping.forward(w,q,k,v,z,b, y,s,sa)
@torch.compiler.disable()
def ref_wkv_cuda_backward(w,q,k,v,z,b, dy,s,sa, dw,dq,dk,dv,dz,db):
torch.ops.wind_backstepping.backward(w,q,k,v,z,b, dy,s,sa, dw,dq,dk,dv,dz,db)
class RefCudaWindBackstepping(torch.autograd.Function):
@staticmethod
def forward(ctx, w,q,k,v,z,b):
CHUNK_LEN=16
B,T,H,C = w.shape
assert T%CHUNK_LEN == 0
assert all(i.dtype==torch.bfloat16 for i in [w,q,k,v,z,b])
assert all(i.is_contiguous() for i in [w,q,k,v,z,b])
y = torch.empty_like(v)
s = torch.empty(B,H,T//CHUNK_LEN,C,C, dtype=torch.float32,device=w.device)
sa = torch.empty(B,T,H,C, dtype=torch.float32,device=w.device)
ref_wkv_cuda_forward(w,q,k,v,z,b, y,s,sa)
ctx.save_for_backward(w,q,k,v,z,b,s,sa)
return y
@staticmethod
def backward(ctx, dy):
assert all(i.dtype==torch.bfloat16 for i in [dy])
assert all(i.is_contiguous() for i in [dy])
w,q,k,v,z,b,s,sa = ctx.saved_tensors
dw,dq,dk,dv,dz,db = [torch.empty_like(x) for x in [w,q,k,v,z,b]]
ref_wkv_cuda_backward(w,q,k,v,z,b, dy,s,sa, dw,dq,dk,dv,dz,db)
return dw,dq,dk,dv,dz,db
@torch.compiler.disable()
def rwkv7_attn_cuda_ref(q,w,k,v, kk,iclr, HEAD_SIZE=64, s0=None):
# Preload the kernel
load_ref_wkv_cuda_kernel()
# Get the shape
B,T,HC = w.shape
C = HEAD_SIZE
H = HC//C
# Assert that the chunk is multiple of 16
assert T % 16 == 0, 'reference cuda, only works in multiple of 16'
# Initialize the state, if not provided - for compatibility (THE STATE IS NOT UPDATED)
s0 = torch.zeros(B,H,C,C, dtype=torch.float,device=w.device) if s0 is None else s0
# Handling the cuda kernel
q,w,k,v,a,b = [i.view(B,T,H,C) for i in [q,w,k,v,(-kk),(kk*iclr)]]
# Forward with backprop
xx = RefCudaWindBackstepping.apply(w,q,k,v,a,b)
return xx.view(B,T,HC), s0.view(B,H,C,C)
####################################################################################################
# State based cuda code
####################################################################################################
def load_wkv_cuda_kernel(CHUNK_LEN = 16, HEAD_SIZE = 64):
from torch.utils.cpp_extension import load
# load_name = f"wind_backstepping_C{HEAD_SIZE}_L{CHUNK_LEN}"
load_name = "state_wind_backstepping"
load_file = "state_wkv7"
# Check if the load_name is already loaded
if load_name in torch.ops:
return torch.ops.state_wind_backstepping
# Get the this script file path, to cmpute the cuda path
this_file_path = os.path.dirname(os.path.abspath(__file__))
# Load the kernel, there is some wierd edge condition in compilation,
# that try catching.... and trying again.... sometimes work?
flags = ['-res-usage', f'-D_C_={HEAD_SIZE}', f"-D_CHUNK_LEN_={CHUNK_LEN}", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization"] #
try:
load(name=load_name, sources=[f'{this_file_path}/cuda/{load_file}_cuda.cu', f'{this_file_path}/cuda/{load_file}_op.cpp'], is_python_module=False, verbose=True, extra_cuda_cflags=flags)
except Exception as e:
print("[WARNING] Failed to load the kernel, trying again (sometimes the compiler has wierd race condition)...")
time.sleep(2) # Somehow this works, with minor compilation error, that passes on subsequent reruns
load(name=load_name, sources=[f'{this_file_path}/cuda/{load_file}_cuda.cu', f'{this_file_path}/cuda/{load_file}_op.cpp'], is_python_module=False, verbose=True, extra_cuda_cflags=flags)
# Return the loaded kernel
return torch.ops.state_wind_backstepping
@torch.compiler.disable()
def wkv_cuda_forward(state, w,q,k,v,z,b, y,s,sa):
torch.ops.state_wind_backstepping.forward(state, w,q,k,v,z,b, y,s,sa)
@torch.compiler.disable()
def wkv_cuda_backward(state, w,q,k,v,z,b, dy,s,sa, dw,dq,dk,dv,dz,db):
torch.ops.state_wind_backstepping.backward(state, w,q,k,v,z,b, dy,s,sa, dw,dq,dk,dv,dz,db)
class CudaWindBackstepping(torch.autograd.Function):
@staticmethod
def forward(ctx, s0, w,q,k,v,z,b):
CHUNK_LEN=16
B,T,H,C = w.shape
assert T%CHUNK_LEN == 0
assert all(i.dtype==torch.bfloat16 for i in [w,q,k,v,z,b])
assert all(i.is_contiguous() for i in [w,q,k,v,z,b])
y = torch.empty_like(v)
s = torch.empty(B,H,T//CHUNK_LEN,C,C, dtype=torch.float32,device=w.device)
sa = torch.empty(B,T,H,C, dtype=torch.float32,device=w.device)
sOri = s0.clone()
wkv_cuda_forward(s0, w,q,k,v,z,b, y,s,sa)
ctx.save_for_backward(sOri, w,q,k,v,z,b,s,sa)
return y
@staticmethod
def backward(ctx, dy):
assert all(i.dtype==torch.bfloat16 for i in [dy])
assert all(i.is_contiguous() for i in [dy])
state,w,q,k,v,z,b,s,sa = ctx.saved_tensors
dS0,dw,dq,dk,dv,dz,db = [torch.empty_like(x) for x in [state,w,q,k,v,z,b]]
wkv_cuda_backward(state, w,q,k,v,z,b, dy,s,sa, dw,dq,dk,dv,dz,db)
return dS0,dw,dq,dk,dv,dz,db
@torch.compiler.disable()
def rwkv7_attn_cuda(r,w,k,v, kk,iclr, HEAD_SIZE=64, s0=None):
# Preload the kernel
load_wkv_cuda_kernel()
# Get the shape
B,T,HC = w.shape
# Check if the chunk is multiple of 16
chunk_remainder = T % 16
# Initialize the state
C = HEAD_SIZE
H = HC//C
# Initialize the state
s0 = torch.zeros(B,H,C,C, dtype=torch.float,device=w.device) if s0 is None else s0
sT = s0.to(dtype=torch.float)
# Optimize the call, if chunk is multiple of 16
if chunk_remainder == 0:
chunk_xx, chunk_sT = rwkv7_attn_cuda_chunk(r,w,k,v, kk,iclr, HEAD_SIZE, sT)
return chunk_xx, chunk_sT.to(dtype=s0.dtype)
# Compute the number of chunks
chunks = T // 16
si = chunks * 16
# Get the chunked output
chunk_xx, chunk_sT = rwkv7_attn_cuda_chunk(
r[:,:si],w[:,:si],k[:,:si],v[:,:si], kk[:,:si],iclr[:,:si],
HEAD_SIZE, s0
)
# Get the remainder
remain_xx, last_sT = rwkv7_attn_pytorch_chunk(
r[:,si:],torch.exp(-torch.exp(w[:,si:])),k[:,si:],v[:,si:], kk[:,si:],iclr[:,si:],
B, H, C,
torch.zeros(B, chunk_remainder, HC, device=w.device, dtype=w.dtype),
chunk_sT, chunk_size=chunk_remainder
)
# Concatenate and return results
return torch.cat([chunk_xx.to(dtype=w.dtype), remain_xx.to(dtype=w.dtype)], dim=1), last_sT.to(dtype=s0.dtype)
def rwkv7_attn_cuda_chunk(r,w,k,v, kk,iclr, HEAD_SIZE=64, s0=None):
'''
Triton implementation running in blocks of 16 (hardcoded requirement for the kernel)
'''
B,T,HC = w.shape
assert T % 16 == 0, 'pure cuda, only works in multiple of 16'
C = HEAD_SIZE
H = HC//C
# Handling the cuda kernel
a,b = -kk, (kk*iclr)
r,w,k,v,a,b = [i.view(B,T,H,C) for i in [r,w,k,v,a,b]]
if s0 is None:
s1 = torch.zeros(B,H,C,C, dtype=torch.float,device=w.device)
else:
s1 = s0.clone()
# Forward with backprop
xx = CudaWindBackstepping.apply(s1,w,r,k,v,a,b)
return xx.view(B,T,HC), s1.view(B,H,C,C)
# ----------------
# block/kernel/rwkv7_attn_fla.py
# ----------------
def rwkv7_attn_fla(
r,w,k,v, kk,iclr,
BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE,
xx, wkv_state_in
):
from fla.ops.rwkv7.chunk import chunk_rwkv7
# Preprocessing the FLA
r,w,k,v,a,b = [i.view(BATCH_SIZE,SEQ_LEN,N_HEAD,-1) for i in [r,w,k,v,-kk,(kk*iclr)]]
log_w = -w.float().exp()
# Run the FLA
output, vk_state = chunk_rwkv7(r=r, log_w=log_w, k=k, v=v, a=a, b=b, initial_state=wkv_state_in.float(), output_final_state=True)
return output, vk_state.to(dtype=wkv_state_in.dtype)
def rwkv7_attn_fused_reccurent_fla(
r,w,k,v, kk,iclr,
BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE,
xx, wkv_state_in
):
from fla.ops.rwkv7.fused_recurrent import fused_recurrent_rwkv7
# Preprocessing the FLA
r,w,k,v,a,b = [i.view(BATCH_SIZE,SEQ_LEN,N_HEAD,-1) for i in [r,w,k,v,-kk,(kk*iclr)]]
log_w = -w.float().exp()
# Run the FLA
output, vk_state = fused_recurrent_rwkv7(r=r, log_w=log_w, k=k, v=v, a=a, b=b, initial_state=wkv_state_in.float(), output_final_state=True)
return output, vk_state.to(dtype=wkv_state_in.dtype)
# ----------------
# block/kernel/rwkv7_attn_triton.py
# ----------------
import torch
import torch as th
import triton
import triton.language as tl
####################################################################################################
# Triton specific coding (aka mostly songlin & Johan Sokrates Wind stuff)
#
# Copyright (c) 2024, Johan Sokrates Wind, licensed under MIT
# https://github.com/johanwind/wind_rwkv/blob/main/LICENSE
####################################################################################################
# -------------------------
# Triton "smallhead" and "bighead" common code
# -------------------------
@triton.jit
def IND3(a,b,c,nb,nc):
return (a*nb+b)*nc+c
@triton.jit
def IND4(a,b,c,d,nb,nc,nd):
return ((a*nb+b)*nc+c)*nd+d
@triton.jit
def IND5(a,b,c,d,e,nb,nc,nd,ne):
return (((a*nb+b)*nc+c)*nd+d)*ne+e
@triton.jit
def _prod(a,b): return a*b
# inv(I-A) where A is a strictly lower triangular nxn matrix
@triton.jit
def tri_minv(A, n:tl.constexpr, prec:tl.constexpr):
i = tl.arange(0,n)
prod = (i[None,:]==i[:,None]).to(tl.float32)
for j in range(n-1):
prod += tl_dot(prec, prod, (A*((i[None,:]==j)*(i[:,None]>i[None,:]))).trans())
return prod.trans()
@triton.jit
def tl_dot(prec:tl.constexpr, a, b):
if prec == 'fp32':
return tl.dot(a.to(tl.float32),b.trans().to(tl.float32).trans(), allow_tf32=False)
elif prec == 'tf32':
return tl.dot(a.to(tl.float32),b.trans().to(tl.float32).trans(), allow_tf32=True)
elif prec == 'bf16':
return tl.dot(a.to(tl.bfloat16),b.trans().to(tl.bfloat16).trans(), allow_tf32=True)
else:
tl.static_assert(False)
# -------------------------
# Triton "smallhead" code
# -------------------------
@triton.jit
def fw_attn_triton(w_,q_,k_,v_,a_,b_, s0_,y_,s_,sT_, B:tl.constexpr,T:tl.constexpr,H:tl.constexpr,C:tl.constexpr,dT:tl.constexpr, prec:tl.constexpr):
bi = tl.program_id(1)
hi = tl.program_id(0)
i = tl.arange(0,C)[None,:]
state = tl.load(s0_+IND4(bi,hi,i.trans(),i, H,C,C)).to(tl.float32)
for t0 in range(T//dT):
t = t0*dT+tl.arange(0,dT)[:,None]
sw = tl.load(w_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
sq = tl.load(q_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
sk = tl.load(k_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
sv = tl.load(v_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
sa = tl.load(a_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
sb = tl.load(b_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
w = (-sw.exp()).exp()
fw = tl.reduce(w, 0, _prod, keep_dims=True)
incl_pref = tl.cumprod(w,axis=0)
non_incl_pref = incl_pref / w
inv_incl_pref = 1 / incl_pref
wq = sq * incl_pref
wa = sa * non_incl_pref
kwi = sk * inv_incl_pref
bwi = sb * inv_incl_pref
mask1 = (t > t.trans())
ab = tl_dot(prec, wa, bwi.trans()) * mask1
ak = tl_dot(prec, wa, kwi.trans()) * mask1
ab_inv = tri_minv(ab, dT, prec)
ab_u = tl_dot(prec, ak, sv) + tl_dot(prec, wa, state.trans())
u = tl_dot(prec, ab_inv, ab_u)
mask2 = (t >= t.trans())
qk = tl_dot(prec, wq, kwi.trans()) * mask2
qb = tl_dot(prec, wq, bwi.trans()) * mask2
yy = tl_dot(prec, qk, sv) + tl_dot(prec, qb, u) + tl_dot(prec, wq, state.trans())
tl.store(y_+IND4(bi,t,hi,i, T,H,C), yy.to(tl.bfloat16))
tl.store(s_+IND5(bi,hi,t0,i.trans(),i, H,T//dT,C,C), state.to(tl.float32))
state = state * fw + tl_dot(prec, sv.trans(), kwi*fw) + tl_dot(prec, u.trans(), bwi*fw)
tl.store(sT_+IND4(bi,hi,i.trans(),i, H,C,C), state.to(tl.bfloat16))
@triton.jit
def bw_attn_triton(w_,q_,k_,v_,a_,b_, dy_,s_,dsT_, dw_,dq_,dk_,dv_,da_,db_,ds0_, B:tl.constexpr,T:tl.constexpr,H:tl.constexpr,C:tl.constexpr,dT:tl.constexpr, prec:tl.constexpr):
bi = tl.program_id(1)
hi = tl.program_id(0)
i = tl.arange(0,C)[None,:]
dstate = tl.load(dsT_+IND4(bi,hi,i.trans(),i, H,C,C)).to(tl.float32)
for t0 in range(T//dT-1,-1,-1):
t = t0*dT+tl.arange(0,dT)[:,None]
state = tl.load(s_+IND5(bi,hi,t0,i.trans(),i, H,T//dT,C,C)).to(tl.float32)
sw = tl.load(w_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
sq = tl.load(q_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
sk = tl.load(k_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
sv = tl.load(v_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
sa = tl.load(a_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
sb = tl.load(b_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
sdy = tl.load(dy_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
dw_fac = -sw.exp()
w = dw_fac.exp()
fw = tl.reduce(w, 0, _prod, keep_dims=True)
incl_pref = tl.cumprod(w,axis=0)
non_incl_pref = incl_pref / w
inv_incl_pref = 1 / incl_pref
wq = sq * incl_pref
wa = sa * non_incl_pref
kwi = sk * inv_incl_pref
bwi = sb * inv_incl_pref
mask1 = (t > t.trans())
ab = tl_dot(prec, wa, bwi.trans()) * mask1
ak = tl_dot(prec, wa, kwi.trans()) * mask1
ab_inv = tri_minv(ab, dT, prec)
ab_u = tl_dot(prec, ak, sv) + tl_dot(prec, wa, state.trans())
u = tl_dot(prec, ab_inv, ab_u)
mask2 = (t >= t.trans())
qk = tl_dot(prec, wq, kwi.trans()) * mask2
qb = tl_dot(prec, wq, bwi.trans()) * mask2
du = tl_dot(prec, qb.trans(), sdy) + tl_dot(prec, bwi*fw, dstate.trans())
dab_u = tl_dot(prec, ab_inv.trans(), du)
dv = tl_dot(prec, qk.trans(), sdy) + tl_dot(prec, kwi*fw, dstate.trans()) + tl_dot(prec, ak.trans(), dab_u)
tl.store(dv_+IND4(bi,t,hi,i, T,H,C), dv.to(tl.bfloat16))
dab = tl_dot(prec, tl_dot(prec, ab_inv.trans(), du), u.trans()) * mask1
dak = tl_dot(prec, dab_u, sv.trans()) * mask1
dab_u_state = tl_dot(prec, dab_u, state)
da = non_incl_pref * (tl_dot(prec, dab, bwi) + tl_dot(prec, dak, kwi) + dab_u_state)
tl.store(da_+IND4(bi,t,hi,i, T,H,C), da.to(tl.bfloat16))
dqb = tl_dot(prec, sdy, u.trans()) * mask2
dqk = tl_dot(prec, sdy, sv.trans()) * mask2
dy_state = tl_dot(prec, sdy, state)
dq = incl_pref * (tl_dot(prec, dqb, bwi) + tl_dot(prec, dqk, kwi) + dy_state)
tl.store(dq_+IND4(bi,t,hi,i, T,H,C), dq.to(tl.bfloat16))
fw_u_dstate = fw * tl_dot(prec, u, dstate)
db = inv_incl_pref * (tl_dot(prec, dab.trans(), wa) + tl_dot(prec, dqb.trans(), wq) + fw_u_dstate)
tl.store(db_+IND4(bi,t,hi,i, T,H,C), db.to(tl.bfloat16))
fw_v_dstate = fw * tl_dot(prec, sv, dstate)
dk = inv_incl_pref * (tl_dot(prec, dak.trans(), wa) + tl_dot(prec, dqk.trans(), wq) + fw_v_dstate)
tl.store(dk_+IND4(bi,t,hi,i, T,H,C), dk.to(tl.bfloat16))
dw0 = fw * tl.sum(state*dstate, axis=0,keep_dims=True)
for k in range(t0*dT,t0*dT+dT):
lmask = (t<k).trans()
A = (tl_dot(prec, dab*lmask, bwi) + tl_dot(prec, dak*lmask, kwi)) * wa * (t>k)
A += (tl_dot(prec, dqb*lmask, bwi) + tl_dot(prec, dqk*lmask, kwi)) * wq * (t>=k)
A += (fw_v_dstate*kwi + fw_u_dstate*bwi) * (t<k)
A += dab_u_state*wa * (t>k) + dy_state*wq * (t>=k)
dw = tl.sum(A, axis=0,keep_dims=True) + dw0
wk = tl.load(w_+IND4(bi,k,hi,i, T,H,C)).to(tl.float32)
dw *= -wk.exp()
tl.store(dw_+IND4(bi,k,hi,i, T,H,C), dw.to(tl.bfloat16))
dstate = dstate * fw + tl_dot(prec, sdy.trans(), wq) + tl_dot(prec, dab_u.trans(), wa)
tl.store(ds0_+IND4(bi,hi,i.trans(),i, H,C,C), dstate.to(tl.bfloat16))
class TritonRWKV7(th.autograd.Function):
@staticmethod
def forward(ctx, w,q,k,v,z,b,s0, dot_prec):
K = 16
B,T,H,C = w.shape
s0 = th.zeros(B,H,C,C, dtype=w.dtype,device=w.device) if s0 is None else s0
y = th.empty_like(v)
sT = th.empty_like(s0)
s = th.zeros(B,H,T//K,C,C, dtype=th.float32,device=w.device)
fw_attn_triton[(H,B)](w,q,k,v,z,b, s0,y,s,sT, B,T,H,C,K, dot_prec)
ctx.dot_prec = dot_prec
ctx.save_for_backward(w,q,k,v,z,b,s)
return y, sT
@staticmethod
def backward(ctx, dy, dsT):
K = 16
w,q,k,v,z,b,s = ctx.saved_tensors
B,T,H,C = w.shape
dw,dq,dk,dv,dz,db,ds0 = [th.empty_like(x) for x in [w,q,k,v,z,b,dsT]]
bw_attn_triton[(H,B)](w,q,k,v,z,b, dy,s,dsT, dw,dq,dk,dv,dz,db,ds0, B,T,H,C,K, ctx.dot_prec)
return dw,dq,dk,dv,dz,db,ds0,None
# -------------------------
# Triton "bighead" code
# -------------------------
@triton.autotune(configs=[triton.Config({'dC': dC}, num_stages=1) for dC in [16,32,64]], key=['T','H','C','dT','prec'])
@triton.jit
def fw_attn_triton_bighead(w_,q_,k_,v_,a_,b_, s0_,y_,s_,sT_, wq_,wa_,kwi_,bwi_,fw_, B:tl.constexpr,T:tl.constexpr,H:tl.constexpr,C:tl.constexpr,dT:tl.constexpr, prec:tl.constexpr, dC:tl.constexpr):
tl.static_assert(C%dC == 0)
bi = tl.program_id(1)
hi = tl.program_id(0)
for i0 in range(0,C,dC):
i = i0+tl.arange(0,dC)[None,:]
for j0 in range(0,C,dC):
j = j0+tl.arange(0,dC)[None,:]
state = tl.load(s0_+IND4(bi,hi,i.trans(),j, H,C,C)).to(tl.float32)
tl.store(s_+IND5(bi,hi,0,i.trans(),j, H,T//dT,C,C), state.to(tl.float32))
for t0 in range(T//dT):
dt = tl.arange(0,dT)[:,None]
t = t0*dT+dt
tl.debug_barrier()
for j0 in range(0,C,dC):
j = j0+tl.arange(0,dC)[None,:]
sw = tl.load(w_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32)
sq = tl.load(q_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32)
sk = tl.load(k_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32)
sa = tl.load(a_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32)
sb = tl.load(b_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32)
w = (-sw.exp()).exp()
fw = tl.reduce(w, 0, _prod, keep_dims=True)
incl_pref = tl.cumprod(w,axis=0)
non_incl_pref = incl_pref / w
inv_incl_pref = 1 / incl_pref
wq = sq * incl_pref
wa = sa * non_incl_pref
kwi = sk * inv_incl_pref
bwi = sb * inv_incl_pref
tl.store(wq_+IND4(bi,hi,dt,j, H,dT,C), wq.to(tl.float32))
tl.store(wa_+IND4(bi,hi,dt,j, H,dT,C), wa.to(tl.float32))
tl.store(kwi_+IND4(bi,hi,dt,j, H,dT,C), kwi.to(tl.float32))
tl.store(bwi_+IND4(bi,hi,dt,j, H,dT,C), bwi.to(tl.float32))
tl.store(fw_+IND3(bi,hi,j, H,C), fw.to(tl.float32))
tl.debug_barrier()
ab = tl.zeros((dT,dT), tl.float32)
ak = tl.zeros((dT,dT), tl.float32)
qb = tl.zeros((dT,dT), tl.float32)
qk = tl.zeros((dT,dT), tl.float32)
for j0 in range(0,C,dC):
j = j0+tl.arange(0,dC)[None,:]
wa = tl.load(wa_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32)
wq = tl.load(wq_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32)
bwi = tl.load(bwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32)
kwi = tl.load(kwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32)
sa = tl.load(a_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32)
sb = tl.load(b_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32)
ab += tl_dot(prec, wa, bwi.trans())
ak += tl_dot(prec, wa, kwi.trans())
qb += tl_dot(prec, wq, bwi.trans())
qk += tl_dot(prec, wq, kwi.trans())
mask1 = (t > t.trans())
mask2 = (t >= t.trans())
ab *= mask1
ak *= mask1
qb *= mask2
qk *= mask2
ab_inv = tri_minv(ab, dT, prec)
for i0 in range(0,C,dC):
i = i0+tl.arange(0,dC)[None,:]
sv = tl.load(v_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
wa_state = tl.zeros((dT,dC), tl.float32)
wq_state = tl.zeros((dT,dC), tl.float32)
for j0 in range(0,C,dC):
j = j0+tl.arange(0,dC)[None,:]
state = tl.load(s_+IND5(bi,hi,t0,i.trans(),j, H,T//dT,C,C)).to(tl.float32)
wa = tl.load(wa_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32)
wq = tl.load(wq_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32)
wa_state += tl_dot(prec, wa, state.trans())
wq_state += tl_dot(prec, wq, state.trans())
ab_u = tl_dot(prec, ak, sv) + wa_state
u = tl_dot(prec, ab_inv, ab_u)
yy = tl_dot(prec, qk, sv) + tl_dot(prec, qb, u) + wq_state
tl.store(y_+IND4(bi,t,hi,i, T,H,C), yy.to(tl.bfloat16))
for j0 in range(0,C,dC):
j = j0+tl.arange(0,dC)[None,:]
state = tl.load(s_+IND5(bi,hi,t0,i.trans(),j, H,T//dT,C,C)).to(tl.float32)
kwi = tl.load(kwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32)
bwi = tl.load(bwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32)
fw = tl.load(fw_+IND3(bi,hi,j, H,C))
state = state * fw + tl_dot(prec, sv.trans(), kwi*fw) + tl_dot(prec, u.trans(), bwi*fw)
if t0+1 < T//dT:
tl.store(s_+IND5(bi,hi,t0+1,i.trans(),j, H,T//dT,C,C), state.to(tl.float32))
else:
tl.store(sT_+IND4(bi,hi,i.trans(),j, H,C,C), state.to(tl.bfloat16))
@triton.autotune(configs=[triton.Config({'dC': dC}, num_stages=1) for dC in [16,32,64]], key=['T','H','C','dT','prec'])
@triton.jit
def bw_attn_triton_bighead(w_,q_,k_,v_,a_,b_, dy_,s_,dsT_,ds_, dw_,dq_,dk_,dv_,da_,db_,ds0_, wq_,wa_,kwi_,bwi_,fw_,u_,dab_u_, B:tl.constexpr,T:tl.constexpr,H:tl.constexpr,C:tl.constexpr,dT:tl.constexpr, prec:tl.constexpr, dC:tl.constexpr):
tl.static_assert(C%dC == 0)
bi = tl.program_id(1)
hi = tl.program_id(0)
for i0 in range(0,C,dC):
i = i0+tl.arange(0,dC)[None,:]
for j0 in range(0,C,dC):
j = j0+tl.arange(0,dC)[None,:]
dstate = tl.load(dsT_+IND4(bi,hi,i.trans(),j, H,C,C)).to(tl.float32)
tl.store(ds_+IND4(bi,hi,i.trans(),j, H,C,C), dstate.to(tl.float32))
for t0 in range(T//dT-1,-1,-1):
dt = tl.arange(0,dT)[:,None]
t = t0*dT+dt
tl.debug_barrier()
for j0 in range(0,C,dC):
j = j0+tl.arange(0,dC)[None,:]
sw = tl.load(w_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32)
sq = tl.load(q_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32)
sk = tl.load(k_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32)
sa = tl.load(a_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32)
sb = tl.load(b_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32)
w = (-sw.exp()).exp()
fw = tl.reduce(w, 0, _prod, keep_dims=True)
incl_pref = tl.cumprod(w,axis=0)
non_incl_pref = incl_pref / w
inv_incl_pref = 1 / incl_pref
wq = sq * incl_pref
wa = sa * non_incl_pref
kwi = sk * inv_incl_pref
bwi = sb * inv_incl_pref
tl.store(wq_+IND4(bi,hi,dt,j, H,dT,C), wq.to(tl.float32))
tl.store(wa_+IND4(bi,hi,dt,j, H,dT,C), wa.to(tl.float32))
tl.store(kwi_+IND4(bi,hi,dt,j, H,dT,C), kwi.to(tl.float32))
tl.store(bwi_+IND4(bi,hi,dt,j, H,dT,C), bwi.to(tl.float32))
tl.store(fw_+IND3(bi,hi,j, H,C), fw.to(tl.float32))
tl.debug_barrier()
ab = tl.zeros((dT,dT), tl.float32)
ak = tl.zeros((dT,dT), tl.float32)
qb = tl.zeros((dT,dT), tl.float32)
qk = tl.zeros((dT,dT), tl.float32)
for j0 in range(0,C,dC):
j = j0+tl.arange(0,dC)[None,:]
wa = tl.load(wa_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32)
wq = tl.load(wq_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32)
bwi = tl.load(bwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32)
kwi = tl.load(kwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32)
sa = tl.load(a_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32)
sb = tl.load(b_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32)
ab += tl_dot(prec, wa, bwi.trans())
ak += tl_dot(prec, wa, kwi.trans())
qb += tl_dot(prec, wq, bwi.trans())
qk += tl_dot(prec, wq, kwi.trans())
mask1 = (t > t.trans())
mask2 = (t >= t.trans())
ab *= mask1
ak *= mask1
qb *= mask2
qk *= mask2
ab_inv = tri_minv(ab, dT, prec)
dab = tl.zeros((dT,dT), tl.float32)
dak = tl.zeros((dT,dT), tl.float32)
dqb = tl.zeros((dT,dT), tl.float32)
dqk = tl.zeros((dT,dT), tl.float32)
tl.debug_barrier()
for i0 in range(0,C,dC):
i = i0+tl.arange(0,dC)[None,:]
wa_state = tl.zeros((dT,dC), tl.float32)
bwi_dw_dstate = tl.zeros((dT,dC), tl.float32)
kwi_dw_dstate = tl.zeros((dT,dC), tl.float32)
for j0 in range(0,C,dC):
j = j0+tl.arange(0,dC)[None,:]
state = tl.load(s_+IND5(bi,hi,t0,i.trans(),j, H,T//dT,C,C)).to(tl.float32)
dstate = tl.load(ds_+IND4(bi,hi,i.trans(),j, H,C,C)).to(tl.float32)
wa = tl.load(wa_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32)
bwi = tl.load(bwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32)
kwi = tl.load(kwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32)
fw = tl.load(fw_+IND3(bi,hi,j, H,C))
wa_state += tl_dot(prec, wa, state.trans())
bwi_dw_dstate += tl_dot(prec, bwi*fw, dstate.trans())
kwi_dw_dstate += tl_dot(prec, kwi*fw, dstate.trans())
sv = tl.load(v_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
sdy = tl.load(dy_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
ab_u = tl_dot(prec, ak, sv) + wa_state
u = tl_dot(prec, ab_inv, ab_u)
du = tl_dot(prec, qb.trans(), sdy) + bwi_dw_dstate
dab_u = tl_dot(prec, ab_inv.trans(), du)
tl.store(u_+IND4(bi,hi,dt,i, H,dT,C), u.to(tl.float32))
tl.store(dab_u_+IND4(bi,hi,dt,i, H,dT,C), dab_u.to(tl.float32))
dv = tl_dot(prec, qk.trans(), sdy) + kwi_dw_dstate + tl_dot(prec, ak.trans(), dab_u)
tl.store(dv_+IND4(bi,t,hi,i, T,H,C), dv.to(tl.bfloat16))
dab += tl_dot(prec, dab_u, u.trans()) * mask1
dak += tl_dot(prec, dab_u, sv.trans()) * mask1
dqb += tl_dot(prec, sdy, u.trans()) * mask2
dqk += tl_dot(prec, sdy, sv.trans()) * mask2
tl.debug_barrier()
for j0 in range(0,C,dC):
j = j0+tl.arange(0,dC)[None,:]
dy_state = tl.zeros((dT,dC), tl.float32)
dab_u_state = tl.zeros((dT,dC), tl.float32)
fw_u_dstate = tl.zeros((dT,dC), tl.float32)
fw_v_dstate = tl.zeros((dT,dC), tl.float32)
state_dstate = tl.zeros((1,dC), tl.float32)
fw = tl.load(fw_+IND3(bi,hi,j, H,C))
wa = tl.load(wa_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32)
wq = tl.load(wq_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32)
for i0 in range(0,C,dC):
i = i0+tl.arange(0,dC)[None,:]
u = tl.load(u_+IND4(bi,hi,dt,i, H,dT,C)).to(tl.float32)
dab_u = tl.load(dab_u_+IND4(bi,hi,dt,i, H,dT,C)).to(tl.float32)
sv = tl.load(v_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
sdy = tl.load(dy_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32)
state = tl.load(s_+IND5(bi,hi,t0,i.trans(),j, H,T//dT,C,C)).to(tl.float32)
tl.debug_barrier()
dstate = tl.load(ds_+IND4(bi,hi,i.trans(),j, H,C,C)).to(tl.float32)
tl.debug_barrier()
dab_u_state += tl_dot(prec, dab_u, state)
fw_u_dstate += fw * tl_dot(prec, u, dstate)
fw_v_dstate += fw * tl_dot(prec, sv, dstate)
dy_state += tl_dot(prec, sdy, state)
state_dstate += tl.sum(state*dstate, axis=0,keep_dims=True)
dstate = dstate * fw + tl_dot(prec, sdy.trans(), wq) + tl_dot(prec, dab_u.trans(), wa)
if t0 > 0:
tl.store(ds_+IND4(bi,hi,i.trans(),j, H,C,C), dstate.to(tl.float32))
else:
tl.store(ds0_+IND4(bi,hi,i.trans(),j, H,C,C), dstate.to(tl.bfloat16))
sw = tl.load(w_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32)
w = (-sw.exp()).exp()
incl_pref = tl.cumprod(w,axis=0)
non_incl_pref = incl_pref / w
inv_incl_pref = 1 / incl_pref
bwi = tl.load(bwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32)
kwi = tl.load(kwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32)
da = non_incl_pref * (tl_dot(prec, dab, bwi) + tl_dot(prec, dak, kwi) + dab_u_state)
tl.store(da_+IND4(bi,t,hi,j, T,H,C), da.to(tl.bfloat16))
dq = incl_pref * (tl_dot(prec, dqb, bwi) + tl_dot(prec, dqk, kwi) + dy_state)
tl.store(dq_+IND4(bi,t,hi,j, T,H,C), dq.to(tl.bfloat16))
db = inv_incl_pref * (tl_dot(prec, dab.trans(), wa) + tl_dot(prec, dqb.trans(), wq) + fw_u_dstate)
tl.store(db_+IND4(bi,t,hi,j, T,H,C), db.to(tl.bfloat16))
dk = inv_incl_pref * (tl_dot(prec, dak.trans(), wa) + tl_dot(prec, dqk.trans(), wq) + fw_v_dstate)
tl.store(dk_+IND4(bi,t,hi,j, T,H,C), dk.to(tl.bfloat16))
dw0 = fw * state_dstate
for k in range(t0*dT,t0*dT+dT):
lmask = (t<k).trans()
A = (tl_dot(prec, dab*lmask, bwi) + tl_dot(prec, dak*lmask, kwi)) * wa * (t>k)
A += (tl_dot(prec, dqb*lmask, bwi) + tl_dot(prec, dqk*lmask, kwi)) * wq * (t>=k)
A += (fw_v_dstate*kwi + fw_u_dstate*bwi) * (t<k)
A += dab_u_state*wa * (t>k) + dy_state*wq * (t>=k)
dw = tl.sum(A, axis=0,keep_dims=True) + dw0
wk = tl.load(w_+IND4(bi,k,hi,j, T,H,C)).to(tl.float32)
dw *= -wk.exp()
tl.store(dw_+IND4(bi,k,hi,j, T,H,C), dw.to(tl.bfloat16))
class TritonBigheadRWKV7(th.autograd.Function):
@staticmethod
def forward(ctx, w,q,k,v,a,b,s0, dot_prec):
K = 16
B,T,H,C = w.shape
assert T%K == 0
assert C%16 == 0
s0 = th.zeros(B,H,C,C, dtype=w.dtype,device=w.device) if s0 is None else s0
y = th.empty_like(v)
sT = th.empty_like(s0)
s = th.zeros(B,H,T//K,C,C, dtype=th.float32,device=w.device)
wq,wa,kwi,bwi = [th.empty(B,H,K,C, dtype=th.float32,device=w.device) for i in range(4)]
fw = th.empty(B,H,C, dtype=th.float32,device=w.device)
fw_attn_triton_bighead[(H,B)](w,q,k,v,a,b, s0,y,s,sT, wq,wa,kwi,bwi,fw, B,T,H,C,K, dot_prec)
ctx.dot_prec = dot_prec
ctx.save_for_backward(w,q,k,v,a,b,s)
return y, sT
@staticmethod
def backward(ctx, dy, dsT):
K = 16
w,q,k,v,a,b,s = ctx.saved_tensors
B,T,H,C = w.shape
dw,dq,dk,dv,da,db,ds0 = [th.empty_like(x) for x in [w,q,k,v,a,b,dsT]]
fw = th.empty(B,H,C, dtype=th.float32,device=w.device)
ds = th.empty(B,H,C,C, dtype=th.float32,device=w.device)
wq,wa,kwi,bwi,u,dab_u = [th.empty(B,H,K,C, dtype=th.float32,device=w.device) for i in range(6)]
bw_attn_triton_bighead[(H,B)](w,q,k,v,a,b, dy,s,dsT,ds, dw,dq,dk,dv,da,db,ds0, wq,wa,kwi,bwi,fw,u,dab_u, B,T,H,C,K, ctx.dot_prec)
return dw,dq,dk,dv,da,db,ds0,None
####################################################################################################
# Start of pytorch code
####################################################################################################
# from .rwkv7_attn_pytorch import rwkv7_attn_pytorch_chunk
# -------------------------
# Pytorch "smallhead" code
# -------------------------
def rwkv7_attn_triton(r,w,k,v, kk,iclr, HEAD_SIZE=64, dot_prec='fp32', s0=None):
B,T,HC = w.shape
# Check if the chunk is multiple of 16
chunk_remainder = T % 16
# Optimize the call, if chunk is multiple of 16
if chunk_remainder == 0:
return rwkv7_attn_triton_chunk(r,w,k,v, kk,iclr, HEAD_SIZE, dot_prec, s0)
# Initialize the state
C = HEAD_SIZE
H = HC//C
s0 = th.zeros(B,H,C,C, dtype=th.float,device=w.device) if s0 is None else s0
# Compute the number of chunks
chunks = T // 16
si = chunks * 16
# Get the chunked output
chunk_xx, chunk_sT = rwkv7_attn_triton_chunk(
r[:,:si],w[:,:si],k[:,:si],v[:,:si], kk[:,:si],iclr[:,:si],
HEAD_SIZE, dot_prec, s0
)
# Get the remainder
remain_xx, last_sT = rwkv7_attn_pytorch_chunk(
r[:,si:],torch.exp(-torch.exp(w[:,si:])),k[:,si:],v[:,si:], kk[:,si:],iclr[:,si:],
B, H, C, torch.zeros(B, chunk_remainder, HC, device=w.device, dtype=w.dtype),
chunk_sT, chunk_size=chunk_remainder
)
# Concatenate and return results
return torch.cat([chunk_xx.to(dtype=w.dtype), remain_xx.to(dtype=w.dtype)], dim=1), last_sT.to(dtype=s0.dtype)
def rwkv7_attn_triton_chunk(r,w,k,v, kk,iclr, HEAD_SIZE=64, dot_prec='fp32', s0=None):
'''
Triton implementation running in blocks of 16 (hardcoded requirement for the kernel)
'''
B,T,HC = w.shape
assert T % 16 == 0, 'pure triton, only works in multiple of 16'
C = HEAD_SIZE
H = HC//C
# Moving the triton specific operations into the chunk steps
r,w,k,v,a,b = [i.view(B,T,H,C) for i in [r,w,k,v,-kk,(kk*iclr)]]
s0 = th.zeros(B,H,C,C, dtype=th.float,device=w.device) if s0 is None else s0
xx, sT = TritonRWKV7.apply(w,r,k,v,a,b,s0,dot_prec)
return xx.view(B,T,HC), sT
# -------------------------
# Pytorch "bighead" code
# -------------------------
def rwkv7_attn_triton_bighead(r,w,k,v, kk,iclr, HEAD_SIZE=64, dot_prec='fp32', s0=None):
B,T,HC = w.shape
# Check if the chunk is multiple of 16
chunk_remainder = T % 16
# Optimize the call, if chunk is multiple of 16
if chunk_remainder == 0:
return rwkv7_attn_triton_bighead_chunk(r,w,k,v, kk,iclr, HEAD_SIZE, dot_prec, s0)
# Initialize the state
C = HEAD_SIZE
H = HC//C
s0 = th.zeros(B,H,C,C, dtype=th.float,device=w.device) if s0 is None else s0
# Compute the number of chunks
chunks = T // 16
si = chunks * 16
# Get the chunked output
chunk_xx, chunk_sT = rwkv7_attn_triton_bighead_chunk(
r[:,:si],w[:,:si],k[:,:si],v[:,:si], kk[:,:si],iclr[:,:si],
HEAD_SIZE, dot_prec, s0
)
# Get the remainder
remain_xx, last_sT = rwkv7_attn_pytorch_chunk(
r[:,si:],torch.exp(-torch.exp(w[:,si:])),k[:,si:],v[:,si:], kk[:,si:],iclr[:,si:],
B, H, C, torch.zeros(B, chunk_remainder, HC, device=w.device, dtype=w.dtype),
chunk_sT, chunk_size=chunk_remainder
)
# Concatenate and return results
return torch.cat([chunk_xx.to(dtype=w.dtype), remain_xx.to(dtype=w.dtype)], dim=1), last_sT.to(dtype=s0.dtype)
def rwkv7_attn_triton_bighead_chunk(r,w,k,v, kk,iclr, HEAD_SIZE=64, dot_prec='fp32', s0=None):
'''
Triton implementation running in blocks of 16 (hardcoded requirement for the kernel)
'''
B,T,HC = w.shape
assert T % 16 == 0, 'pure triton, only works in multiple of 16'
C = HEAD_SIZE
H = HC//C
# Moving the triton specific operations into the chunk steps
r,w,k,v,a,b = [i.view(B,T,H,C) for i in [r,w,k,v,-kk,(kk*iclr)]]
s0 = th.zeros(B,H,C,C, dtype=th.float,device=w.device) if s0 is None else s0
xx, sT = TritonBigheadRWKV7.apply(w,r,k,v,a,b,s0,dot_prec)
return xx.view(B,T,HC), sT
# ----------------
# block/rwkv7_block_config_map.py
# ----------------
from dataclasses import dataclass
from typing import Optional
from typing import Union
import torch
@dataclass
class RWKV7BlockConfigMap:
"""Configuration map for RWKV based models"""
# Key properties for the block / model
num_hidden_layers: int
hidden_size: int
head_size: int = 64
# Dropout rate, should only be used in training
dropout_rate: float = 0.0
# Implementation backend to use
tmix_backend: str = "auto"
# ---
# OPTIONAL PROPS
#
# Optional properties which can be derived
# or can be overwritten by the user
# ---
# Current layer_id of the block
layer_id: Optional[int] = None
# Device and Data type
device: Union[torch.device, str, None] = None
dtype: Union[torch.dtype, str, None] = None
# Channel mix / FFN block dimension size
hidden_size_ffn: Optional[int] = None
hidden_size_att: Optional[int] = None
# # number of heads
# n_head: Optional[int] = None
# ---
# Initializer, with excess arg ignore
# ---
def __init__(
self,
num_hidden_layers: int,
hidden_size: int,
head_size: int = 64,
dropout_rate: float = 0.0,
tmix_backend: str = "auto",
layer_id: Optional[int] = None,
device: Union[torch.device, str, None] = None,
dtype: Union[torch.dtype, str, None] = None,
hidden_size_ffn: Optional[int] = None,
hidden_size_att: Optional[int] = None,
**kwargs
) -> None:
'''
Constructor for the config
'''
self.num_hidden_layers = num_hidden_layers
self.hidden_size = hidden_size
self.head_size = head_size
self.dropout_rate = dropout_rate
self.tmix_backend = tmix_backend
self.layer_id = layer_id
self.device = device
self.dtype = dtype
self.hidden_size_ffn = hidden_size_ffn
self.hidden_size_att = hidden_size_att
# ---
# OPTIONAL PROPS FETCHER
# ---
def get_layer_id(self, fallback:int) -> int:
'''
Returns the layer id
'''
if self.layer_id is not None:
return self.layer_id
return fallback
def get_device(self, fallback:str) -> torch.device:
'''
Returns the device
'''
if self.device is not None:
return torch.device(self.device)
if fallback is not None:
return torch.device(fallback)
return torch.get_default_device()
def get_dtype(self, fallback:str) -> torch.dtype:
'''
Returns the dtype
'''
if self.dtype is not None:
key = self.dtype
else:
key = fallback
# if dtype is already torch.dtype
if isinstance(key, torch.dtype):
return key
# Get and Check if the dtype is instance of torch.dtype
ret = getattr(torch, key)
assert isinstance(ret, torch.dtype), f"Invalid dtype: {self.dtype}"
return ret
# ---
def get_hidden_size_att(self) -> int:
'''
Returns the dimension of attention
'''
if self.hidden_size_att is not None:
hidden_size_att = self.hidden_size_att
else:
hidden_size = self.hidden_size
assert hidden_size % 32 == 0, f"hidden_size must be divisible by 32"
hidden_size_att = hidden_size
assert hidden_size_att % 32 == 0, f"hidden_size_att must be divisible by 32 ({hidden_size_att})"
return hidden_size_att
def get_hidden_size_ffn(self) -> int:
'''
Returns the dimension of feed forward network
'''
if self.hidden_size_ffn is not None:
hidden_size_ffn = self.hidden_size_ffn
else:
hidden_size = self.hidden_size
assert hidden_size % 32 == 0, f"hidden_size must be divisible by 32"
hidden_size_ffn = hidden_size * 4
assert hidden_size_ffn % 32 == 0, f"hidden_size_ffn must be divisible by 32"
return hidden_size_ffn
# def get_n_head(self) -> int:
# '''
# Returns the number of heads
# '''
# if self.n_head is not None:
# n_head = self.n_head
# else:
# hidden_size_att = self.get_hidden_size_att()
# n_head = self.get_hidden_size_att() // self.head_size
# assert hidden_size_att % n_head == 0 , f"hidden_size_att must be divisible by head_size ({self.head_size})"
#
# return n_head
# ---
# Duplicator & Normalizer
# ---
def new_block_config_map(self, **kwargs) -> 'RWKV7BlockConfigMap':
'''
Returns a new config map with updated values
'''
new_dict = {}
for key in RWKV7BlockConfigMap.__dataclass_fields__:
if key in self.__dict__:
new_dict[key] = self.__dict__[key]
new_dict.update(kwargs)
return RWKV7BlockConfigMap(**new_dict)
@staticmethod
def normalize(config_map: any) -> 'RWKV7BlockConfigMap':
'''
Converts either maps, objs or RWKV7BlockConfigMap
'''
if isinstance(config_map, RWKV7BlockConfigMap):
return config_map
dict_obj = None
if isinstance(config_map, dict):
dict_obj = config_map
elif hasattr(config_map, '__dict__'):
dict_obj = config_map.__dict__
if dict_obj is not None:
# Filter for only valeus in RWKV7BlockConfigMap
new_dict = {}
for key, value in dict_obj.items():
if key in RWKV7BlockConfigMap.__dataclass_fields__:
new_dict[key] = value
return RWKV7BlockConfigMap(**new_dict)
raise ValueError(f"Unsupported config_map type: {type(config_map)}")
# ----------------
# block/rwkv7_channel_mix.py
# ----------------
import torch
from torch import nn
from typing import Union
# from .rwkv7_block_config_map import RWKV7BlockConfigMap
class RWKV7ChannelMix(torch.nn.Module):
'''
ChannelMix block for RWKV
This is similar to transformer FFN block
'''
def __init__(self, configMap: Union[RWKV7BlockConfigMap, any]):
'''
Initialize the ChannelMix block.
Note: this does not initialize the parameter weights itself
which would depend on the `init_parameters()` method
'''
super().__init__()
configMap:RWKV7BlockConfigMap = RWKV7BlockConfigMap.normalize(configMap)
self.configMap = configMap
# Get various props
hidden_size = configMap.hidden_size
device = configMap.get_device(None)
dtype = configMap.get_dtype('bfloat16')
# By default, hidden_size_ffn = hidden_size * 4
hidden_size_ffn = configMap.get_hidden_size_ffn()
# Build the various params
# ---
self.x_k = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype))
self.key = nn.Linear(hidden_size, hidden_size_ffn, bias=False, device=device, dtype=dtype)
self.value = nn.Linear(hidden_size_ffn, hidden_size, bias=False, device=device, dtype=dtype)
def init_parameters(self):
'''
Reset the parameters of the block, to an initial state used for training a model from scratch
'''
# Get required props
configMap = self.configMap
hidden_size = configMap.hidden_size
num_hidden_layers = configMap.num_hidden_layers
# Get optional props
layer_id = configMap.get_layer_id(0)
device = configMap.get_device(None)
dtype = configMap.get_dtype('bfloat16')
# By default, hidden_size_ffn = hidden_size * 4
hidden_size_ffn = configMap.get_hidden_size_ffn()
# Reset the various params
# ---
with torch.no_grad(): # fancy init of time_mix
ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0
ddd = torch.ones(1, 1, hidden_size)
for i in range(hidden_size):
ddd[0, 0, i] = i / hidden_size
self.x_k = nn.Parameter( (1.0 - torch.pow(ddd, ratio_1_to_almost0**4)).to(device, dtype=dtype) )
self.key = nn.Linear(hidden_size, hidden_size_ffn, bias=False, device=device, dtype=dtype)
self.value = nn.Linear(hidden_size_ffn, hidden_size, bias=False, device=device, dtype=dtype)
def forward(self, x: torch.Tensor, last_state: torch.Tensor) -> tuple[torch.Tensor,torch.Tensor]:
'''
Forwarding channel mix given the input tokens and states.
Given:
- Incoming token embedding size of shape [batch_size, seq_len, embedding_size]
- Incoming channel mix, shift states of the various batches [batch_size, state_size]
Returns a pair
- Output embedding of shape [batch_size, seq_len, embedding_size]
- Output channel mix, shift state of shape [batch_size, state_size]
'''
# last_state = last_state.to(self.key.weight.device)
##########
## x070
##########
dxprev = torch.cat((last_state.unsqueeze(1), x[:, :-1]), dim=1) - x
xk = x + dxprev * self.x_k
k = torch.relu( self.key(xk) ) ** 2
return self.value(k), x[:,-1]
@torch.compile(mode="default", fullgraph=True)
def forward_with_default_compile(self, in_x: torch.Tensor, in_state: torch.Tensor, out_x: torch.Tensor, out_state: torch.Tensor) -> tuple[torch.Tensor,torch.Tensor]:
'''
Compiled varient of the forward function
With no new tensors being created for the output
Useful for static memory allocation optimizations inference
'''
out_x[:], out_state[:] = self.forward_with_reduce_compile(in_x, in_state)
return out_x, out_state
@torch.compile(mode="reduce-overhead", fullgraph=True)
def forward_with_reduce_compile(self, in_x: torch.Tensor, in_state: torch.Tensor) -> tuple[torch.Tensor,torch.Tensor]:
'''
Compiled varient of the forward function
'''
return self.forward(in_x, in_state)
def load_from_model_state_dict(self, model_state_dict: dict, layer_id:int, non_blocking:bool=True):
'''
Given the Full/partial RWKV model weights, loaded via `torch.load`
Setup the the current module weights, using the layer_id
'''
# Get the current state_dict
current_state_dict = self.state_dict()
# Iterate each parameter in the state_dict, and compare from the model
for n in current_state_dict:
model_key = f"blocks.{layer_id}.ffn.{n}"
if model_key not in model_state_dict:
continue
# Copy the values from the state_dict
try:
current_state_dict[n].copy_(model_state_dict[model_key], non_blocking=non_blocking)
except Exception as e:
print(f"[ERROR] loading: {model_key} | model shape: {current_state_dict[n].shape} | weight shape: {model_state_dict[model_key].shape}")
raise e
# ----------------
# block/rwkv7_time_mix.py
# ----------------
import torch, math
from torch import nn
from torch import Tensor
from typing import Union
from torch.nn import functional as F
# from .rwkv7_block_config_map import RWKV7BlockConfigMap
# Check for triton package, if GPU is available
triton = None
if torch.cuda.is_available():
try:
import triton
except ImportError:
triton = None
else:
print("[WARNING] Triton not available, falling back to pytorch mode by default - this is significantly slower")
class RWKV7TimeMix(torch.nn.Module):
'''
Time Mix block for RWKV V7
'''
def __init__(self, configMap: Union[RWKV7BlockConfigMap, any]):
'''
Initialize the TimeMix block.
Note: this does not initialize the parameter weights itself
which would depend on the `init_parameters()` method
'''
super().__init__()
configMap:RWKV7BlockConfigMap = RWKV7BlockConfigMap.normalize(configMap)
self.configMap = configMap
# Get required props
hidden_size = configMap.hidden_size
# num_hidden_layers = configMap.num_hidden_layers
# Get the layer id
layer_id = configMap.get_layer_id(0)
self.layer_id = layer_id
# Get optional props
device = configMap.get_device(None)
dtype = configMap.get_dtype('bfloat16')
# By default, hidden_size_ffn = hidden_size
hidden_size_att = configMap.get_hidden_size_att()
# Assert hidden_size == hidden_size_att, until we support different hidden_size and hidden_size_att
assert hidden_size == hidden_size_att, "hidden_size should be equal to hidden_size_att (@TODO: support different hidden_size and hidden_size_att)"
# Head size settings
head_size = configMap.head_size
self.head_size = head_size
# Number of heads
n_head = hidden_size_att // head_size
assert hidden_size_att % head_size == 0, "hidden_size_att should be divisible by head_size"
self.n_head = n_head
# Backend
self.tmix_backend = configMap.tmix_backend
# Build the various params
# ---
with torch.no_grad():
# Note: for some data, you can reduce D_GATE_LORA or even remove this gate
def calc_lora_rank(exponent, multiplier):
return max(1, round(hidden_size ** exponent * multiplier / 32)) * 32
D_DECAY_LORA = calc_lora_rank(0.5, 1.8)
D_AAA_LORA = calc_lora_rank(0.5, 1.8)
D_MV_LORA = calc_lora_rank(0.5, 1.3)
D_GATE_LORA = calc_lora_rank(0.8, 0.6)
self.x_r = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype))
self.x_w = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype))
self.x_k = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype))
self.x_v = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype))
self.x_a = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype))
self.x_g = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype))
self.w0 = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype))
self.w1 = nn.Parameter(torch.empty(hidden_size, D_DECAY_LORA, device=device, dtype=dtype))
self.w2 = nn.Parameter(torch.empty(D_DECAY_LORA, hidden_size, device=device, dtype=dtype))
self.a0 = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype))
self.a1 = nn.Parameter(torch.empty(hidden_size,D_AAA_LORA, device=device, dtype=dtype))
self.a2 = nn.Parameter(torch.empty(D_AAA_LORA,hidden_size, device=device, dtype=dtype))
if layer_id > 0:
self.v0 = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype))
self.v1 = nn.Parameter(torch.empty(hidden_size,D_MV_LORA, device=device, dtype=dtype))
self.v2 = nn.Parameter(torch.empty(D_MV_LORA,hidden_size, device=device, dtype=dtype))
self.g1 = nn.Parameter(torch.empty(hidden_size, D_GATE_LORA, device=device, dtype=dtype))
self.g2 = nn.Parameter(torch.empty(D_GATE_LORA, hidden_size, device=device, dtype=dtype))
self.k_k = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype))
self.k_a = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype))
self.r_k = nn.Parameter(torch.empty(n_head, head_size, device=device, dtype=dtype))
self.receptance = nn.Linear(hidden_size, hidden_size_att, bias=False, device=device, dtype=dtype)
self.key = nn.Linear(hidden_size, hidden_size_att, bias=False, device=device, dtype=dtype)
self.value = nn.Linear(hidden_size, hidden_size_att, bias=False, device=device, dtype=dtype)
self.output = nn.Linear(hidden_size_att, hidden_size, bias=False, device=device, dtype=dtype)
self.ln_x = nn.GroupNorm(n_head, hidden_size_att, device=device, dtype=dtype, eps=(1e-5)*head_size)
def init_parameters(self):
'''
Reset the parameters of the block, to an initial state used for training a model from scratch
'''
configMap = self.configMap
# Get required props
hidden_size = configMap.hidden_size
num_hidden_layers = configMap.num_hidden_layers
# Get the layer id
layer_id = self.layer_id
# Get optional props
device = configMap.get_device(None)
dtype = configMap.get_dtype('bfloat16')
# By default, hidden_size_ffn = hidden_size
hidden_size_att = configMap.get_hidden_size_att()
# Assert hidden_size == hidden_size_att, until we support different hidden_size and hidden_size_att
assert hidden_size == hidden_size_att, "hidden_size should be equal to hidden_size_att (@TODO: support different hidden_size and hidden_size_att)"
# Head size settings
head_size = self.head_size
# Number of heads
n_head = hidden_size_att // head_size
assert hidden_size_att % head_size == 0, "hidden_size_att should be divisible by head_size"
# Reset the various params
# ---
with torch.no_grad():
ratio_0_to_1 = layer_id / (num_hidden_layers - 1) # 0 to 1
ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0
ddd = torch.ones(1, 1, hidden_size, device=device, dtype=dtype)
for i in range(hidden_size):
ddd[0, 0, i] = i / hidden_size
# Note: for some data, you can reduce D_GATE_LORA or even remove this gate
def calc_lora_rank(exponent, multiplier):
return max(1, round(hidden_size ** exponent * multiplier / 32)) * 32
D_DECAY_LORA = calc_lora_rank(0.5, 1.8)
D_AAA_LORA = calc_lora_rank(0.5, 1.8)
D_MV_LORA = calc_lora_rank(0.5, 1.3)
D_GATE_LORA = calc_lora_rank(0.8, 0.6)
self.x_r.copy_(1.0 - torch.pow(ddd, 0.2 * ratio_1_to_almost0))
self.x_w.copy_(1.0 - torch.pow(ddd, 0.9 * ratio_1_to_almost0))
self.x_k.copy_(1.0 - (torch.pow(ddd, 0.9 * ratio_1_to_almost0) + 0.4 * ratio_0_to_1))
self.x_v.copy_(1.0 - (torch.pow(ddd, 0.4 * ratio_1_to_almost0) + 0.6 * ratio_0_to_1))
self.x_a.copy_(1.0 - torch.pow(ddd, 0.9 * ratio_1_to_almost0))
self.x_g.copy_(1.0 - torch.pow(ddd, 0.2 * ratio_1_to_almost0))
def ortho_init(x, scale):
x = x.to(device)
shape = x.shape
if len(shape) == 2:
gain = math.sqrt(shape[0] / shape[1]) if shape[0] > shape[1] else 1
nn.init.orthogonal_(x, gain=gain * scale)
elif len(shape) == 3:
gain = math.sqrt(shape[1] / shape[2]) if shape[1] > shape[2] else 1
for i in range(shape[0]):
nn.init.orthogonal_(x[i], gain=gain * scale)
else:
assert False
return x.to(device, dtype=dtype)
# D_DECAY_LORA = max(32, int(round( (1.8*(hidden_size**0.5)) /32)*32))
decay_speed = torch.ones(hidden_size, device=device, dtype=dtype)
for n in range(hidden_size):
decay_speed[n] = -7 + 5 * (n / (hidden_size - 1)) ** (0.85 + 1.0 * ratio_0_to_1 ** 0.5)
self.w0.copy_(decay_speed.reshape(1,1,hidden_size).to(device, dtype=dtype) + 0.5) # !!! 0.5 comes from F.softplus !!!
self.w1.copy_(torch.zeros(hidden_size, D_DECAY_LORA, device=device, dtype=dtype))
self.w2.copy_(ortho_init(torch.zeros(D_DECAY_LORA, hidden_size), 0.1))
# D_AAA_LORA = max(32, int(round( (1.8*(hidden_size**0.5)) /32)*32)) # suggestion
self.a0.copy_(torch.zeros(1,1,hidden_size, device=device, dtype=dtype))
self.a1.copy_(torch.zeros(hidden_size, D_AAA_LORA, device=device, dtype=dtype))
self.a2.copy_(ortho_init(torch.zeros(D_AAA_LORA, hidden_size), 0.1))
# D_MV_LORA = max(32, int(round( (1.3*(hidden_size**0.5)) /32)*32)) # suggestion
if layer_id > 0:
self.v0.copy_(torch.zeros(1,1,hidden_size, device=device, dtype=dtype)+1.0)
self.v1.copy_(torch.zeros(hidden_size, D_MV_LORA, device=device, dtype=dtype))
self.v2.copy_(ortho_init(torch.zeros(D_MV_LORA, hidden_size), 0.1))
# D_GATE_LORA = max(32, int(round( (0.6*(hidden_size**0.8)) /32)*32)) # suggestion
# Note: for some data, you can reduce D_GATE_LORA or even remove this gate
self.g1.copy_(torch.zeros(hidden_size, D_GATE_LORA, device=device, dtype=dtype))
self.g2.copy_(ortho_init(torch.zeros(D_GATE_LORA, hidden_size), 0.1))
self.k_k.copy_(torch.ones(1,1,hidden_size, device=device, dtype=dtype)*0.85)
self.k_a.copy_(torch.ones(1,1,hidden_size, device=device, dtype=dtype))
self.r_k.copy_(torch.zeros(n_head,head_size, device=device, dtype=dtype))
self.receptance = nn.Linear(hidden_size, hidden_size_att, bias=False, device=device, dtype=dtype)
self.key = nn.Linear(hidden_size, hidden_size_att, bias=False, device=device, dtype=dtype)
self.value = nn.Linear(hidden_size, hidden_size_att, bias=False, device=device, dtype=dtype)
self.output = nn.Linear(hidden_size_att, hidden_size, bias=False, device=device, dtype=dtype)
self.ln_x = nn.GroupNorm(n_head, hidden_size_att, device=device, dtype=dtype, eps=(1e-5)*head_size)
def forward(self, x:Tensor, shift_state_in:Tensor, wkv_state_in:Tensor, v_first_val:Tensor) -> tuple[Tensor,Tensor,Tensor,Tensor]:
'''
forwarding time mix given the model weights and the input tokens and states.
Given:
- Incoming token embedding size of shape [batch_size, seq_len, embedding_size]
- Incoming states containing of shapes:
[batch_size, state_size] ## Token Shift state,
[batch_size, n_head, head_size, head_size] ## WKV state
- Incoming v_first_val of shape [batch_size, seq_len, embedding_size]
Returns a pair
- output embedding of shape [batch_size, seq_len, embedding_size]
- output state of shapes:
[batch_size, state_size] ## Token Shift state,
[batch_size, n_head, head_size, head_size] ## WKV state
- output v_first_val of shape [batch_size, seq_len, embedding_size]
'''
# Get the sizing
BATCH_SIZE, SEQ_LEN, IN_EMB_SIZE = x.size()
N_HEAD = self.n_head
HEAD_SIZE = self.head_size
# Ensure wkv_state_in is initialized
if wkv_state_in is None:
wkv_state_in = torch.zeros(BATCH_SIZE,N_HEAD,HEAD_SIZE,HEAD_SIZE, dtype=torch.float,device=w.device)
else:
wkv_state_in = wkv_state_in.clone()
##########
## x070
##########
shift_state_out = x[:, -1]
dxprev = torch.cat((shift_state_in.unsqueeze(1), x[:, :-1]), dim=1) - x
xr = x + dxprev * self.x_r
xw = x + dxprev * self.x_w
xk = x + dxprev * self.x_k
xv = x + dxprev * self.x_v
xa = x + dxprev * self.x_a
xg = x + dxprev * self.x_g
xx = dxprev
r = self.receptance(xr)
w = torch.tanh(xw @ self.w1) @ self.w2
k = self.key(xk)
v = self.value(xv)
g = torch.sigmoid(xg @ self.g1) @ self.g2
iclr = torch.sigmoid(self.a0 + (xa @ self.a1) @ self.a2) # a is "in-context learning rate"
kk = F.normalize((k * self.k_k).view(BATCH_SIZE,SEQ_LEN,N_HEAD,-1), dim=-1, p=2.0).view(BATCH_SIZE, SEQ_LEN, IN_EMB_SIZE)
k = k * (1 + (iclr-1) * self.k_a)
if self.layer_id == 0:
v_first_val = v # store the v of the first layer
else:
v = v + (v_first_val - v) * torch.sigmoid(self.v0 + (xv @ self.v1) @ self.v2) # add value residual
tmix_backend = self.tmix_backend.lower()
if tmix_backend == "auto":
if triton is None or self.receptance.weight.device.type == "cpu":
tmix_backend = "pytorch"
else:
tmix_backend = "cuda"
if tmix_backend == "pytorch_ref" or tmix_backend == "pytorch_ref_ori":
# Pure pytorch mode for rwkv attention
# from .kernel.rwkv7_attn_pytorch import rwkv7_attn_pytorch_ref
# Reference minimal compilation version
w = torch.exp(-0.606531 * torch.sigmoid((self.w0 + w).float())) # 0.606531 = exp(-0.5)
xx, wkv_state_out = rwkv7_attn_pytorch_ref(r, w, k, v, kk, iclr, BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, xx, wkv_state_in)
elif tmix_backend == "pytorch_ref_fp32":
# Pure pytorch mode for rwkv attention
# from .kernel.rwkv7_attn_pytorch import rwkv7_attn_pytorch_ref_fp32
# Modified to follow the same logic as "cuda" version
# w = torch.exp(-0.606531 * torch.sigmoid((self.w0 + w).float())) # 0.606531 = exp(-0.5)
w = -F.softplus(-(self.w0 + w)) - 0.5
xx, wkv_state_out = rwkv7_attn_pytorch_ref_fp32(r, w, k, v, kk, iclr, BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, xx, wkv_state_in)
elif tmix_backend == "pytorch":
# Pure pytorch mode for rwkv attention
# from .kernel.rwkv7_attn_pytorch import rwkv7_attn_pytorch
# Tweaked pytorch compile varient
w = torch.exp(-0.606531 * torch.sigmoid((self.w0 + w).float())) # 0.606531 = exp(-0.5)
xx, wkv_state_out = rwkv7_attn_pytorch(r, w, k, v, kk, iclr, BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, xx, wkv_state_in)
elif tmix_backend == "triton":
if triton is None:
raise ValueError("Triton not available, unable to load triton kernel")
# from .kernel.rwkv7_attn_triton import rwkv7_attn_triton
w = -F.softplus(-(self.w0 + w)) - 0.5
xx, wkv_state_out = rwkv7_attn_triton(r, w, k, v, kk, iclr, s0=wkv_state_in)
elif tmix_backend == "triton_bighead":
if triton is None:
raise ValueError("Triton not available, unable to load triton kernel")
# from .kernel.rwkv7_attn_triton import rwkv7_attn_triton_bighead
w = -F.softplus(-(self.w0 + w)) - 0.5
xx, wkv_state_out = rwkv7_attn_triton_bighead(r, w, k, v, kk, iclr, s0=wkv_state_in)
elif tmix_backend == "cuda_ref":
# Cuda based method for rwkv attention
# from .kernel.rwkv7_attn_cuda import rwkv7_attn_cuda_ref
# Reference cuda version (no state output)
w = -F.softplus(-(self.w0 + w)) - 0.5
xx, wkv_state_out = rwkv7_attn_cuda_ref(r, w, k, v, kk, iclr, s0=wkv_state_in)
elif tmix_backend == "cuda":
# Cuda based method for rwkv attention
# from .kernel.rwkv7_attn_cuda import rwkv7_attn_cuda
# Modified cuda version (with state output)
w = -F.softplus(-(self.w0 + w)) - 0.5
xx, wkv_state_out = rwkv7_attn_cuda(r, w, k, v, kk, iclr, s0=wkv_state_in)
elif tmix_backend == "fla":
# FLA based method for rwkv attention
# from .kernel.rwkv7_attn_fla import rwkv7_attn_fla
# FLA runs with the softplus w
w = -F.softplus(-(self.w0 + w)) - 0.5
xx, wkv_state_out = rwkv7_attn_fla(r, w, k, v, kk, iclr, BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, xx, wkv_state_in)
elif tmix_backend == "fla_fused" or tmix_backend == "fused_fla":
# FLA based method for rwkv attention
# from .kernel.rwkv7_attn_fla import rwkv7_attn_fused_reccurent_fla
# FLA runs with the softplus w
w = -F.softplus(-(self.w0 + w)) - 0.5
xx, wkv_state_out = rwkv7_attn_fused_reccurent_fla(r, w, k, v, kk, iclr, BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, xx, wkv_state_in)
else:
raise ValueError(f"Unknown tmix_backend: {tmix_backend}")
# wkv_state_in normalization of type
if wkv_state_in is not None:
wkv_state_out = wkv_state_out.to(wkv_state_in.dtype)
######## cuda-based method
# wkv_state_out = wkv_state_in.clone()
# w = -F.softplus(-(self.w0 + w)) - 0.5 # soft-clamp to (-inf, -0.5)
# xx = RWKV7_OP(wkv_state_out, r, w, k, v, -kk, kk*a)
######## cuda-based method
xx = self.ln_x(xx.view(BATCH_SIZE * SEQ_LEN, IN_EMB_SIZE)).view(BATCH_SIZE, SEQ_LEN, IN_EMB_SIZE)
xx = xx + ((r.view(BATCH_SIZE,SEQ_LEN,N_HEAD,-1)*k.view(BATCH_SIZE,SEQ_LEN,N_HEAD,-1)*self.r_k).sum(dim=-1, keepdim=True) * v.view(BATCH_SIZE,SEQ_LEN,N_HEAD,-1)).view(BATCH_SIZE,SEQ_LEN,IN_EMB_SIZE)
xx = self.output(xx * g)
return xx, shift_state_out, wkv_state_out, v_first_val
@torch.compile(mode="default")
def forward_with_default_compile(self, in_x:Tensor, shift_state_in:Tensor, wkv_state_in:Tensor, v_first_val_in:Tensor, out_x:Tensor, shift_state_out:Tensor, wkv_state_out:Tensor, v_first_val_out:Tensor) -> tuple[Tensor,Tensor,Tensor,Tensor]:
'''
Compiled varient of the forward function
With no new tensors being created for the output
Useful for static memory allocation optimizations inference
'''
out_x[:], shift_state_out[:], wkv_state_out[:], v_first_val_out[:] = self.forward(in_x, shift_state_in, wkv_state_in, v_first_val_in)
return out_x, shift_state_out, wkv_state_out, v_first_val_out
@torch.compile(mode="reduce-overhead")
def forward_with_reduce_compile(self, in_x:Tensor, shift_state_in:Tensor, wkv_state_in:Tensor, v_first_val:Tensor) -> tuple[Tensor,Tensor,Tensor,Tensor]:
'''
Compiled varient of the forward function
With no input tensor being modified.
Useful for reduce-overhead compile mode
'''
return self.forward(in_x, shift_state_in, wkv_state_in, v_first_val)
# ---------------------------------
#
# Model state handling
#
# ---------------------------------
def load_from_model_state_dict(self, model_state_dict: dict, layer_id:int, non_blocking:bool=True):
'''
Given the Full/partial RWKV model weights, loaded via `torch.load`
Setup the the current module weights, using the layer_id
'''
# Get the current state_dict
current_state_dict = self.state_dict()
# Iterate each parameter in the state_dict, and compare from the model
for n in current_state_dict:
model_key = f"blocks.{layer_id}.att.{n}"
if model_key not in model_state_dict:
continue
# Copy the values from the state_dict
try:
current_state_dict[n].copy_(model_state_dict[model_key], non_blocking=non_blocking)
except Exception as e:
print(f"[ERROR] loading: {model_key} | model shape: {current_state_dict[n].shape} | weight shape: {model_state_dict[model_key].shape}")
raise e
# ----------------
# block/rwkv7_layer_block.py
# ----------------
import torch
from torch import nn
from torch import Tensor
from typing import Union
from torch.nn import functional as F
# from .rwkv7_block_config_map import RWKV7BlockConfigMap
# from .rwkv7_channel_mix import RWKV7ChannelMix
# from .rwkv7_time_mix import RWKV7TimeMix
class RWKV7LayerBlock(torch.nn.Module):
'''
layer block for RWKV V7
'''
def __init__(self, configMap: Union[RWKV7BlockConfigMap, any]):
super().__init__()
configMap:RWKV7BlockConfigMap = RWKV7BlockConfigMap.normalize(configMap)
self.configMap = configMap
# Get required props
hidden_size = configMap.hidden_size
device = configMap.get_device(None)
dtype = configMap.get_dtype('bfloat16')
dropout_rate = configMap.dropout_rate
# Get valid layer_id
layer_id = configMap.get_layer_id(-1)
assert layer_id >= 0, f'layer_id must be >= 0, got {layer_id}'
# Setup the layernorms, and mixes
self.ln1 = nn.LayerNorm(hidden_size, device=device, dtype=dtype)
self.ln2 = nn.LayerNorm(hidden_size, device=device, dtype=dtype)
if layer_id == 0:
self.ln0 = nn.LayerNorm(hidden_size, device=device, dtype=dtype)
else:
self.ln0 = nn.Identity(device=device)
# Setup the time and channel mix
self.att = RWKV7TimeMix(configMap)
self.ffn = RWKV7ChannelMix(configMap)
# Setup droupout at block level
if dropout_rate > 0.0:
self.drop0 = nn.Dropout(p = dropout_rate,device=device)
self.drop1 = nn.Dropout(p = dropout_rate,device=device)
else:
self.drop0 = nn.Identity(device=device)
self.drop1 = nn.Identity(device=device)
def init_parameters(self):
'''
Reset the parameters of the block, to an initial state used for training a model from scratch
'''
configMap = self.configMap
# Get required props
hidden_size = configMap.hidden_size
device = configMap.get_device(None)
dtype = configMap.get_dtype('bfloat16')
dropout_rate = configMap.dropout_rate
# Get valid layer_id
layer_id = configMap.get_layer_id(-1)
assert layer_id >= 0, f'layer_id must be >= 0, got {layer_id}'
# Redo the Setup for the layernorms, and mixes
self.ln1 = nn.LayerNorm(hidden_size, device=device, dtype=dtype)
self.ln2 = nn.LayerNorm(hidden_size, device=device, dtype=dtype)
if layer_id == 0:
self.ln0 = nn.LayerNorm(hidden_size, device=device, dtype=dtype)
else:
self.ln0 = nn.Identity(device=device)
# Call the sub blocks init_parameters
self.att.init_parameters()
self.ffn.init_parameters()
def forward(
self, x:torch.Tensor,
last_state: tuple[torch.Tensor,torch.Tensor,torch.Tensor],
v_first:torch.Tensor
) -> tuple[torch.Tensor,tuple[torch.Tensor,torch.Tensor,torch.Tensor],torch.Tensor]:
'''
Forward the block given the input tokens and the last state
Last state is a tuple of the following
- time mix shift state
- time mix wkv state
- channel mix shift state
Returns a pair of the output embedding, v_first and the next state
'''
# # Ensure everything is in the right device
# x = x.to(self.ln1.weight.device)
# last_state = [ s.to(self.ln1.weight.device) for s in last_state ]
# Note, that this only applies for layer 0
ln0_out = self.ln0(x)
# assert self.ln1(x) is not None
# assert last_state.tmix_shift is not None
# assert last_state.tmix_wkv is not None
att_out, tmix_shift, tmix_wkv, v_first = self.att(
self.ln1(ln0_out),
last_state[0], # tmix_shift,
last_state[1], # tmix_wkv
v_first
)
# x = x + att_out
x = self.drop0(ln0_out + att_out)
ffn_out, ffn_state = self.ffn(
self.ln2(x),
last_state[2] # cmix_shift,
)
# x = x + ffn_out
x = self.drop1(x + ffn_out)
# # Debugging for NaN
# layer_id = self.configMap.get_layer_id(-1)
# assert torch.isnan(att_out).sum() == 0, f'NaN detected att_out @ layer {layer_id}'
# assert torch.isnan(ffn_out).sum() == 0, f'NaN detected ffn_out @ layer {layer_id}'
# assert torch.isnan(v_first).sum() == 0, f'NaN detected v_first @ layer {layer_id}'
# assert torch.isnan(tmix_shift).sum() == 0, f'NaN detected tmix_shift @ layer {layer_id}'
# assert torch.isnan(tmix_wkv).sum() == 0, f'NaN detected tmix_wkv @ layer {layer_id}'
# assert torch.isnan(ffn_state).sum() == 0, f'NaN detected ffn_state @ layer {layer_id}'
# assert torch.isnan(x).sum() == 0, f'NaN detected block out @ layer {layer_id}'
return x, (tmix_shift, tmix_wkv, ffn_state), v_first
@torch.compile(mode="default")
def forward_with_default_compile(
self,
in_x:torch.Tensor,
in_state: tuple[torch.Tensor,torch.Tensor,torch.Tensor],
in_v_first:torch.Tensor,
out_x:torch.Tensor,
out_state: tuple[torch.Tensor,torch.Tensor,torch.Tensor],
out_v_first:torch.Tensor
) -> tuple[torch.Tensor,tuple[torch.Tensor,torch.Tensor,torch.Tensor],torch.Tensor]:
'''
Compiled varient of the forward function
With no new tensors being created for the output
Useful for static memory allocation optimizations inference
'''
out_x[:], tmp_state, out_v_first[:] = self.forward(in_x, in_state, in_v_first)
out_state[0][:], out_state[1][:], out_state[2][:] = tmp_state
return out_x, out_state, out_v_first
@torch.compile(mode="reduce-overhead")
def forward_with_reduce_compile(self, in_x: torch.Tensor, in_state: tuple[torch.Tensor,torch.Tensor,torch.Tensor], in_v_first:torch.Tensor) -> tuple[torch.Tensor,tuple[torch.Tensor,torch.Tensor,torch.Tensor],torch.Tensor]:
'''
Compiled varient of the forward function
'''
return self.forward(in_x, in_state, in_v_first)
def load_from_model_state_dict(self, model_state_dict:dict, layer_id:int=-1, non_blocking:bool=True):
'''
Given the Full/partial RWKV model weights, load the block weights accordingly
'''
if layer_id <= -1:
layer_id = self.configMap.get_layer_id(-1)
assert layer_id >= 0, f'layer_id must be >= 0, got {layer_id}'
# Get the current state_dict
current_state_dict = self.state_dict()
# Iterate each parameter in the state_dict, and compare from the model
for n in current_state_dict:
model_key = f"blocks.{layer_id}.{n}"
if model_key not in model_state_dict:
continue
# Copy the values from the state_dict
try:
current_state_dict[n].copy_(model_state_dict[model_key], non_blocking=non_blocking)
except Exception as e:
print(f"[ERROR] loading: {model_key} | model shape: {current_state_dict[n].shape} | weight shape: {model_state_dict[model_key].shape}")
raise e
# ----------------
# model/rwkv7_goose_config_map.py
# ----------------
from dataclasses import dataclass
from typing import Optional
from typing import Union
import torch
# from ..block.rwkv7_block_config_map import RWKV7BlockConfigMap
@dataclass
class RWKV7GooseConfigMap(RWKV7BlockConfigMap):
# This is the world tokenizer size
vocab_size: int = 65536
init_state_wkv: bool = False
# ---
# Initializer, with excess arg ignore
# ---
def __init__(
self,
vocab_size: int = 65536,
init_state_wkv: bool = False,
**kwargs
) -> None:
self.vocab_size = vocab_size
self.init_state_wkv = init_state_wkv
super().__init__(**kwargs)
@staticmethod
def normalize(config_map: any) -> 'RWKV7GooseConfigMap':
'''
Converts either maps, objs or RWKV7BlockConfigMap
'''
if isinstance(config_map, RWKV7GooseConfigMap):
return config_map
if isinstance(config_map, dict):
return RWKV7GooseConfigMap(**config_map)
if hasattr(config_map, '__dict__'):
return RWKV7GooseConfigMap(**config_map.__dict__)
raise ValueError(f"Unsupported config_map type: {type(config_map)}")
@staticmethod
def from_model_state_dict(state_dict: dict, **kwargs) -> 'RWKV7GooseConfigMap':
'''
Converts the state dict to the config map
'''
# Iterate and count the layers
num_hidden_layers = 0
for key in state_dict.keys():
if key.startswith('blocks.'):
idx = key.split('.')[1]
num_hidden_layers = max(num_hidden_layers, int(idx)+1)
# Enable wkv_state
if 'init_state.0.wkv' in state_dict:
kwargs['init_state_wkv'] = True
# Initialize the config map, with the configured values
return RWKV7GooseConfigMap(
num_hidden_layers=num_hidden_layers,
hidden_size=state_dict['emb.weight'].shape[1],
vocab_size=state_dict['emb.weight'].shape[0],
# init_state_wkv=hasattr(state_dict, 'init_state.0.wkv'),
# n_head=state_dict['blocks.0.att.r_k'].shape[0],
head_size=state_dict['blocks.0.att.r_k'].shape[1],
hidden_size_att=state_dict['blocks.0.att.key.weight'].shape[1],
hidden_size_ffn=state_dict['blocks.0.ffn.key.weight'].shape[0],
**kwargs
)
# ----------------
# model/rwkv7_goose_model.py
# ----------------
import torch
from torch import nn
from torch import Tensor
from typing import Union
# from .rwkv7_goose_config_map import RWKV7GooseConfigMap
# from ..block.rwkv7_layer_block import RWKV7LayerBlock
class RWKV7GooseModel(nn.Module):
'''
RWKV7 Goose Model architecture
Simplified implementation
'''
def __init__(self, config: Union[RWKV7GooseConfigMap, any, None] = None):
# Initialize the base class
super().__init__()
# Normalize the config
configMap:RWKV7GooseConfigMap = RWKV7GooseConfigMap.normalize(config)
self.configMap = configMap
# Get the required prop
num_hidden_layers = configMap.num_hidden_layers
vocab_size = configMap.vocab_size
device = configMap.get_device(None)
dtype = configMap.get_dtype('bfloat16')
hidden_size = configMap.hidden_size
# Embedding layer
self.emb = nn.Embedding(vocab_size, hidden_size, device=device, dtype=dtype)
# main block layers
blockList = [None]*num_hidden_layers
for i in range(num_hidden_layers):
blockList[i] = RWKV7LayerBlock(configMap.new_block_config_map(layer_id=i))
self.blocks = nn.ModuleList(blockList)
# ln_out and head
self.ln_out = nn.LayerNorm(hidden_size, device=device, dtype=dtype)
self.head = nn.Linear(hidden_size, vocab_size, bias=False, device=device, dtype=dtype)
# init state tuning support
if configMap.init_state_wkv:
stateTuneList = [None]*num_hidden_layers
for i in range(num_hidden_layers):
stateTuneList[i] = nn.ParameterDict({
"wkv": nn.Parameter(torch.zeros(hidden_size // 64, 64, 64, device=device, dtype=dtype)),
})
self.init_state = nn.ParameterList(stateTuneList)
def init_parameters(self):
'''
Reset the parameters of the block, to an initial state used for training a model from scratch
'''
# Get the required prop
configMap = self.configMap
num_hidden_layers = configMap.num_hidden_layers
vocab_size = configMap.vocab_size
device = configMap.get_device(None)
dtype = configMap.get_dtype('bfloat16')
hidden_size = configMap.hidden_size
# Iterate and reset the blocks
for i in range(num_hidden_layers):
self.blocks[i].init_parameters()
# Reinit the Embedding layer
self.emb = nn.Embedding(vocab_size, hidden_size, device=device, dtype=dtype)
# Reinit the ln_out and head
self.ln_out = nn.LayerNorm(hidden_size, device=device, dtype=dtype)
if self.head is not None:
self.head = nn.Linear(hidden_size, vocab_size, bias=False, device=device, dtype=dtype)
# Reinit the init state tuning support
if configMap.init_state_wkv:
stateTuneList = [None]*num_hidden_layers
for i in range(num_hidden_layers):
stateTuneList[i] = nn.ParameterDict({
"wkv": nn.Parameter(torch.zeros(hidden_size // 64, 64, 64, device=device, dtype=torch.float)),
})
self.init_state = nn.ParameterList(stateTuneList)
def load_from_model_state_dict(self, state_dict: dict, non_blocking:bool=True):
'''
Given the Full/partial RWKV model weights, loaded via `torch.load`
Setup the RWKV_TimeMix model weights, using the layer_id
'''
for i, block in enumerate(self.blocks):
block.load_from_model_state_dict(state_dict, i, non_blocking=non_blocking)
self.ln_out.weight.data.copy_(state_dict['ln_out.weight'], non_blocking=True)
self.ln_out.bias.data.copy_(state_dict['ln_out.bias'], non_blocking=True)
self.head.weight.data.copy_(state_dict['head.weight'], non_blocking=True)
self.emb.weight.data.copy_(state_dict['emb.weight'], non_blocking=True)
### ---
###
### Init state handling
###
### ---
def get_init_state(self, batch_size:int=1, skip_init_state:bool=False) -> list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]]:
'''
Get an initialized copy of the model state, for the given batch size
'''
# Get required configs
hidden_size = self.configMap.hidden_size
init_state_wkv = self.configMap.init_state_wkv
num_hidden_layers = self.configMap.num_hidden_layers
# Prepare the initial state
init_state = [ None for i in range(num_hidden_layers) ]
for i in range(num_hidden_layers):
device = self.blocks[i].ln1.weight.data.device
dtype = self.blocks[i].ln1.weight.data.dtype
# Use the saved init_state if enabled
# TODO: Consider letting the wkv_state dtype be a parameter
wkv_state = torch.zeros(batch_size, hidden_size // 64, 64, 64, device=device, dtype=torch.float)
if init_state_wkv and skip_init_state == False:
init_wkv = self.init_state[i]["wkv"]
for b in range(batch_size):
wkv_state[b][:] = init_wkv
# Prepare the state
init_state[i] = (
torch.zeros(batch_size, hidden_size, device=device, dtype=dtype),
wkv_state,
torch.zeros(batch_size, hidden_size, device=device, dtype=dtype)
)
return init_state
### ---
###
### Model Forward
###
### ---
def forward(
self, idx:torch.Tensor,
prv_stateList:list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]] = None,
ret_stateList:list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]] = None,
) -> tuple[torch.Tensor,list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]]]:
'''
Forward the block set, given the input tokens and the last state
Last state is a list of tuple of the following
- time mix shift state
- time mix wkv state
- channel mix shift state
Returns a pair of the output embedding and the next state
'''
# Prepare the state, with the batch size
if prv_stateList is None:
prv_stateList = self.get_init_state(idx.shape[0])
# If no return state is set, let _forward_internal, set it up
if ret_stateList is None:
ret_stateList = [ None for i in range(self.configMap.num_hidden_layers) ]
return self._forward_internal(idx, prv_stateList, ret_stateList, overwrite_ret_tensor=False)
# Forward internally
return self._forward_internal(idx, prv_stateList, ret_stateList, overwrite_ret_tensor=True)
def _forward_block_hook(self,
block:RWKV7LayerBlock,
x_hidden_state:torch.Tensor,
prv_stateList:list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]],
v_first:torch.Tensor
) -> tuple[torch.Tensor,tuple[torch.Tensor,torch.Tensor,torch.Tensor],torch.Tensor]:
'''
Forward block hook operation, that is easily overridable.
To implement gradient checkpointing for use in various trainers
'''
x_hidden_state = x_hidden_state.to(block.ln1.weight.device, non_blocking=True)
return block(x_hidden_state, prv_stateList, v_first)
def _forward_internal(
self, idx:torch.Tensor,
prv_stateList:list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]],
ret_stateList:list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]],
overwrite_ret_tensor:bool=False
) -> tuple[torch.Tensor,list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]]]:
'''
Internal forward operations, which assumes the state is already initialized
Due to the lack of safety checks, this should not be used directly
'''
# Lets get the embedding
idx = idx.to(self.emb.weight.device, non_blocking=True)
x_hidden_state = self.emb(idx)
v_first = None
# Iterate the block layers, compute the x_hidden_state
if overwrite_ret_tensor:
for i, block in enumerate(self.blocks):
# x_hidden_state, last_block_state, v_first = block(x_hidden_state, prv_stateList[i], v_first)
x_hidden_state, last_block_state, v_first = self._forward_block_hook(block, x_hidden_state, prv_stateList[i], v_first)
ret_stateList[i][0][:] = last_block_state[0]
ret_stateList[i][1][:] = last_block_state[1]
ret_stateList[i][2][:] = last_block_state[2]
else:
ret_stateList = [ None for i in range( len(self.blocks) ) ]
for i, block in enumerate(self.blocks):
# x_hidden_state, ret_sublist, v_first = block(x_hidden_state, prv_stateList[i], v_first)
x_hidden_state, ret_sublist, v_first = self._forward_block_hook(block, x_hidden_state, prv_stateList[i], v_first)
ret_stateList[i] = ret_sublist
# Final layer norm, and head
x_hidden_state = x_hidden_state.to(self.ln_out.weight.device, non_blocking=True)
x_hidden_state = self.ln_out(x_hidden_state)
x_out = self.head(x_hidden_state)
# Return the output and the state list
return x_out, ret_stateList
@torch.compile(mode="default")
def forward_with_default_compile(
self, idx:torch.Tensor,
prv_stateList:list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]],
ret_stateList:list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]],
) -> tuple[torch.Tensor,list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]]]:
'''
Compiled varient of the forward function
With no new tensors being created for the output
Useful for static memory allocation optimizations inference
'''
# Forward internally
return self._forward_internal(idx, prv_stateList, ret_stateList, overwrite_ret_tensor=True)
@torch.compile(mode="reduce-overhead")
def forward_with_reduce_compile(
self, in_idx:torch.Tensor,
prv_stateList:list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]]
) -> tuple[torch.Tensor,list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]]]:
'''
Compiled varient of the forward function, requires previous state to be passed
'''
return self._forward_internal(in_idx, prv_stateList, None, overwrite_ret_tensor=False)