# -*- coding: utf-8 -*- import torch from einops import rearrange def delta_rule_recurrence(q, k, v, beta, initial_state=None, output_final_state=True): orig_dtype = q.dtype b, h, l, d_k = q.shape q, k, v, beta = map(lambda x: x.float(), [q, k, v, beta]) d_v = v.shape[-1] o = torch.zeros_like(v) S = torch.zeros(b, h, d_k, d_v).to(v) q = q * (d_k ** -0.5) if beta.ndim < v.ndim: beta = beta[..., None] if initial_state is not None: S += initial_state for i in range(l): _k = k[:, :, i] _q = q[:, :, i] _v = v[:, :, i].clone() beta_i = beta[:, :, i] _v = _v - (S.clone() * _k[..., None]).sum(-2) _v = _v * beta_i S = S.clone() + _k.unsqueeze(-1) * _v.unsqueeze(-2) o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S) S = None if output_final_state is False else S return o.to(orig_dtype), S def delta_rule_chunkwise(q, k, v, beta, chunk_size=32): b, h, l, d_k = q.shape d_v = v.shape[-1] q = q * (d_k ** -0.5) v = v * beta[..., None] k_beta = k * beta[..., None] assert l % chunk_size == 0 # compute (I - tri(diag(beta) KK^T))^{-1} mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0) q, k, v, k_beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), [q, k, v, k_beta]) attn = -(k_beta @ k.transpose(-1, -2)).masked_fill(mask, 0) for i in range(1, chunk_size): attn[..., i, :i] = attn[..., i, :i] + (attn[..., i, :, None].clone() * attn[..., :, :i].clone()).sum(-2) attn = attn + torch.eye(chunk_size, dtype=torch.float, device=q.device) u = attn @ v w = attn @ k_beta S = k.new_zeros(b, h, d_k, d_v) o = torch.zeros_like(v) mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1) for i in range(0, l // chunk_size): q_i, k_i = q[:, :, i], k[:, :, i] attn = (q_i @ k_i.transpose(-1, -2)).masked_fill_(mask, 0) u_i = u[:, :, i] - w[:, :, i] @ S o_inter = q_i @ S o[:, :, i] = o_inter + attn @ u_i S = S + k_i.transpose(-1, -2) @ u_i return rearrange(o, 'b h n c d -> b h (n c) d'), S def delta_rule_parallel(q, k, v, beta, BM=128, BN=32): b, h, l, d_k = q.shape # d_v = v.shape[-1] q = q * (d_k ** -0.5) v = v * beta[..., None] k_beta = k * beta[..., None] # compute (I - tri(diag(beta) KK^T))^{-1} q, k, v, k_beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=BN), [q, k, v, k_beta]) mask = torch.triu(torch.ones(BN, BN, dtype=torch.bool, device=q.device), diagonal=0) T = -(k_beta @ k.transpose(-1, -2)).masked_fill(mask, 0) for i in range(1, BN): T[..., i, :i] = T[..., i, :i].clone() + (T[..., i, :, None].clone() * T[..., :, :i].clone()).sum(-2) T = T + torch.eye(BN, dtype=torch.float, device=q.device) mask2 = torch.triu(torch.ones(BN, BN, dtype=torch.bool, device=q.device), diagonal=1) A_local = (q @ k.transpose(-1, -2)).masked_fill(mask2, 0) @ T o_intra = A_local @ v # apply cumprod transition matrices on k to the last position within the chunk k = k - ((k @ k.transpose(-1, -2)).masked_fill(mask, 0) @ T).transpose(-1, -2) @ k_beta # apply cumprod transition matrices on q to the first position within the chunk q = q - A_local @ k_beta o_intra = A_local @ v A = torch.zeros(b, h, l, l, device=q.device) q, k, v, k_beta, o_intra = map(lambda x: rearrange(x, 'b h n c d -> b h (n c) d'), [q, k, v, k_beta, o_intra]) o = torch.empty_like(v) for i in range(0, l, BM): q_i = q[:, :, i:i+BM] o_i = o_intra[:, :, i:i+BM] # intra block for j in range(i + BM - 2 * BN, i-BN, -BN): k_j = k[:, :, j:j+BN] A_ij = q_i @ k_j.transpose(-1, -2) mask = torch.arange(i, i+BM) >= (j + BN) A_ij = A_ij.masked_fill_(~mask[:, None].to(A_ij.device), 0) A[:, :, i:i+BM, j:j+BN] = A_ij q_i = q_i - A_ij @ k_beta[:, :, j:j+BN] o_i += A_ij @ v[:, :, j:j+BN] # inter block for j in range(i - BN, -BN, -BN): k_j = k[:, :, j:j+BN] A_ij = q_i @ k_j.transpose(-1, -2) A[:, :, i:i+BM, j:j+BN] = A_ij q_i = q_i - A_ij @ k_beta[:, :, j:j+BN] o_i += A_ij @ v[:, :, j:j+BN] o[:, :, i:i+BM] = o_i for i in range(0, l//BN): A[:, :, i*BN:i*BN+BN, i*BN:i*BN+BN] = A_local[:, :, i] return o, A