|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
torch.set_float32_matmul_precision('high') |
|
|
|
|
|
def rwkv7_attn_pytorch( |
|
r,w,k,v, kk,a, |
|
BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, |
|
xx, wkv_state_in |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
B,T,HC = w.shape |
|
|
|
|
|
chunk_size = 256 |
|
chunk_count = SEQ_LEN // chunk_size |
|
chunk_remainder = SEQ_LEN % chunk_size |
|
|
|
|
|
wkv_state_out = wkv_state_in.float() |
|
|
|
|
|
|
|
xx = xx.clone() |
|
|
|
|
|
for i in range(chunk_count): |
|
sta = i * chunk_size |
|
end = sta + chunk_size |
|
|
|
xx[:,sta:end], wkv_state_out = rwkv7_attn_pytorch_v2_chunk_w_compile_break( |
|
|
|
r[:,sta:end],w[:,sta:end],k[:,sta:end],v[:,sta:end], |
|
kk[:,sta:end],a[:,sta:end], |
|
BATCH_SIZE, chunk_size, N_HEAD, HEAD_SIZE, |
|
|
|
torch.zeros(B,chunk_size,HC, dtype=xx.dtype, device=xx.device), wkv_state_out |
|
) |
|
|
|
|
|
|
|
if chunk_remainder > 0: |
|
sta = chunk_count * chunk_size |
|
end = sta + chunk_remainder |
|
|
|
xx[:,sta:end], wkv_state_out = rwkv7_attn_pytorch_v2_chunk_w_compile_break( |
|
|
|
r[:,sta:end],w[:,sta:end],k[:,sta:end],v[:,sta:end], |
|
kk[:,sta:end],a[:,sta:end], |
|
BATCH_SIZE, chunk_remainder, N_HEAD, HEAD_SIZE, |
|
|
|
torch.zeros(B,chunk_remainder,HC, dtype=xx.dtype, device=xx.device), wkv_state_out, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
return xx, wkv_state_out.to(dtype=wkv_state_in.dtype) |
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.compiler.disable() |
|
def rwkv7_attn_pytorch_ref( |
|
r,w,k,v, kk,a, |
|
BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, |
|
xx, wkv_state_in |
|
): |
|
|
|
|
|
|
|
vk_state = wkv_state_in.float() |
|
for t in range(SEQ_LEN): |
|
r_, w_, k_, v_, kk_, a_ = r[:,t], w[:,t], k[:,t], v[:,t], kk[:,t], a[:,t] |
|
vk = v_.view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1) @ k_.view(BATCH_SIZE,N_HEAD,1,HEAD_SIZE) |
|
ab = (-kk_).view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1) @ (kk_*a_).view(BATCH_SIZE,N_HEAD,1,HEAD_SIZE) |
|
vk_state = (vk_state * w_.view(BATCH_SIZE,N_HEAD,1,HEAD_SIZE).float() + vk_state @ ab.float() + vk.float()) |
|
xx[:,t] = ((vk_state.to(dtype=xx.dtype) @ r_.view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1)).view(BATCH_SIZE,N_HEAD*HEAD_SIZE)) |
|
wkv_state_out = vk_state.to(dtype=wkv_state_in.dtype) |
|
return xx, wkv_state_out |
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.compiler.disable() |
|
def rwkv7_attn_pytorch_ref_fp32( |
|
r,w,k,v, kk, iclr, |
|
BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, |
|
xx, wkv_state_in |
|
): |
|
|
|
|
|
|
|
w = (-w.float().exp()).exp() |
|
|
|
|
|
vk_state = wkv_state_in.float() |
|
|
|
a = -kk |
|
b = kk * iclr |
|
|
|
for t in range(SEQ_LEN): |
|
r_, w_, k_, v_, a_, b_= r[:,t].float(), w[:,t].float(), k[:,t].float(), v[:,t].float(), a[:,t].float(), b[:,t].float() |
|
|
|
vk = v_.view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1) @ k_.view(BATCH_SIZE,N_HEAD,1,HEAD_SIZE) |
|
vk_state = (vk_state * w_.view(BATCH_SIZE,N_HEAD,1,HEAD_SIZE).float() + vk_state @ a_.float().view(BATCH_SIZE, N_HEAD,HEAD_SIZE,1) @ b_.view(BATCH_SIZE, N_HEAD,1,HEAD_SIZE) + vk.float()) |
|
xx[:,t] = ((vk_state @ r_.view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1)).view(BATCH_SIZE,N_HEAD*HEAD_SIZE)).to(dtype=xx.dtype) |
|
|
|
wkv_state_out = vk_state.to(dtype=wkv_state_in.dtype) |
|
return xx, wkv_state_out |
|
|
|
|
|
def rwkv7_attn_pytorch_chunk( |
|
r,w,k,v, kk,a, |
|
BATCH_SIZE, N_HEAD, HEAD_SIZE, |
|
xx, wkv_state_in, |
|
offset=0, chunk_size=16 |
|
): |
|
''' |
|
Chunked version of the RWKV7 attention, for better performance. |
|
If the chunk size is less then 128, this is generally compilable |
|
|
|
This is used by the triton/cuda implement, for the remaining % 16 chunks |
|
''' |
|
|
|
|
|
|
|
vk_state = wkv_state_in.float() |
|
for i in range(chunk_size): |
|
t = offset + i |
|
r_, w_, k_, v_, kk_, a_ = r[:,t], w[:,t], k[:,t], v[:,t], kk[:,t], a[:,t] |
|
vk = v_.view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1) @ k_.view(BATCH_SIZE,N_HEAD,1,HEAD_SIZE) |
|
ab = (-kk_).view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1) @ (kk_*a_).view(BATCH_SIZE,N_HEAD,1,HEAD_SIZE) |
|
vk_state = (vk_state * w_.view(BATCH_SIZE,N_HEAD,1,HEAD_SIZE).float() + vk_state @ ab.float() + vk.float()) |
|
xx[:,t] = (vk_state.to(dtype=xx.dtype) @ r_.view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1)).view(BATCH_SIZE,N_HEAD*HEAD_SIZE) |
|
wkv_state_out = vk_state.to(dtype=wkv_state_in.dtype) |
|
return xx, wkv_state_out |
|
|
|
|
|
def rwkv7_attn_pytorch_v2_chunk_w_compile_break( |
|
r,w,k,v, kk,a, |
|
BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, |
|
xx, wkv_state_in |
|
): |
|
''' |
|
Chunked version of the RWKV7 attention, for better performance |
|
''' |
|
full_vk_ = v.view(BATCH_SIZE,SEQ_LEN,N_HEAD, HEAD_SIZE,1) @ k.view(BATCH_SIZE,SEQ_LEN,N_HEAD, 1,HEAD_SIZE) |
|
full_iclr_ = (kk * a).view(BATCH_SIZE,SEQ_LEN,N_HEAD,1,HEAD_SIZE) |
|
full_ab = (-kk).view(BATCH_SIZE,SEQ_LEN,N_HEAD, HEAD_SIZE,1) @ full_iclr_ |
|
|
|
wkv_xx = torch.empty(BATCH_SIZE,SEQ_LEN,N_HEAD,HEAD_SIZE,HEAD_SIZE, dtype=xx.dtype, device=xx.device) |
|
wkv_xx, wkv_state_out = rwkv7_attn_pytorch_v2_inner_w_compile_break( |
|
r,w, |
|
full_vk_, full_ab, |
|
BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, |
|
wkv_xx, wkv_state_in |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
xx[:] = (wkv_xx.to(dtype=xx.dtype) @ r.view(BATCH_SIZE,SEQ_LEN,N_HEAD,HEAD_SIZE,1)).view(BATCH_SIZE,SEQ_LEN,N_HEAD*HEAD_SIZE) |
|
|
|
return xx, wkv_state_out |
|
|
|
@torch.compiler.disable() |
|
def rwkv7_attn_pytorch_v2_inner_w_compile_break( |
|
r, w, |
|
full_vk_, full_ab, |
|
BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, |
|
xx, wkv_state_in |
|
): |
|
''' |
|
Isolated sub-function with no compilation |
|
''' |
|
return rwkv7_attn_pytorch_v2_inner_jit( |
|
r, w, |
|
full_vk_, full_ab, |
|
BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, |
|
xx, wkv_state_in |
|
) |
|
|
|
|
|
@torch.jit.script |
|
def rwkv7_attn_pytorch_v2_inner_jit( |
|
r, w, |
|
full_vk_, full_ab, |
|
BATCH_SIZE:int, SEQ_LEN:int, N_HEAD:int, HEAD_SIZE:int, |
|
wkv_xx, wkv_state_in |
|
): |
|
''' |
|
Isolated sub-function with JIT |
|
''' |
|
|
|
|
|
wkv_state = wkv_state_in |
|
for t in range(SEQ_LEN): |
|
|
|
|
|
|
|
|
|
|
|
wkv_state = (wkv_state * w[:,t].view(BATCH_SIZE,N_HEAD,1,HEAD_SIZE).float() + wkv_state @ full_ab[:,t].view(BATCH_SIZE,N_HEAD,HEAD_SIZE,HEAD_SIZE).float() + full_vk_[:,t].view(BATCH_SIZE,N_HEAD,HEAD_SIZE,HEAD_SIZE).float()).clone() |
|
wkv_xx[:,t] = wkv_state.to(dtype=w.dtype) |
|
return wkv_xx, wkv_state |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch, os, time |
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_ref_wkv_cuda_kernel(CHUNK_LEN = 16, HEAD_SIZE = 64): |
|
from torch.utils.cpp_extension import load |
|
|
|
|
|
load_name = "wind_backstepping" |
|
load_file = "wkv7" |
|
|
|
|
|
if load_name in torch.ops: |
|
return torch.ops.wind_backstepping |
|
|
|
|
|
print("[WARNING] Reference CUDA kernel does not support input RWKV state, and is used only for training/validaiton purposes") |
|
|
|
|
|
this_file_path = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flags = ['-res-usage', f'-D_C_={HEAD_SIZE}', f"-D_CHUNK_LEN_={CHUNK_LEN}", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization"] |
|
try: |
|
load(name=load_name, sources=[f'{this_file_path}/cuda/{load_file}_cuda.cu', f'{this_file_path}/cuda/{load_file}_op.cpp'], is_python_module=False, verbose=True, extra_cuda_cflags=flags) |
|
except Exception as e: |
|
print("[WARNING] Failed to load the kernel, trying again (sometimes the compiler has wierd race condition)...") |
|
time.sleep(2) |
|
load(name=load_name, sources=[f'{this_file_path}/cuda/{load_file}_cuda.cu', f'{this_file_path}/cuda/{load_file}_op.cpp'], is_python_module=False, verbose=True, extra_cuda_cflags=flags) |
|
|
|
|
|
return torch.ops.wind_backstepping |
|
|
|
@torch.compiler.disable() |
|
def ref_wkv_cuda_forward(w,q,k,v,z,b, y,s,sa): |
|
torch.ops.wind_backstepping.forward(w,q,k,v,z,b, y,s,sa) |
|
|
|
@torch.compiler.disable() |
|
def ref_wkv_cuda_backward(w,q,k,v,z,b, dy,s,sa, dw,dq,dk,dv,dz,db): |
|
torch.ops.wind_backstepping.backward(w,q,k,v,z,b, dy,s,sa, dw,dq,dk,dv,dz,db) |
|
|
|
class RefCudaWindBackstepping(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, w,q,k,v,z,b): |
|
CHUNK_LEN=16 |
|
B,T,H,C = w.shape |
|
assert T%CHUNK_LEN == 0 |
|
assert all(i.dtype==torch.bfloat16 for i in [w,q,k,v,z,b]) |
|
assert all(i.is_contiguous() for i in [w,q,k,v,z,b]) |
|
y = torch.empty_like(v) |
|
s = torch.empty(B,H,T//CHUNK_LEN,C,C, dtype=torch.float32,device=w.device) |
|
sa = torch.empty(B,T,H,C, dtype=torch.float32,device=w.device) |
|
ref_wkv_cuda_forward(w,q,k,v,z,b, y,s,sa) |
|
ctx.save_for_backward(w,q,k,v,z,b,s,sa) |
|
return y |
|
@staticmethod |
|
def backward(ctx, dy): |
|
assert all(i.dtype==torch.bfloat16 for i in [dy]) |
|
assert all(i.is_contiguous() for i in [dy]) |
|
w,q,k,v,z,b,s,sa = ctx.saved_tensors |
|
dw,dq,dk,dv,dz,db = [torch.empty_like(x) for x in [w,q,k,v,z,b]] |
|
ref_wkv_cuda_backward(w,q,k,v,z,b, dy,s,sa, dw,dq,dk,dv,dz,db) |
|
return dw,dq,dk,dv,dz,db |
|
|
|
@torch.compiler.disable() |
|
def rwkv7_attn_cuda_ref(q,w,k,v, kk,iclr, HEAD_SIZE=64, s0=None): |
|
|
|
load_ref_wkv_cuda_kernel() |
|
|
|
|
|
B,T,HC = w.shape |
|
C = HEAD_SIZE |
|
H = HC//C |
|
|
|
|
|
assert T % 16 == 0, 'reference cuda, only works in multiple of 16' |
|
|
|
|
|
s0 = torch.zeros(B,H,C,C, dtype=torch.float,device=w.device) if s0 is None else s0 |
|
|
|
|
|
q,w,k,v,a,b = [i.view(B,T,H,C) for i in [q,w,k,v,(-kk),(kk*iclr)]] |
|
|
|
|
|
xx = RefCudaWindBackstepping.apply(w,q,k,v,a,b) |
|
return xx.view(B,T,HC), s0.view(B,H,C,C) |
|
|
|
|
|
|
|
|
|
|
|
def load_wkv_cuda_kernel(CHUNK_LEN = 16, HEAD_SIZE = 64): |
|
from torch.utils.cpp_extension import load |
|
|
|
|
|
load_name = "state_wind_backstepping" |
|
load_file = "state_wkv7" |
|
|
|
|
|
if load_name in torch.ops: |
|
return torch.ops.state_wind_backstepping |
|
|
|
|
|
this_file_path = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
|
|
|
|
flags = ['-res-usage', f'-D_C_={HEAD_SIZE}', f"-D_CHUNK_LEN_={CHUNK_LEN}", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization"] |
|
try: |
|
load(name=load_name, sources=[f'{this_file_path}/cuda/{load_file}_cuda.cu', f'{this_file_path}/cuda/{load_file}_op.cpp'], is_python_module=False, verbose=True, extra_cuda_cflags=flags) |
|
except Exception as e: |
|
print("[WARNING] Failed to load the kernel, trying again (sometimes the compiler has wierd race condition)...") |
|
time.sleep(2) |
|
load(name=load_name, sources=[f'{this_file_path}/cuda/{load_file}_cuda.cu', f'{this_file_path}/cuda/{load_file}_op.cpp'], is_python_module=False, verbose=True, extra_cuda_cflags=flags) |
|
|
|
|
|
return torch.ops.state_wind_backstepping |
|
|
|
@torch.compiler.disable() |
|
def wkv_cuda_forward(state, w,q,k,v,z,b, y,s,sa): |
|
torch.ops.state_wind_backstepping.forward(state, w,q,k,v,z,b, y,s,sa) |
|
|
|
@torch.compiler.disable() |
|
def wkv_cuda_backward(state, w,q,k,v,z,b, dy,s,sa, dw,dq,dk,dv,dz,db): |
|
torch.ops.state_wind_backstepping.backward(state, w,q,k,v,z,b, dy,s,sa, dw,dq,dk,dv,dz,db) |
|
|
|
class CudaWindBackstepping(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, s0, w,q,k,v,z,b): |
|
CHUNK_LEN=16 |
|
B,T,H,C = w.shape |
|
assert T%CHUNK_LEN == 0 |
|
assert all(i.dtype==torch.bfloat16 for i in [w,q,k,v,z,b]) |
|
assert all(i.is_contiguous() for i in [w,q,k,v,z,b]) |
|
y = torch.empty_like(v) |
|
s = torch.empty(B,H,T//CHUNK_LEN,C,C, dtype=torch.float32,device=w.device) |
|
sa = torch.empty(B,T,H,C, dtype=torch.float32,device=w.device) |
|
sOri = s0.clone() |
|
wkv_cuda_forward(s0, w,q,k,v,z,b, y,s,sa) |
|
ctx.save_for_backward(sOri, w,q,k,v,z,b,s,sa) |
|
return y |
|
@staticmethod |
|
def backward(ctx, dy): |
|
assert all(i.dtype==torch.bfloat16 for i in [dy]) |
|
assert all(i.is_contiguous() for i in [dy]) |
|
state,w,q,k,v,z,b,s,sa = ctx.saved_tensors |
|
dS0,dw,dq,dk,dv,dz,db = [torch.empty_like(x) for x in [state,w,q,k,v,z,b]] |
|
wkv_cuda_backward(state, w,q,k,v,z,b, dy,s,sa, dw,dq,dk,dv,dz,db) |
|
return dS0,dw,dq,dk,dv,dz,db |
|
|
|
@torch.compiler.disable() |
|
def rwkv7_attn_cuda(r,w,k,v, kk,iclr, HEAD_SIZE=64, s0=None): |
|
|
|
load_wkv_cuda_kernel() |
|
|
|
|
|
B,T,HC = w.shape |
|
|
|
|
|
chunk_remainder = T % 16 |
|
|
|
|
|
C = HEAD_SIZE |
|
H = HC//C |
|
|
|
|
|
s0 = torch.zeros(B,H,C,C, dtype=torch.float,device=w.device) if s0 is None else s0 |
|
sT = s0.to(dtype=torch.float) |
|
|
|
|
|
if chunk_remainder == 0: |
|
chunk_xx, chunk_sT = rwkv7_attn_cuda_chunk(r,w,k,v, kk,iclr, HEAD_SIZE, sT) |
|
return chunk_xx, chunk_sT.to(dtype=s0.dtype) |
|
|
|
|
|
chunks = T // 16 |
|
si = chunks * 16 |
|
|
|
|
|
chunk_xx, chunk_sT = rwkv7_attn_cuda_chunk( |
|
r[:,:si],w[:,:si],k[:,:si],v[:,:si], kk[:,:si],iclr[:,:si], |
|
HEAD_SIZE, s0 |
|
) |
|
|
|
|
|
remain_xx, last_sT = rwkv7_attn_pytorch_chunk( |
|
r[:,si:],torch.exp(-torch.exp(w[:,si:])),k[:,si:],v[:,si:], kk[:,si:],iclr[:,si:], |
|
B, H, C, |
|
torch.zeros(B, chunk_remainder, HC, device=w.device, dtype=w.dtype), |
|
chunk_sT, chunk_size=chunk_remainder |
|
) |
|
|
|
|
|
return torch.cat([chunk_xx.to(dtype=w.dtype), remain_xx.to(dtype=w.dtype)], dim=1), last_sT.to(dtype=s0.dtype) |
|
|
|
|
|
def rwkv7_attn_cuda_chunk(r,w,k,v, kk,iclr, HEAD_SIZE=64, s0=None): |
|
''' |
|
Triton implementation running in blocks of 16 (hardcoded requirement for the kernel) |
|
''' |
|
B,T,HC = w.shape |
|
assert T % 16 == 0, 'pure cuda, only works in multiple of 16' |
|
C = HEAD_SIZE |
|
H = HC//C |
|
|
|
|
|
a,b = -kk, (kk*iclr) |
|
r,w,k,v,a,b = [i.view(B,T,H,C) for i in [r,w,k,v,a,b]] |
|
|
|
if s0 is None: |
|
s1 = torch.zeros(B,H,C,C, dtype=torch.float,device=w.device) |
|
else: |
|
s1 = s0.clone() |
|
|
|
|
|
xx = CudaWindBackstepping.apply(s1,w,r,k,v,a,b) |
|
return xx.view(B,T,HC), s1.view(B,H,C,C) |
|
|
|
|
|
|
|
|
|
|
|
def rwkv7_attn_fla( |
|
r,w,k,v, kk,iclr, |
|
BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, |
|
xx, wkv_state_in |
|
): |
|
from fla.ops.rwkv7.chunk import chunk_rwkv7 |
|
|
|
|
|
r,w,k,v,a,b = [i.view(BATCH_SIZE,SEQ_LEN,N_HEAD,-1) for i in [r,w,k,v,-kk,(kk*iclr)]] |
|
log_w = -w.float().exp() |
|
|
|
|
|
output, vk_state = chunk_rwkv7(r=r, log_w=log_w, k=k, v=v, a=a, b=b, initial_state=wkv_state_in.float(), output_final_state=True) |
|
return output, vk_state.to(dtype=wkv_state_in.dtype) |
|
|
|
def rwkv7_attn_fused_reccurent_fla( |
|
r,w,k,v, kk,iclr, |
|
BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, |
|
xx, wkv_state_in |
|
): |
|
from fla.ops.rwkv7.fused_recurrent import fused_recurrent_rwkv7 |
|
|
|
|
|
r,w,k,v,a,b = [i.view(BATCH_SIZE,SEQ_LEN,N_HEAD,-1) for i in [r,w,k,v,-kk,(kk*iclr)]] |
|
log_w = -w.float().exp() |
|
|
|
|
|
output, vk_state = fused_recurrent_rwkv7(r=r, log_w=log_w, k=k, v=v, a=a, b=b, initial_state=wkv_state_in.float(), output_final_state=True) |
|
return output, vk_state.to(dtype=wkv_state_in.dtype) |
|
|
|
|
|
|
|
|
|
import torch |
|
import torch as th |
|
import triton |
|
import triton.language as tl |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@triton.jit |
|
def IND3(a,b,c,nb,nc): |
|
return (a*nb+b)*nc+c |
|
@triton.jit |
|
def IND4(a,b,c,d,nb,nc,nd): |
|
return ((a*nb+b)*nc+c)*nd+d |
|
@triton.jit |
|
def IND5(a,b,c,d,e,nb,nc,nd,ne): |
|
return (((a*nb+b)*nc+c)*nd+d)*ne+e |
|
|
|
@triton.jit |
|
def _prod(a,b): return a*b |
|
|
|
|
|
@triton.jit |
|
def tri_minv(A, n:tl.constexpr, prec:tl.constexpr): |
|
i = tl.arange(0,n) |
|
prod = (i[None,:]==i[:,None]).to(tl.float32) |
|
for j in range(n-1): |
|
prod += tl_dot(prec, prod, (A*((i[None,:]==j)*(i[:,None]>i[None,:]))).trans()) |
|
return prod.trans() |
|
|
|
@triton.jit |
|
def tl_dot(prec:tl.constexpr, a, b): |
|
if prec == 'fp32': |
|
return tl.dot(a.to(tl.float32),b.trans().to(tl.float32).trans(), allow_tf32=False) |
|
elif prec == 'tf32': |
|
return tl.dot(a.to(tl.float32),b.trans().to(tl.float32).trans(), allow_tf32=True) |
|
elif prec == 'bf16': |
|
return tl.dot(a.to(tl.bfloat16),b.trans().to(tl.bfloat16).trans(), allow_tf32=True) |
|
else: |
|
tl.static_assert(False) |
|
|
|
|
|
|
|
|
|
|
|
@triton.jit |
|
def fw_attn_triton(w_,q_,k_,v_,a_,b_, s0_,y_,s_,sT_, B:tl.constexpr,T:tl.constexpr,H:tl.constexpr,C:tl.constexpr,dT:tl.constexpr, prec:tl.constexpr): |
|
bi = tl.program_id(1) |
|
hi = tl.program_id(0) |
|
|
|
i = tl.arange(0,C)[None,:] |
|
state = tl.load(s0_+IND4(bi,hi,i.trans(),i, H,C,C)).to(tl.float32) |
|
for t0 in range(T//dT): |
|
t = t0*dT+tl.arange(0,dT)[:,None] |
|
sw = tl.load(w_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) |
|
sq = tl.load(q_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) |
|
sk = tl.load(k_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) |
|
sv = tl.load(v_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) |
|
sa = tl.load(a_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) |
|
sb = tl.load(b_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) |
|
|
|
w = (-sw.exp()).exp() |
|
fw = tl.reduce(w, 0, _prod, keep_dims=True) |
|
incl_pref = tl.cumprod(w,axis=0) |
|
non_incl_pref = incl_pref / w |
|
inv_incl_pref = 1 / incl_pref |
|
|
|
wq = sq * incl_pref |
|
wa = sa * non_incl_pref |
|
kwi = sk * inv_incl_pref |
|
bwi = sb * inv_incl_pref |
|
|
|
mask1 = (t > t.trans()) |
|
ab = tl_dot(prec, wa, bwi.trans()) * mask1 |
|
ak = tl_dot(prec, wa, kwi.trans()) * mask1 |
|
|
|
ab_inv = tri_minv(ab, dT, prec) |
|
|
|
ab_u = tl_dot(prec, ak, sv) + tl_dot(prec, wa, state.trans()) |
|
u = tl_dot(prec, ab_inv, ab_u) |
|
mask2 = (t >= t.trans()) |
|
qk = tl_dot(prec, wq, kwi.trans()) * mask2 |
|
qb = tl_dot(prec, wq, bwi.trans()) * mask2 |
|
yy = tl_dot(prec, qk, sv) + tl_dot(prec, qb, u) + tl_dot(prec, wq, state.trans()) |
|
tl.store(y_+IND4(bi,t,hi,i, T,H,C), yy.to(tl.bfloat16)) |
|
|
|
tl.store(s_+IND5(bi,hi,t0,i.trans(),i, H,T//dT,C,C), state.to(tl.float32)) |
|
state = state * fw + tl_dot(prec, sv.trans(), kwi*fw) + tl_dot(prec, u.trans(), bwi*fw) |
|
tl.store(sT_+IND4(bi,hi,i.trans(),i, H,C,C), state.to(tl.bfloat16)) |
|
|
|
@triton.jit |
|
def bw_attn_triton(w_,q_,k_,v_,a_,b_, dy_,s_,dsT_, dw_,dq_,dk_,dv_,da_,db_,ds0_, B:tl.constexpr,T:tl.constexpr,H:tl.constexpr,C:tl.constexpr,dT:tl.constexpr, prec:tl.constexpr): |
|
bi = tl.program_id(1) |
|
hi = tl.program_id(0) |
|
|
|
i = tl.arange(0,C)[None,:] |
|
dstate = tl.load(dsT_+IND4(bi,hi,i.trans(),i, H,C,C)).to(tl.float32) |
|
|
|
for t0 in range(T//dT-1,-1,-1): |
|
t = t0*dT+tl.arange(0,dT)[:,None] |
|
|
|
state = tl.load(s_+IND5(bi,hi,t0,i.trans(),i, H,T//dT,C,C)).to(tl.float32) |
|
|
|
sw = tl.load(w_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) |
|
sq = tl.load(q_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) |
|
sk = tl.load(k_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) |
|
sv = tl.load(v_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) |
|
sa = tl.load(a_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) |
|
sb = tl.load(b_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) |
|
sdy = tl.load(dy_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) |
|
|
|
dw_fac = -sw.exp() |
|
w = dw_fac.exp() |
|
fw = tl.reduce(w, 0, _prod, keep_dims=True) |
|
incl_pref = tl.cumprod(w,axis=0) |
|
non_incl_pref = incl_pref / w |
|
inv_incl_pref = 1 / incl_pref |
|
|
|
wq = sq * incl_pref |
|
wa = sa * non_incl_pref |
|
kwi = sk * inv_incl_pref |
|
bwi = sb * inv_incl_pref |
|
|
|
mask1 = (t > t.trans()) |
|
ab = tl_dot(prec, wa, bwi.trans()) * mask1 |
|
ak = tl_dot(prec, wa, kwi.trans()) * mask1 |
|
|
|
ab_inv = tri_minv(ab, dT, prec) |
|
|
|
ab_u = tl_dot(prec, ak, sv) + tl_dot(prec, wa, state.trans()) |
|
u = tl_dot(prec, ab_inv, ab_u) |
|
mask2 = (t >= t.trans()) |
|
qk = tl_dot(prec, wq, kwi.trans()) * mask2 |
|
qb = tl_dot(prec, wq, bwi.trans()) * mask2 |
|
|
|
du = tl_dot(prec, qb.trans(), sdy) + tl_dot(prec, bwi*fw, dstate.trans()) |
|
dab_u = tl_dot(prec, ab_inv.trans(), du) |
|
|
|
dv = tl_dot(prec, qk.trans(), sdy) + tl_dot(prec, kwi*fw, dstate.trans()) + tl_dot(prec, ak.trans(), dab_u) |
|
tl.store(dv_+IND4(bi,t,hi,i, T,H,C), dv.to(tl.bfloat16)) |
|
|
|
dab = tl_dot(prec, tl_dot(prec, ab_inv.trans(), du), u.trans()) * mask1 |
|
dak = tl_dot(prec, dab_u, sv.trans()) * mask1 |
|
dab_u_state = tl_dot(prec, dab_u, state) |
|
da = non_incl_pref * (tl_dot(prec, dab, bwi) + tl_dot(prec, dak, kwi) + dab_u_state) |
|
tl.store(da_+IND4(bi,t,hi,i, T,H,C), da.to(tl.bfloat16)) |
|
|
|
dqb = tl_dot(prec, sdy, u.trans()) * mask2 |
|
dqk = tl_dot(prec, sdy, sv.trans()) * mask2 |
|
dy_state = tl_dot(prec, sdy, state) |
|
dq = incl_pref * (tl_dot(prec, dqb, bwi) + tl_dot(prec, dqk, kwi) + dy_state) |
|
tl.store(dq_+IND4(bi,t,hi,i, T,H,C), dq.to(tl.bfloat16)) |
|
|
|
fw_u_dstate = fw * tl_dot(prec, u, dstate) |
|
db = inv_incl_pref * (tl_dot(prec, dab.trans(), wa) + tl_dot(prec, dqb.trans(), wq) + fw_u_dstate) |
|
tl.store(db_+IND4(bi,t,hi,i, T,H,C), db.to(tl.bfloat16)) |
|
|
|
fw_v_dstate = fw * tl_dot(prec, sv, dstate) |
|
dk = inv_incl_pref * (tl_dot(prec, dak.trans(), wa) + tl_dot(prec, dqk.trans(), wq) + fw_v_dstate) |
|
tl.store(dk_+IND4(bi,t,hi,i, T,H,C), dk.to(tl.bfloat16)) |
|
|
|
dw0 = fw * tl.sum(state*dstate, axis=0,keep_dims=True) |
|
for k in range(t0*dT,t0*dT+dT): |
|
lmask = (t<k).trans() |
|
A = (tl_dot(prec, dab*lmask, bwi) + tl_dot(prec, dak*lmask, kwi)) * wa * (t>k) |
|
A += (tl_dot(prec, dqb*lmask, bwi) + tl_dot(prec, dqk*lmask, kwi)) * wq * (t>=k) |
|
A += (fw_v_dstate*kwi + fw_u_dstate*bwi) * (t<k) |
|
A += dab_u_state*wa * (t>k) + dy_state*wq * (t>=k) |
|
dw = tl.sum(A, axis=0,keep_dims=True) + dw0 |
|
|
|
wk = tl.load(w_+IND4(bi,k,hi,i, T,H,C)).to(tl.float32) |
|
dw *= -wk.exp() |
|
tl.store(dw_+IND4(bi,k,hi,i, T,H,C), dw.to(tl.bfloat16)) |
|
|
|
dstate = dstate * fw + tl_dot(prec, sdy.trans(), wq) + tl_dot(prec, dab_u.trans(), wa) |
|
tl.store(ds0_+IND4(bi,hi,i.trans(),i, H,C,C), dstate.to(tl.bfloat16)) |
|
|
|
class TritonRWKV7(th.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, w,q,k,v,z,b,s0, dot_prec): |
|
K = 16 |
|
B,T,H,C = w.shape |
|
s0 = th.zeros(B,H,C,C, dtype=w.dtype,device=w.device) if s0 is None else s0 |
|
y = th.empty_like(v) |
|
sT = th.empty_like(s0) |
|
s = th.zeros(B,H,T//K,C,C, dtype=th.float32,device=w.device) |
|
fw_attn_triton[(H,B)](w,q,k,v,z,b, s0,y,s,sT, B,T,H,C,K, dot_prec) |
|
ctx.dot_prec = dot_prec |
|
ctx.save_for_backward(w,q,k,v,z,b,s) |
|
return y, sT |
|
@staticmethod |
|
def backward(ctx, dy, dsT): |
|
K = 16 |
|
w,q,k,v,z,b,s = ctx.saved_tensors |
|
B,T,H,C = w.shape |
|
dw,dq,dk,dv,dz,db,ds0 = [th.empty_like(x) for x in [w,q,k,v,z,b,dsT]] |
|
bw_attn_triton[(H,B)](w,q,k,v,z,b, dy,s,dsT, dw,dq,dk,dv,dz,db,ds0, B,T,H,C,K, ctx.dot_prec) |
|
return dw,dq,dk,dv,dz,db,ds0,None |
|
|
|
|
|
|
|
|
|
|
|
@triton.autotune(configs=[triton.Config({'dC': dC}, num_stages=1) for dC in [16,32,64]], key=['T','H','C','dT','prec']) |
|
@triton.jit |
|
def fw_attn_triton_bighead(w_,q_,k_,v_,a_,b_, s0_,y_,s_,sT_, wq_,wa_,kwi_,bwi_,fw_, B:tl.constexpr,T:tl.constexpr,H:tl.constexpr,C:tl.constexpr,dT:tl.constexpr, prec:tl.constexpr, dC:tl.constexpr): |
|
tl.static_assert(C%dC == 0) |
|
bi = tl.program_id(1) |
|
hi = tl.program_id(0) |
|
for i0 in range(0,C,dC): |
|
i = i0+tl.arange(0,dC)[None,:] |
|
for j0 in range(0,C,dC): |
|
j = j0+tl.arange(0,dC)[None,:] |
|
state = tl.load(s0_+IND4(bi,hi,i.trans(),j, H,C,C)).to(tl.float32) |
|
tl.store(s_+IND5(bi,hi,0,i.trans(),j, H,T//dT,C,C), state.to(tl.float32)) |
|
|
|
for t0 in range(T//dT): |
|
dt = tl.arange(0,dT)[:,None] |
|
t = t0*dT+dt |
|
tl.debug_barrier() |
|
for j0 in range(0,C,dC): |
|
j = j0+tl.arange(0,dC)[None,:] |
|
sw = tl.load(w_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) |
|
sq = tl.load(q_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) |
|
sk = tl.load(k_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) |
|
sa = tl.load(a_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) |
|
sb = tl.load(b_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) |
|
|
|
w = (-sw.exp()).exp() |
|
fw = tl.reduce(w, 0, _prod, keep_dims=True) |
|
incl_pref = tl.cumprod(w,axis=0) |
|
non_incl_pref = incl_pref / w |
|
inv_incl_pref = 1 / incl_pref |
|
|
|
wq = sq * incl_pref |
|
wa = sa * non_incl_pref |
|
kwi = sk * inv_incl_pref |
|
bwi = sb * inv_incl_pref |
|
|
|
tl.store(wq_+IND4(bi,hi,dt,j, H,dT,C), wq.to(tl.float32)) |
|
tl.store(wa_+IND4(bi,hi,dt,j, H,dT,C), wa.to(tl.float32)) |
|
tl.store(kwi_+IND4(bi,hi,dt,j, H,dT,C), kwi.to(tl.float32)) |
|
tl.store(bwi_+IND4(bi,hi,dt,j, H,dT,C), bwi.to(tl.float32)) |
|
tl.store(fw_+IND3(bi,hi,j, H,C), fw.to(tl.float32)) |
|
tl.debug_barrier() |
|
|
|
ab = tl.zeros((dT,dT), tl.float32) |
|
ak = tl.zeros((dT,dT), tl.float32) |
|
qb = tl.zeros((dT,dT), tl.float32) |
|
qk = tl.zeros((dT,dT), tl.float32) |
|
for j0 in range(0,C,dC): |
|
j = j0+tl.arange(0,dC)[None,:] |
|
|
|
wa = tl.load(wa_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) |
|
wq = tl.load(wq_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) |
|
bwi = tl.load(bwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) |
|
kwi = tl.load(kwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) |
|
|
|
sa = tl.load(a_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) |
|
sb = tl.load(b_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) |
|
|
|
ab += tl_dot(prec, wa, bwi.trans()) |
|
ak += tl_dot(prec, wa, kwi.trans()) |
|
qb += tl_dot(prec, wq, bwi.trans()) |
|
qk += tl_dot(prec, wq, kwi.trans()) |
|
|
|
mask1 = (t > t.trans()) |
|
mask2 = (t >= t.trans()) |
|
ab *= mask1 |
|
ak *= mask1 |
|
qb *= mask2 |
|
qk *= mask2 |
|
|
|
ab_inv = tri_minv(ab, dT, prec) |
|
|
|
for i0 in range(0,C,dC): |
|
i = i0+tl.arange(0,dC)[None,:] |
|
sv = tl.load(v_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) |
|
|
|
wa_state = tl.zeros((dT,dC), tl.float32) |
|
wq_state = tl.zeros((dT,dC), tl.float32) |
|
for j0 in range(0,C,dC): |
|
j = j0+tl.arange(0,dC)[None,:] |
|
state = tl.load(s_+IND5(bi,hi,t0,i.trans(),j, H,T//dT,C,C)).to(tl.float32) |
|
wa = tl.load(wa_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) |
|
wq = tl.load(wq_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) |
|
wa_state += tl_dot(prec, wa, state.trans()) |
|
wq_state += tl_dot(prec, wq, state.trans()) |
|
|
|
ab_u = tl_dot(prec, ak, sv) + wa_state |
|
u = tl_dot(prec, ab_inv, ab_u) |
|
yy = tl_dot(prec, qk, sv) + tl_dot(prec, qb, u) + wq_state |
|
tl.store(y_+IND4(bi,t,hi,i, T,H,C), yy.to(tl.bfloat16)) |
|
|
|
for j0 in range(0,C,dC): |
|
j = j0+tl.arange(0,dC)[None,:] |
|
state = tl.load(s_+IND5(bi,hi,t0,i.trans(),j, H,T//dT,C,C)).to(tl.float32) |
|
kwi = tl.load(kwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) |
|
bwi = tl.load(bwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) |
|
fw = tl.load(fw_+IND3(bi,hi,j, H,C)) |
|
|
|
state = state * fw + tl_dot(prec, sv.trans(), kwi*fw) + tl_dot(prec, u.trans(), bwi*fw) |
|
|
|
if t0+1 < T//dT: |
|
tl.store(s_+IND5(bi,hi,t0+1,i.trans(),j, H,T//dT,C,C), state.to(tl.float32)) |
|
else: |
|
tl.store(sT_+IND4(bi,hi,i.trans(),j, H,C,C), state.to(tl.bfloat16)) |
|
|
|
|
|
@triton.autotune(configs=[triton.Config({'dC': dC}, num_stages=1) for dC in [16,32,64]], key=['T','H','C','dT','prec']) |
|
@triton.jit |
|
def bw_attn_triton_bighead(w_,q_,k_,v_,a_,b_, dy_,s_,dsT_,ds_, dw_,dq_,dk_,dv_,da_,db_,ds0_, wq_,wa_,kwi_,bwi_,fw_,u_,dab_u_, B:tl.constexpr,T:tl.constexpr,H:tl.constexpr,C:tl.constexpr,dT:tl.constexpr, prec:tl.constexpr, dC:tl.constexpr): |
|
tl.static_assert(C%dC == 0) |
|
bi = tl.program_id(1) |
|
hi = tl.program_id(0) |
|
|
|
for i0 in range(0,C,dC): |
|
i = i0+tl.arange(0,dC)[None,:] |
|
for j0 in range(0,C,dC): |
|
j = j0+tl.arange(0,dC)[None,:] |
|
dstate = tl.load(dsT_+IND4(bi,hi,i.trans(),j, H,C,C)).to(tl.float32) |
|
tl.store(ds_+IND4(bi,hi,i.trans(),j, H,C,C), dstate.to(tl.float32)) |
|
|
|
for t0 in range(T//dT-1,-1,-1): |
|
dt = tl.arange(0,dT)[:,None] |
|
t = t0*dT+dt |
|
tl.debug_barrier() |
|
for j0 in range(0,C,dC): |
|
j = j0+tl.arange(0,dC)[None,:] |
|
sw = tl.load(w_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) |
|
sq = tl.load(q_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) |
|
sk = tl.load(k_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) |
|
sa = tl.load(a_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) |
|
sb = tl.load(b_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) |
|
|
|
w = (-sw.exp()).exp() |
|
fw = tl.reduce(w, 0, _prod, keep_dims=True) |
|
incl_pref = tl.cumprod(w,axis=0) |
|
non_incl_pref = incl_pref / w |
|
inv_incl_pref = 1 / incl_pref |
|
|
|
wq = sq * incl_pref |
|
wa = sa * non_incl_pref |
|
kwi = sk * inv_incl_pref |
|
bwi = sb * inv_incl_pref |
|
|
|
tl.store(wq_+IND4(bi,hi,dt,j, H,dT,C), wq.to(tl.float32)) |
|
tl.store(wa_+IND4(bi,hi,dt,j, H,dT,C), wa.to(tl.float32)) |
|
tl.store(kwi_+IND4(bi,hi,dt,j, H,dT,C), kwi.to(tl.float32)) |
|
tl.store(bwi_+IND4(bi,hi,dt,j, H,dT,C), bwi.to(tl.float32)) |
|
tl.store(fw_+IND3(bi,hi,j, H,C), fw.to(tl.float32)) |
|
tl.debug_barrier() |
|
|
|
ab = tl.zeros((dT,dT), tl.float32) |
|
ak = tl.zeros((dT,dT), tl.float32) |
|
qb = tl.zeros((dT,dT), tl.float32) |
|
qk = tl.zeros((dT,dT), tl.float32) |
|
for j0 in range(0,C,dC): |
|
j = j0+tl.arange(0,dC)[None,:] |
|
|
|
wa = tl.load(wa_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) |
|
wq = tl.load(wq_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) |
|
bwi = tl.load(bwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) |
|
kwi = tl.load(kwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) |
|
|
|
sa = tl.load(a_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) |
|
sb = tl.load(b_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) |
|
|
|
ab += tl_dot(prec, wa, bwi.trans()) |
|
ak += tl_dot(prec, wa, kwi.trans()) |
|
qb += tl_dot(prec, wq, bwi.trans()) |
|
qk += tl_dot(prec, wq, kwi.trans()) |
|
|
|
mask1 = (t > t.trans()) |
|
mask2 = (t >= t.trans()) |
|
ab *= mask1 |
|
ak *= mask1 |
|
qb *= mask2 |
|
qk *= mask2 |
|
|
|
ab_inv = tri_minv(ab, dT, prec) |
|
|
|
dab = tl.zeros((dT,dT), tl.float32) |
|
dak = tl.zeros((dT,dT), tl.float32) |
|
dqb = tl.zeros((dT,dT), tl.float32) |
|
dqk = tl.zeros((dT,dT), tl.float32) |
|
|
|
tl.debug_barrier() |
|
for i0 in range(0,C,dC): |
|
i = i0+tl.arange(0,dC)[None,:] |
|
wa_state = tl.zeros((dT,dC), tl.float32) |
|
bwi_dw_dstate = tl.zeros((dT,dC), tl.float32) |
|
kwi_dw_dstate = tl.zeros((dT,dC), tl.float32) |
|
for j0 in range(0,C,dC): |
|
j = j0+tl.arange(0,dC)[None,:] |
|
state = tl.load(s_+IND5(bi,hi,t0,i.trans(),j, H,T//dT,C,C)).to(tl.float32) |
|
dstate = tl.load(ds_+IND4(bi,hi,i.trans(),j, H,C,C)).to(tl.float32) |
|
wa = tl.load(wa_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) |
|
bwi = tl.load(bwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) |
|
kwi = tl.load(kwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) |
|
fw = tl.load(fw_+IND3(bi,hi,j, H,C)) |
|
|
|
wa_state += tl_dot(prec, wa, state.trans()) |
|
bwi_dw_dstate += tl_dot(prec, bwi*fw, dstate.trans()) |
|
kwi_dw_dstate += tl_dot(prec, kwi*fw, dstate.trans()) |
|
|
|
sv = tl.load(v_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) |
|
sdy = tl.load(dy_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) |
|
|
|
ab_u = tl_dot(prec, ak, sv) + wa_state |
|
u = tl_dot(prec, ab_inv, ab_u) |
|
du = tl_dot(prec, qb.trans(), sdy) + bwi_dw_dstate |
|
dab_u = tl_dot(prec, ab_inv.trans(), du) |
|
|
|
tl.store(u_+IND4(bi,hi,dt,i, H,dT,C), u.to(tl.float32)) |
|
tl.store(dab_u_+IND4(bi,hi,dt,i, H,dT,C), dab_u.to(tl.float32)) |
|
|
|
dv = tl_dot(prec, qk.trans(), sdy) + kwi_dw_dstate + tl_dot(prec, ak.trans(), dab_u) |
|
tl.store(dv_+IND4(bi,t,hi,i, T,H,C), dv.to(tl.bfloat16)) |
|
|
|
dab += tl_dot(prec, dab_u, u.trans()) * mask1 |
|
dak += tl_dot(prec, dab_u, sv.trans()) * mask1 |
|
dqb += tl_dot(prec, sdy, u.trans()) * mask2 |
|
dqk += tl_dot(prec, sdy, sv.trans()) * mask2 |
|
tl.debug_barrier() |
|
|
|
for j0 in range(0,C,dC): |
|
j = j0+tl.arange(0,dC)[None,:] |
|
|
|
dy_state = tl.zeros((dT,dC), tl.float32) |
|
dab_u_state = tl.zeros((dT,dC), tl.float32) |
|
fw_u_dstate = tl.zeros((dT,dC), tl.float32) |
|
fw_v_dstate = tl.zeros((dT,dC), tl.float32) |
|
state_dstate = tl.zeros((1,dC), tl.float32) |
|
|
|
fw = tl.load(fw_+IND3(bi,hi,j, H,C)) |
|
wa = tl.load(wa_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) |
|
wq = tl.load(wq_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) |
|
for i0 in range(0,C,dC): |
|
i = i0+tl.arange(0,dC)[None,:] |
|
|
|
u = tl.load(u_+IND4(bi,hi,dt,i, H,dT,C)).to(tl.float32) |
|
dab_u = tl.load(dab_u_+IND4(bi,hi,dt,i, H,dT,C)).to(tl.float32) |
|
sv = tl.load(v_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) |
|
sdy = tl.load(dy_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) |
|
|
|
state = tl.load(s_+IND5(bi,hi,t0,i.trans(),j, H,T//dT,C,C)).to(tl.float32) |
|
tl.debug_barrier() |
|
dstate = tl.load(ds_+IND4(bi,hi,i.trans(),j, H,C,C)).to(tl.float32) |
|
tl.debug_barrier() |
|
|
|
dab_u_state += tl_dot(prec, dab_u, state) |
|
fw_u_dstate += fw * tl_dot(prec, u, dstate) |
|
fw_v_dstate += fw * tl_dot(prec, sv, dstate) |
|
dy_state += tl_dot(prec, sdy, state) |
|
|
|
state_dstate += tl.sum(state*dstate, axis=0,keep_dims=True) |
|
|
|
dstate = dstate * fw + tl_dot(prec, sdy.trans(), wq) + tl_dot(prec, dab_u.trans(), wa) |
|
if t0 > 0: |
|
tl.store(ds_+IND4(bi,hi,i.trans(),j, H,C,C), dstate.to(tl.float32)) |
|
else: |
|
tl.store(ds0_+IND4(bi,hi,i.trans(),j, H,C,C), dstate.to(tl.bfloat16)) |
|
|
|
sw = tl.load(w_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) |
|
w = (-sw.exp()).exp() |
|
incl_pref = tl.cumprod(w,axis=0) |
|
non_incl_pref = incl_pref / w |
|
inv_incl_pref = 1 / incl_pref |
|
|
|
bwi = tl.load(bwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) |
|
kwi = tl.load(kwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) |
|
|
|
da = non_incl_pref * (tl_dot(prec, dab, bwi) + tl_dot(prec, dak, kwi) + dab_u_state) |
|
tl.store(da_+IND4(bi,t,hi,j, T,H,C), da.to(tl.bfloat16)) |
|
|
|
dq = incl_pref * (tl_dot(prec, dqb, bwi) + tl_dot(prec, dqk, kwi) + dy_state) |
|
tl.store(dq_+IND4(bi,t,hi,j, T,H,C), dq.to(tl.bfloat16)) |
|
|
|
db = inv_incl_pref * (tl_dot(prec, dab.trans(), wa) + tl_dot(prec, dqb.trans(), wq) + fw_u_dstate) |
|
tl.store(db_+IND4(bi,t,hi,j, T,H,C), db.to(tl.bfloat16)) |
|
|
|
dk = inv_incl_pref * (tl_dot(prec, dak.trans(), wa) + tl_dot(prec, dqk.trans(), wq) + fw_v_dstate) |
|
tl.store(dk_+IND4(bi,t,hi,j, T,H,C), dk.to(tl.bfloat16)) |
|
|
|
dw0 = fw * state_dstate |
|
for k in range(t0*dT,t0*dT+dT): |
|
lmask = (t<k).trans() |
|
A = (tl_dot(prec, dab*lmask, bwi) + tl_dot(prec, dak*lmask, kwi)) * wa * (t>k) |
|
A += (tl_dot(prec, dqb*lmask, bwi) + tl_dot(prec, dqk*lmask, kwi)) * wq * (t>=k) |
|
A += (fw_v_dstate*kwi + fw_u_dstate*bwi) * (t<k) |
|
A += dab_u_state*wa * (t>k) + dy_state*wq * (t>=k) |
|
dw = tl.sum(A, axis=0,keep_dims=True) + dw0 |
|
|
|
wk = tl.load(w_+IND4(bi,k,hi,j, T,H,C)).to(tl.float32) |
|
dw *= -wk.exp() |
|
tl.store(dw_+IND4(bi,k,hi,j, T,H,C), dw.to(tl.bfloat16)) |
|
|
|
class TritonBigheadRWKV7(th.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, w,q,k,v,a,b,s0, dot_prec): |
|
K = 16 |
|
B,T,H,C = w.shape |
|
assert T%K == 0 |
|
assert C%16 == 0 |
|
s0 = th.zeros(B,H,C,C, dtype=w.dtype,device=w.device) if s0 is None else s0 |
|
y = th.empty_like(v) |
|
sT = th.empty_like(s0) |
|
s = th.zeros(B,H,T//K,C,C, dtype=th.float32,device=w.device) |
|
wq,wa,kwi,bwi = [th.empty(B,H,K,C, dtype=th.float32,device=w.device) for i in range(4)] |
|
fw = th.empty(B,H,C, dtype=th.float32,device=w.device) |
|
fw_attn_triton_bighead[(H,B)](w,q,k,v,a,b, s0,y,s,sT, wq,wa,kwi,bwi,fw, B,T,H,C,K, dot_prec) |
|
ctx.dot_prec = dot_prec |
|
ctx.save_for_backward(w,q,k,v,a,b,s) |
|
return y, sT |
|
@staticmethod |
|
def backward(ctx, dy, dsT): |
|
K = 16 |
|
w,q,k,v,a,b,s = ctx.saved_tensors |
|
B,T,H,C = w.shape |
|
dw,dq,dk,dv,da,db,ds0 = [th.empty_like(x) for x in [w,q,k,v,a,b,dsT]] |
|
fw = th.empty(B,H,C, dtype=th.float32,device=w.device) |
|
ds = th.empty(B,H,C,C, dtype=th.float32,device=w.device) |
|
wq,wa,kwi,bwi,u,dab_u = [th.empty(B,H,K,C, dtype=th.float32,device=w.device) for i in range(6)] |
|
bw_attn_triton_bighead[(H,B)](w,q,k,v,a,b, dy,s,dsT,ds, dw,dq,dk,dv,da,db,ds0, wq,wa,kwi,bwi,fw,u,dab_u, B,T,H,C,K, ctx.dot_prec) |
|
return dw,dq,dk,dv,da,db,ds0,None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def rwkv7_attn_triton(r,w,k,v, kk,iclr, HEAD_SIZE=64, dot_prec='fp32', s0=None): |
|
B,T,HC = w.shape |
|
|
|
|
|
chunk_remainder = T % 16 |
|
|
|
|
|
if chunk_remainder == 0: |
|
return rwkv7_attn_triton_chunk(r,w,k,v, kk,iclr, HEAD_SIZE, dot_prec, s0) |
|
|
|
|
|
C = HEAD_SIZE |
|
H = HC//C |
|
s0 = th.zeros(B,H,C,C, dtype=th.float,device=w.device) if s0 is None else s0 |
|
|
|
|
|
chunks = T // 16 |
|
si = chunks * 16 |
|
|
|
|
|
chunk_xx, chunk_sT = rwkv7_attn_triton_chunk( |
|
r[:,:si],w[:,:si],k[:,:si],v[:,:si], kk[:,:si],iclr[:,:si], |
|
HEAD_SIZE, dot_prec, s0 |
|
) |
|
|
|
|
|
remain_xx, last_sT = rwkv7_attn_pytorch_chunk( |
|
r[:,si:],torch.exp(-torch.exp(w[:,si:])),k[:,si:],v[:,si:], kk[:,si:],iclr[:,si:], |
|
B, H, C, torch.zeros(B, chunk_remainder, HC, device=w.device, dtype=w.dtype), |
|
chunk_sT, chunk_size=chunk_remainder |
|
) |
|
|
|
|
|
return torch.cat([chunk_xx.to(dtype=w.dtype), remain_xx.to(dtype=w.dtype)], dim=1), last_sT.to(dtype=s0.dtype) |
|
|
|
|
|
def rwkv7_attn_triton_chunk(r,w,k,v, kk,iclr, HEAD_SIZE=64, dot_prec='fp32', s0=None): |
|
''' |
|
Triton implementation running in blocks of 16 (hardcoded requirement for the kernel) |
|
''' |
|
B,T,HC = w.shape |
|
assert T % 16 == 0, 'pure triton, only works in multiple of 16' |
|
C = HEAD_SIZE |
|
H = HC//C |
|
|
|
|
|
r,w,k,v,a,b = [i.view(B,T,H,C) for i in [r,w,k,v,-kk,(kk*iclr)]] |
|
s0 = th.zeros(B,H,C,C, dtype=th.float,device=w.device) if s0 is None else s0 |
|
xx, sT = TritonRWKV7.apply(w,r,k,v,a,b,s0,dot_prec) |
|
return xx.view(B,T,HC), sT |
|
|
|
|
|
|
|
|
|
|
|
def rwkv7_attn_triton_bighead(r,w,k,v, kk,iclr, HEAD_SIZE=64, dot_prec='fp32', s0=None): |
|
B,T,HC = w.shape |
|
|
|
|
|
chunk_remainder = T % 16 |
|
|
|
|
|
if chunk_remainder == 0: |
|
return rwkv7_attn_triton_bighead_chunk(r,w,k,v, kk,iclr, HEAD_SIZE, dot_prec, s0) |
|
|
|
|
|
C = HEAD_SIZE |
|
H = HC//C |
|
s0 = th.zeros(B,H,C,C, dtype=th.float,device=w.device) if s0 is None else s0 |
|
|
|
|
|
chunks = T // 16 |
|
si = chunks * 16 |
|
|
|
|
|
chunk_xx, chunk_sT = rwkv7_attn_triton_bighead_chunk( |
|
r[:,:si],w[:,:si],k[:,:si],v[:,:si], kk[:,:si],iclr[:,:si], |
|
HEAD_SIZE, dot_prec, s0 |
|
) |
|
|
|
|
|
remain_xx, last_sT = rwkv7_attn_pytorch_chunk( |
|
r[:,si:],torch.exp(-torch.exp(w[:,si:])),k[:,si:],v[:,si:], kk[:,si:],iclr[:,si:], |
|
B, H, C, torch.zeros(B, chunk_remainder, HC, device=w.device, dtype=w.dtype), |
|
chunk_sT, chunk_size=chunk_remainder |
|
) |
|
|
|
|
|
return torch.cat([chunk_xx.to(dtype=w.dtype), remain_xx.to(dtype=w.dtype)], dim=1), last_sT.to(dtype=s0.dtype) |
|
|
|
|
|
def rwkv7_attn_triton_bighead_chunk(r,w,k,v, kk,iclr, HEAD_SIZE=64, dot_prec='fp32', s0=None): |
|
''' |
|
Triton implementation running in blocks of 16 (hardcoded requirement for the kernel) |
|
''' |
|
B,T,HC = w.shape |
|
assert T % 16 == 0, 'pure triton, only works in multiple of 16' |
|
C = HEAD_SIZE |
|
H = HC//C |
|
|
|
|
|
r,w,k,v,a,b = [i.view(B,T,H,C) for i in [r,w,k,v,-kk,(kk*iclr)]] |
|
s0 = th.zeros(B,H,C,C, dtype=th.float,device=w.device) if s0 is None else s0 |
|
xx, sT = TritonBigheadRWKV7.apply(w,r,k,v,a,b,s0,dot_prec) |
|
return xx.view(B,T,HC), sT |
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass |
|
from typing import Optional |
|
from typing import Union |
|
import torch |
|
|
|
@dataclass |
|
class RWKV7BlockConfigMap: |
|
|
|
"""Configuration map for RWKV based models""" |
|
|
|
num_hidden_layers: int |
|
hidden_size: int |
|
|
|
head_size: int = 64 |
|
|
|
|
|
dropout_rate: float = 0.0 |
|
|
|
|
|
tmix_backend: str = "auto" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
layer_id: Optional[int] = None |
|
|
|
|
|
device: Union[torch.device, str, None] = None |
|
dtype: Union[torch.dtype, str, None] = None |
|
|
|
|
|
hidden_size_ffn: Optional[int] = None |
|
hidden_size_att: Optional[int] = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__( |
|
self, |
|
num_hidden_layers: int, |
|
hidden_size: int, |
|
head_size: int = 64, |
|
dropout_rate: float = 0.0, |
|
tmix_backend: str = "auto", |
|
layer_id: Optional[int] = None, |
|
device: Union[torch.device, str, None] = None, |
|
dtype: Union[torch.dtype, str, None] = None, |
|
hidden_size_ffn: Optional[int] = None, |
|
hidden_size_att: Optional[int] = None, |
|
**kwargs |
|
) -> None: |
|
''' |
|
Constructor for the config |
|
''' |
|
self.num_hidden_layers = num_hidden_layers |
|
self.hidden_size = hidden_size |
|
self.head_size = head_size |
|
self.dropout_rate = dropout_rate |
|
self.tmix_backend = tmix_backend |
|
self.layer_id = layer_id |
|
self.device = device |
|
self.dtype = dtype |
|
self.hidden_size_ffn = hidden_size_ffn |
|
self.hidden_size_att = hidden_size_att |
|
|
|
|
|
|
|
|
|
|
|
def get_layer_id(self, fallback:int) -> int: |
|
''' |
|
Returns the layer id |
|
''' |
|
if self.layer_id is not None: |
|
return self.layer_id |
|
return fallback |
|
|
|
def get_device(self, fallback:str) -> torch.device: |
|
''' |
|
Returns the device |
|
''' |
|
if self.device is not None: |
|
return torch.device(self.device) |
|
if fallback is not None: |
|
return torch.device(fallback) |
|
return torch.get_default_device() |
|
|
|
def get_dtype(self, fallback:str) -> torch.dtype: |
|
''' |
|
Returns the dtype |
|
''' |
|
if self.dtype is not None: |
|
key = self.dtype |
|
else: |
|
key = fallback |
|
|
|
|
|
if isinstance(key, torch.dtype): |
|
return key |
|
|
|
|
|
ret = getattr(torch, key) |
|
assert isinstance(ret, torch.dtype), f"Invalid dtype: {self.dtype}" |
|
return ret |
|
|
|
|
|
|
|
def get_hidden_size_att(self) -> int: |
|
''' |
|
Returns the dimension of attention |
|
''' |
|
if self.hidden_size_att is not None: |
|
hidden_size_att = self.hidden_size_att |
|
else: |
|
hidden_size = self.hidden_size |
|
assert hidden_size % 32 == 0, f"hidden_size must be divisible by 32" |
|
hidden_size_att = hidden_size |
|
assert hidden_size_att % 32 == 0, f"hidden_size_att must be divisible by 32 ({hidden_size_att})" |
|
return hidden_size_att |
|
|
|
def get_hidden_size_ffn(self) -> int: |
|
''' |
|
Returns the dimension of feed forward network |
|
''' |
|
if self.hidden_size_ffn is not None: |
|
hidden_size_ffn = self.hidden_size_ffn |
|
else: |
|
hidden_size = self.hidden_size |
|
assert hidden_size % 32 == 0, f"hidden_size must be divisible by 32" |
|
hidden_size_ffn = hidden_size * 4 |
|
|
|
assert hidden_size_ffn % 32 == 0, f"hidden_size_ffn must be divisible by 32" |
|
return hidden_size_ffn |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def new_block_config_map(self, **kwargs) -> 'RWKV7BlockConfigMap': |
|
''' |
|
Returns a new config map with updated values |
|
''' |
|
|
|
new_dict = {} |
|
for key in RWKV7BlockConfigMap.__dataclass_fields__: |
|
if key in self.__dict__: |
|
new_dict[key] = self.__dict__[key] |
|
new_dict.update(kwargs) |
|
|
|
return RWKV7BlockConfigMap(**new_dict) |
|
|
|
@staticmethod |
|
def normalize(config_map: any) -> 'RWKV7BlockConfigMap': |
|
''' |
|
Converts either maps, objs or RWKV7BlockConfigMap |
|
''' |
|
if isinstance(config_map, RWKV7BlockConfigMap): |
|
return config_map |
|
|
|
dict_obj = None |
|
if isinstance(config_map, dict): |
|
dict_obj = config_map |
|
elif hasattr(config_map, '__dict__'): |
|
dict_obj = config_map.__dict__ |
|
|
|
if dict_obj is not None: |
|
|
|
new_dict = {} |
|
for key, value in dict_obj.items(): |
|
if key in RWKV7BlockConfigMap.__dataclass_fields__: |
|
new_dict[key] = value |
|
return RWKV7BlockConfigMap(**new_dict) |
|
|
|
raise ValueError(f"Unsupported config_map type: {type(config_map)}") |
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from torch import nn |
|
from typing import Union |
|
|
|
|
|
class RWKV7ChannelMix(torch.nn.Module): |
|
''' |
|
ChannelMix block for RWKV |
|
This is similar to transformer FFN block |
|
''' |
|
|
|
def __init__(self, configMap: Union[RWKV7BlockConfigMap, any]): |
|
''' |
|
Initialize the ChannelMix block. |
|
|
|
Note: this does not initialize the parameter weights itself |
|
which would depend on the `init_parameters()` method |
|
''' |
|
|
|
super().__init__() |
|
|
|
configMap:RWKV7BlockConfigMap = RWKV7BlockConfigMap.normalize(configMap) |
|
self.configMap = configMap |
|
|
|
|
|
hidden_size = configMap.hidden_size |
|
device = configMap.get_device(None) |
|
dtype = configMap.get_dtype('bfloat16') |
|
|
|
|
|
hidden_size_ffn = configMap.get_hidden_size_ffn() |
|
|
|
|
|
|
|
self.x_k = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype)) |
|
self.key = nn.Linear(hidden_size, hidden_size_ffn, bias=False, device=device, dtype=dtype) |
|
self.value = nn.Linear(hidden_size_ffn, hidden_size, bias=False, device=device, dtype=dtype) |
|
|
|
def init_parameters(self): |
|
''' |
|
Reset the parameters of the block, to an initial state used for training a model from scratch |
|
''' |
|
|
|
|
|
configMap = self.configMap |
|
hidden_size = configMap.hidden_size |
|
num_hidden_layers = configMap.num_hidden_layers |
|
|
|
|
|
layer_id = configMap.get_layer_id(0) |
|
device = configMap.get_device(None) |
|
dtype = configMap.get_dtype('bfloat16') |
|
|
|
|
|
hidden_size_ffn = configMap.get_hidden_size_ffn() |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) |
|
ddd = torch.ones(1, 1, hidden_size) |
|
for i in range(hidden_size): |
|
ddd[0, 0, i] = i / hidden_size |
|
self.x_k = nn.Parameter( (1.0 - torch.pow(ddd, ratio_1_to_almost0**4)).to(device, dtype=dtype) ) |
|
|
|
self.key = nn.Linear(hidden_size, hidden_size_ffn, bias=False, device=device, dtype=dtype) |
|
self.value = nn.Linear(hidden_size_ffn, hidden_size, bias=False, device=device, dtype=dtype) |
|
|
|
def forward(self, x: torch.Tensor, last_state: torch.Tensor) -> tuple[torch.Tensor,torch.Tensor]: |
|
''' |
|
Forwarding channel mix given the input tokens and states. |
|
|
|
Given: |
|
- Incoming token embedding size of shape [batch_size, seq_len, embedding_size] |
|
- Incoming channel mix, shift states of the various batches [batch_size, state_size] |
|
|
|
Returns a pair |
|
- Output embedding of shape [batch_size, seq_len, embedding_size] |
|
- Output channel mix, shift state of shape [batch_size, state_size] |
|
''' |
|
|
|
|
|
|
|
|
|
|
|
|
|
dxprev = torch.cat((last_state.unsqueeze(1), x[:, :-1]), dim=1) - x |
|
xk = x + dxprev * self.x_k |
|
k = torch.relu( self.key(xk) ) ** 2 |
|
|
|
return self.value(k), x[:,-1] |
|
|
|
@torch.compile(mode="default", fullgraph=True) |
|
def forward_with_default_compile(self, in_x: torch.Tensor, in_state: torch.Tensor, out_x: torch.Tensor, out_state: torch.Tensor) -> tuple[torch.Tensor,torch.Tensor]: |
|
''' |
|
Compiled varient of the forward function |
|
With no new tensors being created for the output |
|
Useful for static memory allocation optimizations inference |
|
''' |
|
out_x[:], out_state[:] = self.forward_with_reduce_compile(in_x, in_state) |
|
return out_x, out_state |
|
|
|
@torch.compile(mode="reduce-overhead", fullgraph=True) |
|
def forward_with_reduce_compile(self, in_x: torch.Tensor, in_state: torch.Tensor) -> tuple[torch.Tensor,torch.Tensor]: |
|
''' |
|
Compiled varient of the forward function |
|
''' |
|
return self.forward(in_x, in_state) |
|
|
|
def load_from_model_state_dict(self, model_state_dict: dict, layer_id:int, non_blocking:bool=True): |
|
''' |
|
Given the Full/partial RWKV model weights, loaded via `torch.load` |
|
Setup the the current module weights, using the layer_id |
|
''' |
|
|
|
current_state_dict = self.state_dict() |
|
|
|
|
|
for n in current_state_dict: |
|
model_key = f"blocks.{layer_id}.ffn.{n}" |
|
if model_key not in model_state_dict: |
|
continue |
|
|
|
|
|
try: |
|
current_state_dict[n].copy_(model_state_dict[model_key], non_blocking=non_blocking) |
|
except Exception as e: |
|
print(f"[ERROR] loading: {model_key} | model shape: {current_state_dict[n].shape} | weight shape: {model_state_dict[model_key].shape}") |
|
raise e |
|
|
|
|
|
|
|
|
|
import torch, math |
|
from torch import nn |
|
from torch import Tensor |
|
from typing import Union |
|
from torch.nn import functional as F |
|
|
|
|
|
|
|
|
|
triton = None |
|
if torch.cuda.is_available(): |
|
try: |
|
import triton |
|
except ImportError: |
|
triton = None |
|
else: |
|
print("[WARNING] Triton not available, falling back to pytorch mode by default - this is significantly slower") |
|
|
|
class RWKV7TimeMix(torch.nn.Module): |
|
''' |
|
Time Mix block for RWKV V7 |
|
''' |
|
|
|
def __init__(self, configMap: Union[RWKV7BlockConfigMap, any]): |
|
''' |
|
Initialize the TimeMix block. |
|
|
|
Note: this does not initialize the parameter weights itself |
|
which would depend on the `init_parameters()` method |
|
''' |
|
super().__init__() |
|
|
|
configMap:RWKV7BlockConfigMap = RWKV7BlockConfigMap.normalize(configMap) |
|
self.configMap = configMap |
|
|
|
|
|
hidden_size = configMap.hidden_size |
|
|
|
|
|
|
|
layer_id = configMap.get_layer_id(0) |
|
self.layer_id = layer_id |
|
|
|
|
|
device = configMap.get_device(None) |
|
dtype = configMap.get_dtype('bfloat16') |
|
|
|
|
|
hidden_size_att = configMap.get_hidden_size_att() |
|
|
|
|
|
assert hidden_size == hidden_size_att, "hidden_size should be equal to hidden_size_att (@TODO: support different hidden_size and hidden_size_att)" |
|
|
|
|
|
head_size = configMap.head_size |
|
self.head_size = head_size |
|
|
|
|
|
n_head = hidden_size_att // head_size |
|
assert hidden_size_att % head_size == 0, "hidden_size_att should be divisible by head_size" |
|
self.n_head = n_head |
|
|
|
|
|
self.tmix_backend = configMap.tmix_backend |
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
def calc_lora_rank(exponent, multiplier): |
|
return max(1, round(hidden_size ** exponent * multiplier / 32)) * 32 |
|
D_DECAY_LORA = calc_lora_rank(0.5, 1.8) |
|
D_AAA_LORA = calc_lora_rank(0.5, 1.8) |
|
D_MV_LORA = calc_lora_rank(0.5, 1.3) |
|
D_GATE_LORA = calc_lora_rank(0.8, 0.6) |
|
|
|
self.x_r = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype)) |
|
self.x_w = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype)) |
|
self.x_k = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype)) |
|
self.x_v = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype)) |
|
self.x_a = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype)) |
|
self.x_g = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype)) |
|
|
|
self.w0 = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype)) |
|
self.w1 = nn.Parameter(torch.empty(hidden_size, D_DECAY_LORA, device=device, dtype=dtype)) |
|
self.w2 = nn.Parameter(torch.empty(D_DECAY_LORA, hidden_size, device=device, dtype=dtype)) |
|
|
|
self.a0 = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype)) |
|
self.a1 = nn.Parameter(torch.empty(hidden_size,D_AAA_LORA, device=device, dtype=dtype)) |
|
self.a2 = nn.Parameter(torch.empty(D_AAA_LORA,hidden_size, device=device, dtype=dtype)) |
|
|
|
if layer_id > 0: |
|
self.v0 = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype)) |
|
self.v1 = nn.Parameter(torch.empty(hidden_size,D_MV_LORA, device=device, dtype=dtype)) |
|
self.v2 = nn.Parameter(torch.empty(D_MV_LORA,hidden_size, device=device, dtype=dtype)) |
|
|
|
self.g1 = nn.Parameter(torch.empty(hidden_size, D_GATE_LORA, device=device, dtype=dtype)) |
|
self.g2 = nn.Parameter(torch.empty(D_GATE_LORA, hidden_size, device=device, dtype=dtype)) |
|
|
|
self.k_k = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype)) |
|
self.k_a = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype)) |
|
self.r_k = nn.Parameter(torch.empty(n_head, head_size, device=device, dtype=dtype)) |
|
|
|
self.receptance = nn.Linear(hidden_size, hidden_size_att, bias=False, device=device, dtype=dtype) |
|
self.key = nn.Linear(hidden_size, hidden_size_att, bias=False, device=device, dtype=dtype) |
|
self.value = nn.Linear(hidden_size, hidden_size_att, bias=False, device=device, dtype=dtype) |
|
self.output = nn.Linear(hidden_size_att, hidden_size, bias=False, device=device, dtype=dtype) |
|
self.ln_x = nn.GroupNorm(n_head, hidden_size_att, device=device, dtype=dtype, eps=(1e-5)*head_size) |
|
|
|
def init_parameters(self): |
|
''' |
|
Reset the parameters of the block, to an initial state used for training a model from scratch |
|
''' |
|
configMap = self.configMap |
|
|
|
|
|
hidden_size = configMap.hidden_size |
|
num_hidden_layers = configMap.num_hidden_layers |
|
|
|
|
|
layer_id = self.layer_id |
|
|
|
|
|
device = configMap.get_device(None) |
|
dtype = configMap.get_dtype('bfloat16') |
|
|
|
|
|
hidden_size_att = configMap.get_hidden_size_att() |
|
|
|
|
|
assert hidden_size == hidden_size_att, "hidden_size should be equal to hidden_size_att (@TODO: support different hidden_size and hidden_size_att)" |
|
|
|
|
|
head_size = self.head_size |
|
|
|
|
|
n_head = hidden_size_att // head_size |
|
assert hidden_size_att % head_size == 0, "hidden_size_att should be divisible by head_size" |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
ratio_0_to_1 = layer_id / (num_hidden_layers - 1) |
|
ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) |
|
ddd = torch.ones(1, 1, hidden_size, device=device, dtype=dtype) |
|
for i in range(hidden_size): |
|
ddd[0, 0, i] = i / hidden_size |
|
|
|
|
|
def calc_lora_rank(exponent, multiplier): |
|
return max(1, round(hidden_size ** exponent * multiplier / 32)) * 32 |
|
D_DECAY_LORA = calc_lora_rank(0.5, 1.8) |
|
D_AAA_LORA = calc_lora_rank(0.5, 1.8) |
|
D_MV_LORA = calc_lora_rank(0.5, 1.3) |
|
D_GATE_LORA = calc_lora_rank(0.8, 0.6) |
|
|
|
self.x_r.copy_(1.0 - torch.pow(ddd, 0.2 * ratio_1_to_almost0)) |
|
self.x_w.copy_(1.0 - torch.pow(ddd, 0.9 * ratio_1_to_almost0)) |
|
self.x_k.copy_(1.0 - (torch.pow(ddd, 0.9 * ratio_1_to_almost0) + 0.4 * ratio_0_to_1)) |
|
self.x_v.copy_(1.0 - (torch.pow(ddd, 0.4 * ratio_1_to_almost0) + 0.6 * ratio_0_to_1)) |
|
self.x_a.copy_(1.0 - torch.pow(ddd, 0.9 * ratio_1_to_almost0)) |
|
self.x_g.copy_(1.0 - torch.pow(ddd, 0.2 * ratio_1_to_almost0)) |
|
|
|
def ortho_init(x, scale): |
|
x = x.to(device) |
|
shape = x.shape |
|
if len(shape) == 2: |
|
gain = math.sqrt(shape[0] / shape[1]) if shape[0] > shape[1] else 1 |
|
nn.init.orthogonal_(x, gain=gain * scale) |
|
elif len(shape) == 3: |
|
gain = math.sqrt(shape[1] / shape[2]) if shape[1] > shape[2] else 1 |
|
for i in range(shape[0]): |
|
nn.init.orthogonal_(x[i], gain=gain * scale) |
|
else: |
|
assert False |
|
return x.to(device, dtype=dtype) |
|
|
|
|
|
decay_speed = torch.ones(hidden_size, device=device, dtype=dtype) |
|
for n in range(hidden_size): |
|
decay_speed[n] = -7 + 5 * (n / (hidden_size - 1)) ** (0.85 + 1.0 * ratio_0_to_1 ** 0.5) |
|
|
|
self.w0.copy_(decay_speed.reshape(1,1,hidden_size).to(device, dtype=dtype) + 0.5) |
|
self.w1.copy_(torch.zeros(hidden_size, D_DECAY_LORA, device=device, dtype=dtype)) |
|
self.w2.copy_(ortho_init(torch.zeros(D_DECAY_LORA, hidden_size), 0.1)) |
|
|
|
|
|
self.a0.copy_(torch.zeros(1,1,hidden_size, device=device, dtype=dtype)) |
|
self.a1.copy_(torch.zeros(hidden_size, D_AAA_LORA, device=device, dtype=dtype)) |
|
self.a2.copy_(ortho_init(torch.zeros(D_AAA_LORA, hidden_size), 0.1)) |
|
|
|
|
|
if layer_id > 0: |
|
self.v0.copy_(torch.zeros(1,1,hidden_size, device=device, dtype=dtype)+1.0) |
|
self.v1.copy_(torch.zeros(hidden_size, D_MV_LORA, device=device, dtype=dtype)) |
|
self.v2.copy_(ortho_init(torch.zeros(D_MV_LORA, hidden_size), 0.1)) |
|
|
|
|
|
|
|
self.g1.copy_(torch.zeros(hidden_size, D_GATE_LORA, device=device, dtype=dtype)) |
|
self.g2.copy_(ortho_init(torch.zeros(D_GATE_LORA, hidden_size), 0.1)) |
|
|
|
self.k_k.copy_(torch.ones(1,1,hidden_size, device=device, dtype=dtype)*0.85) |
|
self.k_a.copy_(torch.ones(1,1,hidden_size, device=device, dtype=dtype)) |
|
self.r_k.copy_(torch.zeros(n_head,head_size, device=device, dtype=dtype)) |
|
|
|
self.receptance = nn.Linear(hidden_size, hidden_size_att, bias=False, device=device, dtype=dtype) |
|
self.key = nn.Linear(hidden_size, hidden_size_att, bias=False, device=device, dtype=dtype) |
|
self.value = nn.Linear(hidden_size, hidden_size_att, bias=False, device=device, dtype=dtype) |
|
self.output = nn.Linear(hidden_size_att, hidden_size, bias=False, device=device, dtype=dtype) |
|
self.ln_x = nn.GroupNorm(n_head, hidden_size_att, device=device, dtype=dtype, eps=(1e-5)*head_size) |
|
|
|
def forward(self, x:Tensor, shift_state_in:Tensor, wkv_state_in:Tensor, v_first_val:Tensor) -> tuple[Tensor,Tensor,Tensor,Tensor]: |
|
''' |
|
forwarding time mix given the model weights and the input tokens and states. |
|
|
|
Given: |
|
- Incoming token embedding size of shape [batch_size, seq_len, embedding_size] |
|
- Incoming states containing of shapes: |
|
[batch_size, state_size] ## Token Shift state, |
|
[batch_size, n_head, head_size, head_size] ## WKV state |
|
- Incoming v_first_val of shape [batch_size, seq_len, embedding_size] |
|
|
|
|
|
Returns a pair |
|
- output embedding of shape [batch_size, seq_len, embedding_size] |
|
- output state of shapes: |
|
[batch_size, state_size] ## Token Shift state, |
|
[batch_size, n_head, head_size, head_size] ## WKV state |
|
- output v_first_val of shape [batch_size, seq_len, embedding_size] |
|
|
|
''' |
|
|
|
BATCH_SIZE, SEQ_LEN, IN_EMB_SIZE = x.size() |
|
N_HEAD = self.n_head |
|
HEAD_SIZE = self.head_size |
|
|
|
|
|
if wkv_state_in is None: |
|
wkv_state_in = torch.zeros(BATCH_SIZE,N_HEAD,HEAD_SIZE,HEAD_SIZE, dtype=torch.float,device=w.device) |
|
else: |
|
wkv_state_in = wkv_state_in.clone() |
|
|
|
|
|
|
|
|
|
|
|
shift_state_out = x[:, -1] |
|
dxprev = torch.cat((shift_state_in.unsqueeze(1), x[:, :-1]), dim=1) - x |
|
|
|
xr = x + dxprev * self.x_r |
|
xw = x + dxprev * self.x_w |
|
xk = x + dxprev * self.x_k |
|
xv = x + dxprev * self.x_v |
|
xa = x + dxprev * self.x_a |
|
xg = x + dxprev * self.x_g |
|
xx = dxprev |
|
|
|
r = self.receptance(xr) |
|
w = torch.tanh(xw @ self.w1) @ self.w2 |
|
k = self.key(xk) |
|
v = self.value(xv) |
|
g = torch.sigmoid(xg @ self.g1) @ self.g2 |
|
iclr = torch.sigmoid(self.a0 + (xa @ self.a1) @ self.a2) |
|
|
|
kk = F.normalize((k * self.k_k).view(BATCH_SIZE,SEQ_LEN,N_HEAD,-1), dim=-1, p=2.0).view(BATCH_SIZE, SEQ_LEN, IN_EMB_SIZE) |
|
k = k * (1 + (iclr-1) * self.k_a) |
|
|
|
if self.layer_id == 0: |
|
v_first_val = v |
|
else: |
|
v = v + (v_first_val - v) * torch.sigmoid(self.v0 + (xv @ self.v1) @ self.v2) |
|
|
|
tmix_backend = self.tmix_backend.lower() |
|
if tmix_backend == "auto": |
|
if triton is None or self.receptance.weight.device.type == "cpu": |
|
tmix_backend = "pytorch" |
|
else: |
|
tmix_backend = "cuda" |
|
|
|
if tmix_backend == "pytorch_ref" or tmix_backend == "pytorch_ref_ori": |
|
|
|
|
|
|
|
w = torch.exp(-0.606531 * torch.sigmoid((self.w0 + w).float())) |
|
xx, wkv_state_out = rwkv7_attn_pytorch_ref(r, w, k, v, kk, iclr, BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, xx, wkv_state_in) |
|
elif tmix_backend == "pytorch_ref_fp32": |
|
|
|
|
|
|
|
|
|
w = -F.softplus(-(self.w0 + w)) - 0.5 |
|
xx, wkv_state_out = rwkv7_attn_pytorch_ref_fp32(r, w, k, v, kk, iclr, BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, xx, wkv_state_in) |
|
elif tmix_backend == "pytorch": |
|
|
|
|
|
|
|
w = torch.exp(-0.606531 * torch.sigmoid((self.w0 + w).float())) |
|
xx, wkv_state_out = rwkv7_attn_pytorch(r, w, k, v, kk, iclr, BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, xx, wkv_state_in) |
|
elif tmix_backend == "triton": |
|
if triton is None: |
|
raise ValueError("Triton not available, unable to load triton kernel") |
|
|
|
w = -F.softplus(-(self.w0 + w)) - 0.5 |
|
xx, wkv_state_out = rwkv7_attn_triton(r, w, k, v, kk, iclr, s0=wkv_state_in) |
|
elif tmix_backend == "triton_bighead": |
|
if triton is None: |
|
raise ValueError("Triton not available, unable to load triton kernel") |
|
|
|
w = -F.softplus(-(self.w0 + w)) - 0.5 |
|
xx, wkv_state_out = rwkv7_attn_triton_bighead(r, w, k, v, kk, iclr, s0=wkv_state_in) |
|
elif tmix_backend == "cuda_ref": |
|
|
|
|
|
|
|
w = -F.softplus(-(self.w0 + w)) - 0.5 |
|
xx, wkv_state_out = rwkv7_attn_cuda_ref(r, w, k, v, kk, iclr, s0=wkv_state_in) |
|
elif tmix_backend == "cuda": |
|
|
|
|
|
|
|
w = -F.softplus(-(self.w0 + w)) - 0.5 |
|
xx, wkv_state_out = rwkv7_attn_cuda(r, w, k, v, kk, iclr, s0=wkv_state_in) |
|
elif tmix_backend == "fla": |
|
|
|
|
|
|
|
w = -F.softplus(-(self.w0 + w)) - 0.5 |
|
xx, wkv_state_out = rwkv7_attn_fla(r, w, k, v, kk, iclr, BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, xx, wkv_state_in) |
|
elif tmix_backend == "fla_fused" or tmix_backend == "fused_fla": |
|
|
|
|
|
|
|
w = -F.softplus(-(self.w0 + w)) - 0.5 |
|
xx, wkv_state_out = rwkv7_attn_fused_reccurent_fla(r, w, k, v, kk, iclr, BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, xx, wkv_state_in) |
|
else: |
|
raise ValueError(f"Unknown tmix_backend: {tmix_backend}") |
|
|
|
|
|
if wkv_state_in is not None: |
|
wkv_state_out = wkv_state_out.to(wkv_state_in.dtype) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
xx = self.ln_x(xx.view(BATCH_SIZE * SEQ_LEN, IN_EMB_SIZE)).view(BATCH_SIZE, SEQ_LEN, IN_EMB_SIZE) |
|
xx = xx + ((r.view(BATCH_SIZE,SEQ_LEN,N_HEAD,-1)*k.view(BATCH_SIZE,SEQ_LEN,N_HEAD,-1)*self.r_k).sum(dim=-1, keepdim=True) * v.view(BATCH_SIZE,SEQ_LEN,N_HEAD,-1)).view(BATCH_SIZE,SEQ_LEN,IN_EMB_SIZE) |
|
xx = self.output(xx * g) |
|
|
|
return xx, shift_state_out, wkv_state_out, v_first_val |
|
|
|
@torch.compile(mode="default") |
|
def forward_with_default_compile(self, in_x:Tensor, shift_state_in:Tensor, wkv_state_in:Tensor, v_first_val_in:Tensor, out_x:Tensor, shift_state_out:Tensor, wkv_state_out:Tensor, v_first_val_out:Tensor) -> tuple[Tensor,Tensor,Tensor,Tensor]: |
|
''' |
|
Compiled varient of the forward function |
|
With no new tensors being created for the output |
|
Useful for static memory allocation optimizations inference |
|
''' |
|
out_x[:], shift_state_out[:], wkv_state_out[:], v_first_val_out[:] = self.forward(in_x, shift_state_in, wkv_state_in, v_first_val_in) |
|
return out_x, shift_state_out, wkv_state_out, v_first_val_out |
|
|
|
@torch.compile(mode="reduce-overhead") |
|
def forward_with_reduce_compile(self, in_x:Tensor, shift_state_in:Tensor, wkv_state_in:Tensor, v_first_val:Tensor) -> tuple[Tensor,Tensor,Tensor,Tensor]: |
|
''' |
|
Compiled varient of the forward function |
|
With no input tensor being modified. |
|
Useful for reduce-overhead compile mode |
|
''' |
|
return self.forward(in_x, shift_state_in, wkv_state_in, v_first_val) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_from_model_state_dict(self, model_state_dict: dict, layer_id:int, non_blocking:bool=True): |
|
''' |
|
Given the Full/partial RWKV model weights, loaded via `torch.load` |
|
Setup the the current module weights, using the layer_id |
|
''' |
|
|
|
current_state_dict = self.state_dict() |
|
|
|
|
|
for n in current_state_dict: |
|
model_key = f"blocks.{layer_id}.att.{n}" |
|
if model_key not in model_state_dict: |
|
continue |
|
|
|
|
|
try: |
|
current_state_dict[n].copy_(model_state_dict[model_key], non_blocking=non_blocking) |
|
except Exception as e: |
|
print(f"[ERROR] loading: {model_key} | model shape: {current_state_dict[n].shape} | weight shape: {model_state_dict[model_key].shape}") |
|
raise e |
|
|
|
|
|
|
|
|
|
import torch |
|
from torch import nn |
|
from torch import Tensor |
|
from typing import Union |
|
from torch.nn import functional as F |
|
|
|
|
|
|
|
|
|
|
|
class RWKV7LayerBlock(torch.nn.Module): |
|
''' |
|
layer block for RWKV V7 |
|
''' |
|
|
|
def __init__(self, configMap: Union[RWKV7BlockConfigMap, any]): |
|
super().__init__() |
|
|
|
configMap:RWKV7BlockConfigMap = RWKV7BlockConfigMap.normalize(configMap) |
|
self.configMap = configMap |
|
|
|
|
|
hidden_size = configMap.hidden_size |
|
device = configMap.get_device(None) |
|
dtype = configMap.get_dtype('bfloat16') |
|
dropout_rate = configMap.dropout_rate |
|
|
|
|
|
layer_id = configMap.get_layer_id(-1) |
|
assert layer_id >= 0, f'layer_id must be >= 0, got {layer_id}' |
|
|
|
|
|
self.ln1 = nn.LayerNorm(hidden_size, device=device, dtype=dtype) |
|
self.ln2 = nn.LayerNorm(hidden_size, device=device, dtype=dtype) |
|
|
|
if layer_id == 0: |
|
self.ln0 = nn.LayerNorm(hidden_size, device=device, dtype=dtype) |
|
else: |
|
self.ln0 = nn.Identity(device=device) |
|
|
|
|
|
self.att = RWKV7TimeMix(configMap) |
|
self.ffn = RWKV7ChannelMix(configMap) |
|
|
|
|
|
if dropout_rate > 0.0: |
|
self.drop0 = nn.Dropout(p = dropout_rate,device=device) |
|
self.drop1 = nn.Dropout(p = dropout_rate,device=device) |
|
else: |
|
self.drop0 = nn.Identity(device=device) |
|
self.drop1 = nn.Identity(device=device) |
|
|
|
def init_parameters(self): |
|
''' |
|
Reset the parameters of the block, to an initial state used for training a model from scratch |
|
''' |
|
configMap = self.configMap |
|
|
|
|
|
hidden_size = configMap.hidden_size |
|
device = configMap.get_device(None) |
|
dtype = configMap.get_dtype('bfloat16') |
|
dropout_rate = configMap.dropout_rate |
|
|
|
|
|
layer_id = configMap.get_layer_id(-1) |
|
assert layer_id >= 0, f'layer_id must be >= 0, got {layer_id}' |
|
|
|
|
|
self.ln1 = nn.LayerNorm(hidden_size, device=device, dtype=dtype) |
|
self.ln2 = nn.LayerNorm(hidden_size, device=device, dtype=dtype) |
|
|
|
if layer_id == 0: |
|
self.ln0 = nn.LayerNorm(hidden_size, device=device, dtype=dtype) |
|
else: |
|
self.ln0 = nn.Identity(device=device) |
|
|
|
|
|
self.att.init_parameters() |
|
self.ffn.init_parameters() |
|
|
|
def forward( |
|
self, x:torch.Tensor, |
|
last_state: tuple[torch.Tensor,torch.Tensor,torch.Tensor], |
|
v_first:torch.Tensor |
|
) -> tuple[torch.Tensor,tuple[torch.Tensor,torch.Tensor,torch.Tensor],torch.Tensor]: |
|
''' |
|
Forward the block given the input tokens and the last state |
|
Last state is a tuple of the following |
|
- time mix shift state |
|
- time mix wkv state |
|
- channel mix shift state |
|
|
|
Returns a pair of the output embedding, v_first and the next state |
|
''' |
|
|
|
|
|
|
|
|
|
|
|
|
|
ln0_out = self.ln0(x) |
|
|
|
|
|
|
|
|
|
|
|
att_out, tmix_shift, tmix_wkv, v_first = self.att( |
|
self.ln1(ln0_out), |
|
last_state[0], |
|
last_state[1], |
|
v_first |
|
) |
|
|
|
|
|
x = self.drop0(ln0_out + att_out) |
|
|
|
ffn_out, ffn_state = self.ffn( |
|
self.ln2(x), |
|
last_state[2] |
|
) |
|
|
|
|
|
x = self.drop1(x + ffn_out) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return x, (tmix_shift, tmix_wkv, ffn_state), v_first |
|
|
|
@torch.compile(mode="default") |
|
def forward_with_default_compile( |
|
self, |
|
in_x:torch.Tensor, |
|
in_state: tuple[torch.Tensor,torch.Tensor,torch.Tensor], |
|
in_v_first:torch.Tensor, |
|
out_x:torch.Tensor, |
|
out_state: tuple[torch.Tensor,torch.Tensor,torch.Tensor], |
|
out_v_first:torch.Tensor |
|
) -> tuple[torch.Tensor,tuple[torch.Tensor,torch.Tensor,torch.Tensor],torch.Tensor]: |
|
''' |
|
Compiled varient of the forward function |
|
With no new tensors being created for the output |
|
Useful for static memory allocation optimizations inference |
|
''' |
|
out_x[:], tmp_state, out_v_first[:] = self.forward(in_x, in_state, in_v_first) |
|
out_state[0][:], out_state[1][:], out_state[2][:] = tmp_state |
|
return out_x, out_state, out_v_first |
|
|
|
@torch.compile(mode="reduce-overhead") |
|
def forward_with_reduce_compile(self, in_x: torch.Tensor, in_state: tuple[torch.Tensor,torch.Tensor,torch.Tensor], in_v_first:torch.Tensor) -> tuple[torch.Tensor,tuple[torch.Tensor,torch.Tensor,torch.Tensor],torch.Tensor]: |
|
''' |
|
Compiled varient of the forward function |
|
''' |
|
return self.forward(in_x, in_state, in_v_first) |
|
|
|
def load_from_model_state_dict(self, model_state_dict:dict, layer_id:int=-1, non_blocking:bool=True): |
|
''' |
|
Given the Full/partial RWKV model weights, load the block weights accordingly |
|
''' |
|
if layer_id <= -1: |
|
layer_id = self.configMap.get_layer_id(-1) |
|
assert layer_id >= 0, f'layer_id must be >= 0, got {layer_id}' |
|
|
|
|
|
current_state_dict = self.state_dict() |
|
|
|
|
|
for n in current_state_dict: |
|
model_key = f"blocks.{layer_id}.{n}" |
|
if model_key not in model_state_dict: |
|
continue |
|
|
|
|
|
try: |
|
current_state_dict[n].copy_(model_state_dict[model_key], non_blocking=non_blocking) |
|
except Exception as e: |
|
print(f"[ERROR] loading: {model_key} | model shape: {current_state_dict[n].shape} | weight shape: {model_state_dict[model_key].shape}") |
|
raise e |
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass |
|
from typing import Optional |
|
from typing import Union |
|
import torch |
|
|
|
|
|
|
|
@dataclass |
|
class RWKV7GooseConfigMap(RWKV7BlockConfigMap): |
|
|
|
vocab_size: int = 65536 |
|
init_state_wkv: bool = False |
|
|
|
|
|
|
|
|
|
def __init__( |
|
self, |
|
vocab_size: int = 65536, |
|
init_state_wkv: bool = False, |
|
**kwargs |
|
) -> None: |
|
self.vocab_size = vocab_size |
|
self.init_state_wkv = init_state_wkv |
|
super().__init__(**kwargs) |
|
|
|
@staticmethod |
|
def normalize(config_map: any) -> 'RWKV7GooseConfigMap': |
|
''' |
|
Converts either maps, objs or RWKV7BlockConfigMap |
|
''' |
|
if isinstance(config_map, RWKV7GooseConfigMap): |
|
return config_map |
|
|
|
if isinstance(config_map, dict): |
|
return RWKV7GooseConfigMap(**config_map) |
|
|
|
if hasattr(config_map, '__dict__'): |
|
return RWKV7GooseConfigMap(**config_map.__dict__) |
|
|
|
raise ValueError(f"Unsupported config_map type: {type(config_map)}") |
|
|
|
@staticmethod |
|
def from_model_state_dict(state_dict: dict, **kwargs) -> 'RWKV7GooseConfigMap': |
|
''' |
|
Converts the state dict to the config map |
|
''' |
|
|
|
|
|
num_hidden_layers = 0 |
|
for key in state_dict.keys(): |
|
if key.startswith('blocks.'): |
|
idx = key.split('.')[1] |
|
num_hidden_layers = max(num_hidden_layers, int(idx)+1) |
|
|
|
|
|
if 'init_state.0.wkv' in state_dict: |
|
kwargs['init_state_wkv'] = True |
|
|
|
|
|
return RWKV7GooseConfigMap( |
|
num_hidden_layers=num_hidden_layers, |
|
hidden_size=state_dict['emb.weight'].shape[1], |
|
vocab_size=state_dict['emb.weight'].shape[0], |
|
|
|
|
|
|
|
head_size=state_dict['blocks.0.att.r_k'].shape[1], |
|
|
|
hidden_size_att=state_dict['blocks.0.att.key.weight'].shape[1], |
|
hidden_size_ffn=state_dict['blocks.0.ffn.key.weight'].shape[0], |
|
|
|
**kwargs |
|
) |
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from torch import nn |
|
from torch import Tensor |
|
from typing import Union |
|
|
|
|
|
|
|
|
|
class RWKV7GooseModel(nn.Module): |
|
''' |
|
RWKV7 Goose Model architecture |
|
Simplified implementation |
|
''' |
|
|
|
def __init__(self, config: Union[RWKV7GooseConfigMap, any, None] = None): |
|
|
|
super().__init__() |
|
|
|
|
|
configMap:RWKV7GooseConfigMap = RWKV7GooseConfigMap.normalize(config) |
|
self.configMap = configMap |
|
|
|
|
|
num_hidden_layers = configMap.num_hidden_layers |
|
vocab_size = configMap.vocab_size |
|
device = configMap.get_device(None) |
|
dtype = configMap.get_dtype('bfloat16') |
|
hidden_size = configMap.hidden_size |
|
|
|
|
|
self.emb = nn.Embedding(vocab_size, hidden_size, device=device, dtype=dtype) |
|
|
|
|
|
blockList = [None]*num_hidden_layers |
|
for i in range(num_hidden_layers): |
|
blockList[i] = RWKV7LayerBlock(configMap.new_block_config_map(layer_id=i)) |
|
self.blocks = nn.ModuleList(blockList) |
|
|
|
|
|
self.ln_out = nn.LayerNorm(hidden_size, device=device, dtype=dtype) |
|
self.head = nn.Linear(hidden_size, vocab_size, bias=False, device=device, dtype=dtype) |
|
|
|
|
|
if configMap.init_state_wkv: |
|
stateTuneList = [None]*num_hidden_layers |
|
for i in range(num_hidden_layers): |
|
stateTuneList[i] = nn.ParameterDict({ |
|
"wkv": nn.Parameter(torch.zeros(hidden_size // 64, 64, 64, device=device, dtype=dtype)), |
|
}) |
|
self.init_state = nn.ParameterList(stateTuneList) |
|
|
|
def init_parameters(self): |
|
''' |
|
Reset the parameters of the block, to an initial state used for training a model from scratch |
|
''' |
|
|
|
|
|
configMap = self.configMap |
|
num_hidden_layers = configMap.num_hidden_layers |
|
vocab_size = configMap.vocab_size |
|
device = configMap.get_device(None) |
|
dtype = configMap.get_dtype('bfloat16') |
|
hidden_size = configMap.hidden_size |
|
|
|
|
|
for i in range(num_hidden_layers): |
|
self.blocks[i].init_parameters() |
|
|
|
|
|
self.emb = nn.Embedding(vocab_size, hidden_size, device=device, dtype=dtype) |
|
|
|
|
|
self.ln_out = nn.LayerNorm(hidden_size, device=device, dtype=dtype) |
|
if self.head is not None: |
|
self.head = nn.Linear(hidden_size, vocab_size, bias=False, device=device, dtype=dtype) |
|
|
|
|
|
if configMap.init_state_wkv: |
|
stateTuneList = [None]*num_hidden_layers |
|
for i in range(num_hidden_layers): |
|
stateTuneList[i] = nn.ParameterDict({ |
|
"wkv": nn.Parameter(torch.zeros(hidden_size // 64, 64, 64, device=device, dtype=torch.float)), |
|
}) |
|
self.init_state = nn.ParameterList(stateTuneList) |
|
|
|
def load_from_model_state_dict(self, state_dict: dict, non_blocking:bool=True): |
|
''' |
|
Given the Full/partial RWKV model weights, loaded via `torch.load` |
|
Setup the RWKV_TimeMix model weights, using the layer_id |
|
''' |
|
for i, block in enumerate(self.blocks): |
|
block.load_from_model_state_dict(state_dict, i, non_blocking=non_blocking) |
|
|
|
self.ln_out.weight.data.copy_(state_dict['ln_out.weight'], non_blocking=True) |
|
self.ln_out.bias.data.copy_(state_dict['ln_out.bias'], non_blocking=True) |
|
self.head.weight.data.copy_(state_dict['head.weight'], non_blocking=True) |
|
self.emb.weight.data.copy_(state_dict['emb.weight'], non_blocking=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_init_state(self, batch_size:int=1, skip_init_state:bool=False) -> list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]]: |
|
''' |
|
Get an initialized copy of the model state, for the given batch size |
|
''' |
|
|
|
hidden_size = self.configMap.hidden_size |
|
init_state_wkv = self.configMap.init_state_wkv |
|
num_hidden_layers = self.configMap.num_hidden_layers |
|
|
|
|
|
init_state = [ None for i in range(num_hidden_layers) ] |
|
for i in range(num_hidden_layers): |
|
device = self.blocks[i].ln1.weight.data.device |
|
dtype = self.blocks[i].ln1.weight.data.dtype |
|
|
|
|
|
|
|
wkv_state = torch.zeros(batch_size, hidden_size // 64, 64, 64, device=device, dtype=torch.float) |
|
if init_state_wkv and skip_init_state == False: |
|
init_wkv = self.init_state[i]["wkv"] |
|
for b in range(batch_size): |
|
wkv_state[b][:] = init_wkv |
|
|
|
|
|
init_state[i] = ( |
|
torch.zeros(batch_size, hidden_size, device=device, dtype=dtype), |
|
wkv_state, |
|
torch.zeros(batch_size, hidden_size, device=device, dtype=dtype) |
|
) |
|
return init_state |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward( |
|
self, idx:torch.Tensor, |
|
prv_stateList:list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]] = None, |
|
ret_stateList:list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]] = None, |
|
) -> tuple[torch.Tensor,list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]]]: |
|
''' |
|
Forward the block set, given the input tokens and the last state |
|
Last state is a list of tuple of the following |
|
- time mix shift state |
|
- time mix wkv state |
|
- channel mix shift state |
|
|
|
Returns a pair of the output embedding and the next state |
|
''' |
|
|
|
if prv_stateList is None: |
|
prv_stateList = self.get_init_state(idx.shape[0]) |
|
|
|
|
|
if ret_stateList is None: |
|
ret_stateList = [ None for i in range(self.configMap.num_hidden_layers) ] |
|
return self._forward_internal(idx, prv_stateList, ret_stateList, overwrite_ret_tensor=False) |
|
|
|
|
|
return self._forward_internal(idx, prv_stateList, ret_stateList, overwrite_ret_tensor=True) |
|
|
|
def _forward_block_hook(self, |
|
block:RWKV7LayerBlock, |
|
x_hidden_state:torch.Tensor, |
|
prv_stateList:list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]], |
|
v_first:torch.Tensor |
|
) -> tuple[torch.Tensor,tuple[torch.Tensor,torch.Tensor,torch.Tensor],torch.Tensor]: |
|
''' |
|
Forward block hook operation, that is easily overridable. |
|
To implement gradient checkpointing for use in various trainers |
|
''' |
|
x_hidden_state = x_hidden_state.to(block.ln1.weight.device, non_blocking=True) |
|
return block(x_hidden_state, prv_stateList, v_first) |
|
|
|
def _forward_internal( |
|
self, idx:torch.Tensor, |
|
prv_stateList:list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]], |
|
ret_stateList:list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]], |
|
overwrite_ret_tensor:bool=False |
|
) -> tuple[torch.Tensor,list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]]]: |
|
''' |
|
Internal forward operations, which assumes the state is already initialized |
|
Due to the lack of safety checks, this should not be used directly |
|
''' |
|
|
|
|
|
idx = idx.to(self.emb.weight.device, non_blocking=True) |
|
x_hidden_state = self.emb(idx) |
|
v_first = None |
|
|
|
|
|
if overwrite_ret_tensor: |
|
for i, block in enumerate(self.blocks): |
|
|
|
x_hidden_state, last_block_state, v_first = self._forward_block_hook(block, x_hidden_state, prv_stateList[i], v_first) |
|
ret_stateList[i][0][:] = last_block_state[0] |
|
ret_stateList[i][1][:] = last_block_state[1] |
|
ret_stateList[i][2][:] = last_block_state[2] |
|
else: |
|
ret_stateList = [ None for i in range( len(self.blocks) ) ] |
|
for i, block in enumerate(self.blocks): |
|
|
|
x_hidden_state, ret_sublist, v_first = self._forward_block_hook(block, x_hidden_state, prv_stateList[i], v_first) |
|
ret_stateList[i] = ret_sublist |
|
|
|
|
|
x_hidden_state = x_hidden_state.to(self.ln_out.weight.device, non_blocking=True) |
|
x_hidden_state = self.ln_out(x_hidden_state) |
|
x_out = self.head(x_hidden_state) |
|
|
|
|
|
return x_out, ret_stateList |
|
|
|
@torch.compile(mode="default") |
|
def forward_with_default_compile( |
|
self, idx:torch.Tensor, |
|
prv_stateList:list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]], |
|
ret_stateList:list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]], |
|
) -> tuple[torch.Tensor,list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]]]: |
|
''' |
|
Compiled varient of the forward function |
|
With no new tensors being created for the output |
|
Useful for static memory allocation optimizations inference |
|
''' |
|
|
|
return self._forward_internal(idx, prv_stateList, ret_stateList, overwrite_ret_tensor=True) |
|
|
|
@torch.compile(mode="reduce-overhead") |
|
def forward_with_reduce_compile( |
|
self, in_idx:torch.Tensor, |
|
prv_stateList:list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]] |
|
) -> tuple[torch.Tensor,list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]]]: |
|
''' |
|
Compiled varient of the forward function, requires previous state to be passed |
|
''' |
|
return self._forward_internal(in_idx, prv_stateList, None, overwrite_ret_tensor=False) |
|
|
|
|