zaydzuhri's picture
Add files using upload-large-folder tool
4135502 verified
raw
history blame
9.66 kB
import logging
import torch
import triton
import triton.language as tl
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_pytorch_version, input_guard, use_cuda_graph
logger = logging.getLogger(__name__)
if not check_pytorch_version('2.4'):
logger.warning('PyTorch < 2.4 detected - computations may be slower due to lack of optimizations')
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': block_size})
for block_size in [128, 256, 512, 1024, 2048, 4096, 8192]
],
key=['hidden_dim'],
use_cuda_graph=use_cuda_graph,
)
@triton.jit
def rwkv_seq_mix_kernel(
x_ptr,
x_prev_ptr,
mix_k_ptr,
output_ptr,
batch_size: tl.constexpr,
token_length,
hidden_dim: tl.constexpr,
BLOCK_SIZE: tl.constexpr
):
block_start = tl.program_id(0) * BLOCK_SIZE
block_idx = block_start + tl.arange(0, BLOCK_SIZE)[:]
total_seq_dim = token_length * hidden_dim
batch_idx = block_idx // total_seq_dim
seq_and_feat = block_idx % total_seq_dim
seq_idx = seq_and_feat // hidden_dim
feat_idx = seq_and_feat % hidden_dim
is_valid = (batch_idx < batch_size) & (seq_idx < token_length)
x_idx = batch_idx * total_seq_dim + seq_idx * hidden_dim + feat_idx
curr_x = tl.load(x_ptr + x_idx, mask=is_valid, other=0.0).to(tl.float32)
k_value = tl.load(mix_k_ptr + feat_idx).to(tl.float32)
is_first = seq_idx < 1
prev_state_idx = batch_idx * hidden_dim + feat_idx
prev_state = tl.load(x_prev_ptr + prev_state_idx,
mask=(is_first & is_valid),
other=0.0).to(tl.float32)
prev_x_idx = x_idx - hidden_dim
prev_x = tl.load(x_ptr + prev_x_idx,
mask=(~is_first & is_valid),
other=0.0).to(tl.float32)
prev_value = tl.where(is_first, prev_state, prev_x)
state_diff = prev_value - curr_x
mixed = state_diff * k_value
result = tl.cast(curr_x + mixed, dtype=output_ptr.dtype.element_ty, fp_downcast_rounding='rtne')
tl.store(output_ptr + x_idx, result, mask=is_valid)
@triton.jit
def rwkv_channel_mixing_pow_and_relu(
in_ptr,
out_ptr,
BLOCK_SIZE: tl.constexpr
):
"""Fused ReLU and Power operation: x = ReLU(x)^2"""
xoffset = tl.program_id(0) * BLOCK_SIZE
xindex = xoffset + tl.arange(0, BLOCK_SIZE)
x0 = xindex
x = tl.load(in_ptr + (x0), None)
x = tl.maximum(x, 0.0).to(tl.float32)
x = tl.cast(x * x, dtype=out_ptr.dtype.element_ty, fp_downcast_rounding='rtne')
tl.store(out_ptr + (x0), x, None)
def rwkv_mix_torch(x: torch.Tensor, x_prev: torch.Tensor, x_k: torch.Tensor):
if x_prev.dim() == 2:
x_prev = x_prev.unsqueeze(1) # (batch_size, 1, hidden_dim)
xx = torch.cat((x_prev, x[:, :-1, :]), dim=1) - x
k = x.addcmul(xx, x_k)
return k
def rwkv_relu_and_square_torch(x: torch.Tensor):
return torch.relu(x) ** 2
def rwkv_mix_fwd(x, x_prev, x_k):
has_batch = x.dim() == 3
if has_batch:
batch_size, token_length, hidden_dim = x.shape
else:
token_length, hidden_dim = x.shape
batch_size = 1
x = x.unsqueeze(0)
x_prev = x_prev.unsqueeze(0)
token_length = x.shape[1]
hidden_dim = x.shape[2]
total_elements = batch_size * token_length * hidden_dim
output = torch.empty_like(x)
def grid(meta): return (
(total_elements + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'], # grid_0
1, # grid_1
1 # grid_2
)
rwkv_seq_mix_kernel[grid](
x.contiguous(),
x_prev.contiguous(),
x_k.squeeze(),
output,
batch_size=batch_size,
token_length=token_length,
hidden_dim=hidden_dim,
)
if not has_batch:
output = output.squeeze(0)
return output
def rwkv_relu_and_square_fwd(x: torch.Tensor, inplace: bool = True):
"""
Triton implementation of RWKV's ReLU and square operation
Args:
x: Input tensor
Returns:
Tensor after ReLU and square operations
"""
x = x.contiguous()
output = x if inplace else torch.empty_like(x)
def grid(meta): return (
(output.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'], # grid_0
1, # grid_1
1 # grid_2
)
rwkv_channel_mixing_pow_and_relu[grid](
x,
output,
BLOCK_SIZE=4096,
)
return output
@triton.jit
def relu_square_bwd_kernel(
out_ptr,
forward_input_ptr,
BLOCK_SIZE: tl.constexpr
):
"""ReLU(x)^2 backward kernel
grad_input = grad_output * 2 * x if x > 0 else 0
"""
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
x = tl.load(forward_input_ptr + offsets).to(tl.float32)
grad = tl.load(out_ptr + offsets).to(tl.float32)
x = tl.maximum(x, 0.0)
grad_input = grad * 2 * x
tl.store(out_ptr + offsets, grad_input.to(out_ptr.dtype.element_ty))
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': block_size})
for block_size in [128, 256, 512, 1024, 2048, 4096, 8192]
],
key=['hidden_dim'],
use_cuda_graph=use_cuda_graph,
)
@triton.jit
def rwkv_mix_bwd_kenel(
dk1_ptr0,
xk_ptr,
dx_ptr,
dx_prev_ptr,
batch_size,
token_length,
hidden_dim: tl.constexpr,
BLOCK_SIZE: tl.constexpr
):
pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
batch_idx = offsets // (token_length * hidden_dim)
seq_feat = offsets % (token_length * hidden_dim)
seq_idx = seq_feat // hidden_dim
feat_idx = seq_feat % hidden_dim
is_valid = offsets < (batch_size * token_length * hidden_dim)
dk1 = tl.load(dk1_ptr0 + offsets, mask=is_valid)
xk = tl.load(xk_ptr + feat_idx, mask=is_valid)
prod = dk1 * xk
mask_next = seq_idx < (token_length - 1)
next_offset = offsets + hidden_dim
dk1_next = tl.load(dk1_ptr0 + next_offset, mask=mask_next & is_valid, other=0.0)
prod_next = dk1_next * xk
dx_val = dk1 - prod + tl.where(mask_next, prod_next, 0.0)
dx_val = tl.cast(dx_val, dtype=dx_ptr.dtype.element_ty, fp_downcast_rounding='rtne')
tl.store(dx_ptr + offsets, dx_val, mask=is_valid)
dx_prev_offset = batch_idx * hidden_dim + feat_idx
is_first_step = seq_idx == 0
tl.store(
dx_prev_ptr + dx_prev_offset,
tl.cast(prod, dtype=dx_prev_ptr.dtype.element_ty),
mask=is_first_step
)
@torch.compile(fullgraph=True)
def compute_x_k_grad(dk1, x, x_prev):
"""
Args:
dk1: (batch*seq_len, hidden_dim)
x: (batch, seq_len, hidden_dim)
x_prev: (batch, hidden_dim) or (batch, 1, hidden_dim)
"""
if x_prev.dim() == 2:
x_prev = x_prev.unsqueeze(1) # (batch, 1, hidden_dim)
xx = torch.cat((x_prev, x[:, :-1, :]), dim=1) - x # (batch, seq_len, hidden_dim)
# (hidden_dim,) --> (1, 1, hidden_dim)
grad_x_k = (dk1 * xx.reshape(-1, x.shape[2])).sum(dim=0).view(1, 1, -1)
return grad_x_k
def rwkv_channel_mixing_bwd(grad_output, x, x_prev, x_k, key_weight, value_weight, k1, k1_K, k, inplace=True):
batch_size = x.shape[0] if x.dim() == 3 else 1
seq_len, n_embd = x.shape[-2], x.shape[-1]
dV = k.transpose(-2, -1) @ grad_output
dk = grad_output @ value_weight.transpose(-2, -1)
BLOCK_SIZE = 4096
grid = ((dk.numel() + BLOCK_SIZE - 1) // BLOCK_SIZE,)
relu_square_bwd_kernel[grid](
dk,
k1_K,
BLOCK_SIZE=BLOCK_SIZE
)
dK = k1.transpose(-2, -1) @ dk
dk1 = dk @ key_weight.transpose(-2, -1)
dk1 = dk1.view(-1, n_embd).contiguous()
dk_reduced = compute_x_k_grad(dk1, x, x_prev)
dx_prev = torch.empty_like(x_prev) if not inplace else x_prev
dx = torch.empty_like(x) if not inplace else x
def grid(meta): return ((batch_size * seq_len * n_embd + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'], 1, 1)
rwkv_mix_bwd_kenel[grid](
dk1,
x_k.squeeze(),
dx,
dx_prev,
batch_size,
seq_len,
n_embd,
)
# dx_prev.shape batch_size, seq_len, n_embd
return dx, dx_prev, dk_reduced, dK, dV
class Rwkv7ChannelMixing(torch.autograd.Function):
@staticmethod
@input_guard
@autocast_custom_fwd
def forward(ctx, x, x_prev, x_k, key_weight, value_weight, inplace: bool = True):
k1 = rwkv_mix_fwd(x, x_prev, x_k)
k1_K = k1 @ key_weight
k = rwkv_relu_and_square_fwd(k1_K, inplace=True)
ctx.save_for_backward(x, x_prev, x_k, key_weight, value_weight)
ctx.inplace = inplace
return k @ value_weight
@staticmethod
@input_guard
@autocast_custom_bwd
def backward(ctx, dkv):
x, x_prev, x_k, key_weight, value_weight = ctx.saved_tensors
k1 = rwkv_mix_fwd(x, x_prev, x_k)
k1_K = k1 @ key_weight
k = rwkv_relu_and_square_fwd(k1_K, inplace=False)
dx, dx_prev, dk_reduced, dK, dV = rwkv_channel_mixing_bwd(
dkv, x, x_prev, x_k, key_weight, value_weight, k1, k1_K, k, ctx.inplace)
return dx, dx_prev, dk_reduced, dK, dV, None
def channel_mixing_rwkv7(x: torch.Tensor, x_prev: torch.Tensor, x_k: torch.Tensor,
key_weight: torch.Tensor, value_weight: torch.Tensor, inplace: bool = True):
assert x.dim() == 3
return Rwkv7ChannelMixing.apply(x, x_prev, x_k, key_weight, value_weight, inplace), x[-1, :]
def channel_mixing_rwkv7_torch(x, x_prev, x_k, key_weight, value_weight):
k1 = rwkv_mix_torch(x, x_prev, x_k)
k1_K = k1 @ key_weight
k = rwkv_relu_and_square_torch(k1_K)
return k @ value_weight, x[-1, :]