# -*- coding: utf-8 -*- # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang, Yuqi Pan import torch import torch.nn.functional as F def ttt_linear( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, w: torch.Tensor, b: torch.Tensor, eta: torch.Tensor, scale: float, eps: float, mini_batch_size: int, initial_state: torch.Tensor, initial_state_bias: torch.Tensor, output_final_state: bool ): B, H, T, D = q.shape BT = mini_batch_size NT = T // BT # [NT, B, H, mini_batch_size, D] _q = q.reshape(B, H, NT, BT, D).permute(2, 0, 1, 3, 4) _k = k.reshape(B, H, NT, BT, D).permute(2, 0, 1, 3, 4) _v = v.reshape(B, H, NT, BT, D).permute(2, 0, 1, 3, 4) # [NT, B, H, BT, 1] _eta = eta.reshape(B, H, NT, BT, 1).permute(2, 0, 1, 3, 4) # [H, 1, D] w = w.reshape(H, 1, D).to(torch.float32) b = b.reshape(H, 1, D).to(torch.float32) h = torch.zeros((B, H, D, D), device=v.device, dtype=torch.float32) if initial_state is None else initial_state hb = torch.zeros((B, H, 1, D), device=v.device, dtype=torch.float32) if initial_state_bias is None else initial_state_bias q *= scale # [NT, B, H, BT, D] o = torch.empty_like(_v) for i in range(NT): q_i, k_i, v_i, eta_i = [x[i] for x in [_q, _k, _v, _eta]] kh = k_i @ h + hb reconstruction_target = v_i - k_i mean = kh.mean(-1, True) var = kh.var(-1, unbiased=False, keepdim=True).to(torch.float32) rstd = torch.sqrt(var + eps).to(torch.float32) kh_hat = (kh - mean) / rstd g = w * kh_hat + b - reconstruction_target g *= w v_new = (D * g - g.sum(-1, True) - kh_hat * (g * kh_hat).sum(-1, True)) / (rstd * D) Attn = torch.tril(q_i @ k_i.transpose(-2, -1)) o_i = q_i @ h - (eta_i * Attn) @ v_new + hb - torch.tril(eta_i.expand_as(Attn)) @ v_new h = h - (eta_i[:, :, -1, :, None] * k_i).transpose(-1, -2) @ v_new hb = hb - torch.sum(eta_i[:, :, -1, :, None] * v_new, dim=-2, keepdim=True) # layer norm with residuals mean = o_i.mean(dim=-1, keepdim=True) var = o_i.var(dim=-1, unbiased=False, keepdim=True).to(torch.float32) rstd = torch.sqrt(var + eps).to(torch.float32) o[i] = o_i + (o_i - mean) / rstd * w + b # [B, H, T, D] o = o.permute(1, 2, 0, 3, 4).reshape(B, H, T, D) h = h if output_final_state else None hb = hb if output_final_state else None return o, h, hb def chunk_ttt_linear_ref( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, w: torch.Tensor, b: torch.Tensor, eta: torch.Tensor, scale: float = None, eps: float = 1e-6, mini_batch_size: int = 16, initial_state: torch.Tensor = None, initial_state_bias: torch.Tensor = None, output_final_state: bool = False, head_first: bool = True, ): assert q.dtype == k.dtype == v.dtype assert k.shape[-1] == v.shape[-1], "The key and value dimension must be the same." if isinstance(eta, float): eta = torch.full_like(q[:, :, :, :1], eta) if scale is None: scale = k.shape[-1] ** -0.5 if not head_first: q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) eta = eta.transpose(1, 2) T = q.shape[-2] padded = (mini_batch_size - (T % mini_batch_size)) % mini_batch_size if padded > 0: q = F.pad(q, (0, 0, 0, padded)) k = F.pad(k, (0, 0, 0, padded)) v = F.pad(v, (0, 0, 0, padded)) eta = F.pad(eta, (0, 0, 0, padded)) eta[:, :, -1, :] = eta[:, :, -(padded+1), :] assert q.shape[-2] % mini_batch_size == 0, "Sequence length should be a multiple of mini_batch_size." q, k, v, eta, w, b = map(lambda x: x.to(torch.float32), [q, k, v, eta, w, b]) o, final_state, final_state_bias = ttt_linear( q, k, v, w, b, eta, scale, eps, mini_batch_size, initial_state, initial_state_bias, output_final_state, ) o = o[:, :, :T, :].contiguous() if not head_first: o = o.transpose(1, 2) return o, final_state, final_state_bias