|
import logging |
|
|
|
import torch |
|
import triton |
|
import triton.language as tl |
|
|
|
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_pytorch_version, input_guard, use_cuda_graph |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
if not check_pytorch_version('2.4'): |
|
logger.warning('PyTorch < 2.4 detected - computations may be slower due to lack of optimizations') |
|
|
|
|
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({'BLOCK_SIZE': block_size}) |
|
for block_size in [128, 256, 512, 1024, 2048, 4096, 8192] |
|
], |
|
key=['hidden_dim'], |
|
use_cuda_graph=use_cuda_graph, |
|
) |
|
@triton.jit |
|
def rwkv_seq_mix_kernel( |
|
x_ptr, |
|
x_prev_ptr, |
|
mix_k_ptr, |
|
output_ptr, |
|
batch_size: tl.constexpr, |
|
token_length, |
|
hidden_dim: tl.constexpr, |
|
BLOCK_SIZE: tl.constexpr |
|
): |
|
block_start = tl.program_id(0) * BLOCK_SIZE |
|
block_idx = block_start + tl.arange(0, BLOCK_SIZE)[:] |
|
|
|
total_seq_dim = token_length * hidden_dim |
|
batch_idx = block_idx // total_seq_dim |
|
seq_and_feat = block_idx % total_seq_dim |
|
seq_idx = seq_and_feat // hidden_dim |
|
feat_idx = seq_and_feat % hidden_dim |
|
|
|
is_valid = (batch_idx < batch_size) & (seq_idx < token_length) |
|
|
|
x_idx = batch_idx * total_seq_dim + seq_idx * hidden_dim + feat_idx |
|
|
|
curr_x = tl.load(x_ptr + x_idx, mask=is_valid, other=0.0).to(tl.float32) |
|
k_value = tl.load(mix_k_ptr + feat_idx).to(tl.float32) |
|
|
|
is_first = seq_idx < 1 |
|
prev_state_idx = batch_idx * hidden_dim + feat_idx |
|
prev_state = tl.load(x_prev_ptr + prev_state_idx, |
|
mask=(is_first & is_valid), |
|
other=0.0).to(tl.float32) |
|
|
|
prev_x_idx = x_idx - hidden_dim |
|
prev_x = tl.load(x_ptr + prev_x_idx, |
|
mask=(~is_first & is_valid), |
|
other=0.0).to(tl.float32) |
|
|
|
prev_value = tl.where(is_first, prev_state, prev_x) |
|
state_diff = prev_value - curr_x |
|
mixed = state_diff * k_value |
|
result = tl.cast(curr_x + mixed, dtype=output_ptr.dtype.element_ty, fp_downcast_rounding='rtne') |
|
tl.store(output_ptr + x_idx, result, mask=is_valid) |
|
|
|
|
|
@triton.jit |
|
def rwkv_channel_mixing_pow_and_relu( |
|
in_ptr, |
|
out_ptr, |
|
BLOCK_SIZE: tl.constexpr |
|
): |
|
"""Fused ReLU and Power operation: x = ReLU(x)^2""" |
|
xoffset = tl.program_id(0) * BLOCK_SIZE |
|
xindex = xoffset + tl.arange(0, BLOCK_SIZE) |
|
x0 = xindex |
|
x = tl.load(in_ptr + (x0), None) |
|
x = tl.maximum(x, 0.0).to(tl.float32) |
|
x = tl.cast(x * x, dtype=out_ptr.dtype.element_ty, fp_downcast_rounding='rtne') |
|
tl.store(out_ptr + (x0), x, None) |
|
|
|
|
|
def rwkv_mix_torch(x: torch.Tensor, x_prev: torch.Tensor, x_k: torch.Tensor): |
|
if x_prev.dim() == 2: |
|
x_prev = x_prev.unsqueeze(1) |
|
xx = torch.cat((x_prev, x[:, :-1, :]), dim=1) - x |
|
k = x.addcmul(xx, x_k) |
|
return k |
|
|
|
|
|
def rwkv_relu_and_square_torch(x: torch.Tensor): |
|
return torch.relu(x) ** 2 |
|
|
|
|
|
def rwkv_mix_fwd(x, x_prev, x_k): |
|
has_batch = x.dim() == 3 |
|
|
|
if has_batch: |
|
batch_size, token_length, hidden_dim = x.shape |
|
else: |
|
token_length, hidden_dim = x.shape |
|
batch_size = 1 |
|
x = x.unsqueeze(0) |
|
x_prev = x_prev.unsqueeze(0) |
|
|
|
token_length = x.shape[1] |
|
hidden_dim = x.shape[2] |
|
total_elements = batch_size * token_length * hidden_dim |
|
|
|
output = torch.empty_like(x) |
|
|
|
def grid(meta): return ( |
|
(total_elements + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'], |
|
1, |
|
1 |
|
) |
|
|
|
rwkv_seq_mix_kernel[grid]( |
|
x.contiguous(), |
|
x_prev.contiguous(), |
|
x_k.squeeze(), |
|
output, |
|
batch_size=batch_size, |
|
token_length=token_length, |
|
hidden_dim=hidden_dim, |
|
) |
|
if not has_batch: |
|
output = output.squeeze(0) |
|
return output |
|
|
|
|
|
def rwkv_relu_and_square_fwd(x: torch.Tensor, inplace: bool = True): |
|
""" |
|
Triton implementation of RWKV's ReLU and square operation |
|
Args: |
|
x: Input tensor |
|
Returns: |
|
Tensor after ReLU and square operations |
|
""" |
|
x = x.contiguous() |
|
output = x if inplace else torch.empty_like(x) |
|
|
|
def grid(meta): return ( |
|
(output.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'], |
|
1, |
|
1 |
|
) |
|
rwkv_channel_mixing_pow_and_relu[grid]( |
|
x, |
|
output, |
|
BLOCK_SIZE=4096, |
|
) |
|
|
|
return output |
|
|
|
|
|
@triton.jit |
|
def relu_square_bwd_kernel( |
|
out_ptr, |
|
forward_input_ptr, |
|
BLOCK_SIZE: tl.constexpr |
|
): |
|
"""ReLU(x)^2 backward kernel |
|
grad_input = grad_output * 2 * x if x > 0 else 0 |
|
""" |
|
pid = tl.program_id(0) |
|
block_start = pid * BLOCK_SIZE |
|
offsets = block_start + tl.arange(0, BLOCK_SIZE) |
|
|
|
x = tl.load(forward_input_ptr + offsets).to(tl.float32) |
|
grad = tl.load(out_ptr + offsets).to(tl.float32) |
|
|
|
x = tl.maximum(x, 0.0) |
|
|
|
grad_input = grad * 2 * x |
|
|
|
tl.store(out_ptr + offsets, grad_input.to(out_ptr.dtype.element_ty)) |
|
|
|
|
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({'BLOCK_SIZE': block_size}) |
|
for block_size in [128, 256, 512, 1024, 2048, 4096, 8192] |
|
], |
|
key=['hidden_dim'], |
|
use_cuda_graph=use_cuda_graph, |
|
) |
|
@triton.jit |
|
def rwkv_mix_bwd_kenel( |
|
dk1_ptr0, |
|
xk_ptr, |
|
dx_ptr, |
|
dx_prev_ptr, |
|
batch_size, |
|
token_length, |
|
hidden_dim: tl.constexpr, |
|
BLOCK_SIZE: tl.constexpr |
|
): |
|
pid = tl.program_id(0) |
|
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
|
|
|
batch_idx = offsets // (token_length * hidden_dim) |
|
seq_feat = offsets % (token_length * hidden_dim) |
|
seq_idx = seq_feat // hidden_dim |
|
feat_idx = seq_feat % hidden_dim |
|
|
|
is_valid = offsets < (batch_size * token_length * hidden_dim) |
|
|
|
dk1 = tl.load(dk1_ptr0 + offsets, mask=is_valid) |
|
xk = tl.load(xk_ptr + feat_idx, mask=is_valid) |
|
prod = dk1 * xk |
|
|
|
mask_next = seq_idx < (token_length - 1) |
|
next_offset = offsets + hidden_dim |
|
dk1_next = tl.load(dk1_ptr0 + next_offset, mask=mask_next & is_valid, other=0.0) |
|
prod_next = dk1_next * xk |
|
dx_val = dk1 - prod + tl.where(mask_next, prod_next, 0.0) |
|
dx_val = tl.cast(dx_val, dtype=dx_ptr.dtype.element_ty, fp_downcast_rounding='rtne') |
|
tl.store(dx_ptr + offsets, dx_val, mask=is_valid) |
|
|
|
dx_prev_offset = batch_idx * hidden_dim + feat_idx |
|
is_first_step = seq_idx == 0 |
|
|
|
tl.store( |
|
dx_prev_ptr + dx_prev_offset, |
|
tl.cast(prod, dtype=dx_prev_ptr.dtype.element_ty), |
|
mask=is_first_step |
|
) |
|
|
|
|
|
@torch.compile(fullgraph=True) |
|
def compute_x_k_grad(dk1, x, x_prev): |
|
""" |
|
Args: |
|
dk1: (batch*seq_len, hidden_dim) |
|
x: (batch, seq_len, hidden_dim) |
|
x_prev: (batch, hidden_dim) or (batch, 1, hidden_dim) |
|
""" |
|
|
|
if x_prev.dim() == 2: |
|
x_prev = x_prev.unsqueeze(1) |
|
xx = torch.cat((x_prev, x[:, :-1, :]), dim=1) - x |
|
|
|
|
|
grad_x_k = (dk1 * xx.reshape(-1, x.shape[2])).sum(dim=0).view(1, 1, -1) |
|
return grad_x_k |
|
|
|
|
|
def rwkv_channel_mixing_bwd(grad_output, x, x_prev, x_k, key_weight, value_weight, k1, k1_K, k, inplace=True): |
|
batch_size = x.shape[0] if x.dim() == 3 else 1 |
|
seq_len, n_embd = x.shape[-2], x.shape[-1] |
|
|
|
dV = k.transpose(-2, -1) @ grad_output |
|
dk = grad_output @ value_weight.transpose(-2, -1) |
|
|
|
BLOCK_SIZE = 4096 |
|
grid = ((dk.numel() + BLOCK_SIZE - 1) // BLOCK_SIZE,) |
|
relu_square_bwd_kernel[grid]( |
|
dk, |
|
k1_K, |
|
BLOCK_SIZE=BLOCK_SIZE |
|
) |
|
|
|
dK = k1.transpose(-2, -1) @ dk |
|
dk1 = dk @ key_weight.transpose(-2, -1) |
|
dk1 = dk1.view(-1, n_embd).contiguous() |
|
|
|
dk_reduced = compute_x_k_grad(dk1, x, x_prev) |
|
dx_prev = torch.empty_like(x_prev) if not inplace else x_prev |
|
dx = torch.empty_like(x) if not inplace else x |
|
|
|
def grid(meta): return ((batch_size * seq_len * n_embd + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'], 1, 1) |
|
rwkv_mix_bwd_kenel[grid]( |
|
dk1, |
|
x_k.squeeze(), |
|
dx, |
|
dx_prev, |
|
batch_size, |
|
seq_len, |
|
n_embd, |
|
) |
|
|
|
return dx, dx_prev, dk_reduced, dK, dV |
|
|
|
|
|
class Rwkv7ChannelMixing(torch.autograd.Function): |
|
@staticmethod |
|
@input_guard |
|
@autocast_custom_fwd |
|
def forward(ctx, x, x_prev, x_k, key_weight, value_weight, inplace: bool = True): |
|
k1 = rwkv_mix_fwd(x, x_prev, x_k) |
|
k1_K = k1 @ key_weight |
|
k = rwkv_relu_and_square_fwd(k1_K, inplace=True) |
|
ctx.save_for_backward(x, x_prev, x_k, key_weight, value_weight) |
|
ctx.inplace = inplace |
|
return k @ value_weight |
|
|
|
@staticmethod |
|
@input_guard |
|
@autocast_custom_bwd |
|
def backward(ctx, dkv): |
|
x, x_prev, x_k, key_weight, value_weight = ctx.saved_tensors |
|
k1 = rwkv_mix_fwd(x, x_prev, x_k) |
|
k1_K = k1 @ key_weight |
|
k = rwkv_relu_and_square_fwd(k1_K, inplace=False) |
|
dx, dx_prev, dk_reduced, dK, dV = rwkv_channel_mixing_bwd( |
|
dkv, x, x_prev, x_k, key_weight, value_weight, k1, k1_K, k, ctx.inplace) |
|
return dx, dx_prev, dk_reduced, dK, dV, None |
|
|
|
|
|
def channel_mixing_rwkv7(x: torch.Tensor, x_prev: torch.Tensor, x_k: torch.Tensor, |
|
key_weight: torch.Tensor, value_weight: torch.Tensor, inplace: bool = True): |
|
assert x.dim() == 3 |
|
|
|
return Rwkv7ChannelMixing.apply(x, x_prev, x_k, key_weight, value_weight, inplace), x[-1, :] |
|
|
|
|
|
def channel_mixing_rwkv7_torch(x, x_prev, x_k, key_weight, value_weight): |
|
k1 = rwkv_mix_torch(x, x_prev, x_k) |
|
k1_K = k1 @ key_weight |
|
k = rwkv_relu_and_square_torch(k1_K) |
|
return k @ value_weight, x[-1, :] |
|
|