|
|
|
|
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
|
|
def ttt_linear( |
|
q: torch.Tensor, |
|
k: torch.Tensor, |
|
v: torch.Tensor, |
|
w: torch.Tensor, |
|
b: torch.Tensor, |
|
eta: torch.Tensor, |
|
scale: float, |
|
eps: float, |
|
mini_batch_size: int, |
|
initial_state: torch.Tensor, |
|
initial_state_bias: torch.Tensor, |
|
output_final_state: bool |
|
): |
|
B, H, T, D = q.shape |
|
BT = mini_batch_size |
|
NT = T // BT |
|
|
|
_q = q.reshape(B, H, NT, BT, D).permute(2, 0, 1, 3, 4) |
|
_k = k.reshape(B, H, NT, BT, D).permute(2, 0, 1, 3, 4) |
|
_v = v.reshape(B, H, NT, BT, D).permute(2, 0, 1, 3, 4) |
|
|
|
_eta = eta.reshape(B, H, NT, BT, 1).permute(2, 0, 1, 3, 4) |
|
|
|
w = w.reshape(H, 1, D).to(torch.float32) |
|
b = b.reshape(H, 1, D).to(torch.float32) |
|
|
|
h = torch.zeros((B, H, D, D), device=v.device, dtype=torch.float32) if initial_state is None else initial_state |
|
hb = torch.zeros((B, H, 1, D), device=v.device, dtype=torch.float32) if initial_state_bias is None else initial_state_bias |
|
q *= scale |
|
|
|
o = torch.empty_like(_v) |
|
|
|
for i in range(NT): |
|
q_i, k_i, v_i, eta_i = [x[i] for x in [_q, _k, _v, _eta]] |
|
kh = k_i @ h + hb |
|
reconstruction_target = v_i - k_i |
|
|
|
mean = kh.mean(-1, True) |
|
var = kh.var(-1, unbiased=False, keepdim=True).to(torch.float32) |
|
rstd = torch.sqrt(var + eps).to(torch.float32) |
|
kh_hat = (kh - mean) / rstd |
|
|
|
g = w * kh_hat + b - reconstruction_target |
|
g *= w |
|
v_new = (D * g - g.sum(-1, True) - kh_hat * (g * kh_hat).sum(-1, True)) / (rstd * D) |
|
|
|
Attn = torch.tril(q_i @ k_i.transpose(-2, -1)) |
|
o_i = q_i @ h - (eta_i * Attn) @ v_new + hb - torch.tril(eta_i.expand_as(Attn)) @ v_new |
|
h = h - (eta_i[:, :, -1, :, None] * k_i).transpose(-1, -2) @ v_new |
|
hb = hb - torch.sum(eta_i[:, :, -1, :, None] * v_new, dim=-2, keepdim=True) |
|
|
|
|
|
mean = o_i.mean(dim=-1, keepdim=True) |
|
var = o_i.var(dim=-1, unbiased=False, keepdim=True).to(torch.float32) |
|
rstd = torch.sqrt(var + eps).to(torch.float32) |
|
o[i] = o_i + (o_i - mean) / rstd * w + b |
|
|
|
|
|
o = o.permute(1, 2, 0, 3, 4).reshape(B, H, T, D) |
|
h = h if output_final_state else None |
|
hb = hb if output_final_state else None |
|
return o, h, hb |
|
|
|
|
|
def chunk_ttt_linear_ref( |
|
q: torch.Tensor, |
|
k: torch.Tensor, |
|
v: torch.Tensor, |
|
w: torch.Tensor, |
|
b: torch.Tensor, |
|
eta: torch.Tensor, |
|
scale: float = None, |
|
eps: float = 1e-6, |
|
mini_batch_size: int = 16, |
|
initial_state: torch.Tensor = None, |
|
initial_state_bias: torch.Tensor = None, |
|
output_final_state: bool = False, |
|
head_first: bool = True, |
|
): |
|
assert q.dtype == k.dtype == v.dtype |
|
assert k.shape[-1] == v.shape[-1], "The key and value dimension must be the same." |
|
if isinstance(eta, float): |
|
eta = torch.full_like(q[:, :, :, :1], eta) |
|
if scale is None: |
|
scale = k.shape[-1] ** -0.5 |
|
if not head_first: |
|
q = q.transpose(1, 2) |
|
k = k.transpose(1, 2) |
|
v = v.transpose(1, 2) |
|
eta = eta.transpose(1, 2) |
|
T = q.shape[-2] |
|
padded = (mini_batch_size - (T % mini_batch_size)) % mini_batch_size |
|
if padded > 0: |
|
q = F.pad(q, (0, 0, 0, padded)) |
|
k = F.pad(k, (0, 0, 0, padded)) |
|
v = F.pad(v, (0, 0, 0, padded)) |
|
eta = F.pad(eta, (0, 0, 0, padded)) |
|
eta[:, :, -1, :] = eta[:, :, -(padded+1), :] |
|
assert q.shape[-2] % mini_batch_size == 0, "Sequence length should be a multiple of mini_batch_size." |
|
q, k, v, eta, w, b = map(lambda x: x.to(torch.float32), [q, k, v, eta, w, b]) |
|
o, final_state, final_state_bias = ttt_linear( |
|
q, |
|
k, |
|
v, |
|
w, |
|
b, |
|
eta, |
|
scale, |
|
eps, |
|
mini_batch_size, |
|
initial_state, |
|
initial_state_bias, |
|
output_final_state, |
|
) |
|
o = o[:, :, :T, :].contiguous() |
|
if not head_first: |
|
o = o.transpose(1, 2) |
|
return o, final_state, final_state_bias |
|
|