|
|
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from fla.ops.titans.log_impl import combine_params_log |
|
|
|
|
|
def cal_n(theta, eta, seq_len): |
|
n = torch.zeros(*theta.shape, seq_len, dtype=theta.dtype).to( |
|
theta.device |
|
) |
|
|
|
|
|
indices = torch.arange(seq_len, device=theta.device) |
|
n[..., indices, indices] = theta[..., indices] |
|
|
|
|
|
|
|
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).to(theta.device) |
|
|
|
mask = mask.bool() |
|
|
|
eta_expanded = eta.unsqueeze(-2).expand(*theta.shape[:-1], seq_len, seq_len) |
|
|
|
cumulative = torch.ones_like(eta_expanded) |
|
cumulative = torch.where(mask, eta_expanded, cumulative) |
|
|
|
cumulative_prod = torch.cumprod(cumulative, dim=-1) |
|
|
|
|
|
|
|
theta_expanded = theta.unsqueeze(-1).expand(*theta.shape[:-1], seq_len, seq_len) |
|
|
|
upper_triangular = torch.triu(torch.ones_like(n), diagonal=1).bool() |
|
|
|
n = torch.where(upper_triangular, theta_expanded * cumulative_prod, n) |
|
return n |
|
|
|
|
|
def cal_f(beta, seq_len, m): |
|
a = torch.tril(beta.to(torch.float32).unsqueeze(-1).expand(*beta.shape, seq_len), 0) |
|
ratio = (m.to(torch.float32) / beta.to(torch.float32)).unsqueeze(-1) |
|
f = torch.matmul(a, ratio).squeeze(-1) |
|
return f.to(beta.dtype) |
|
|
|
|
|
def cal_G(beta, n, seq_len): |
|
i_indices = torch.arange(seq_len, device=beta.device) |
|
j_indices = torch.arange(seq_len, device=beta.device) |
|
k_indices = torch.arange(seq_len, device=beta.device) |
|
beta_ratio = beta[..., :, None] / beta[..., None, :] |
|
|
|
|
|
k_mask = (k_indices[None, None, :] >= j_indices[None, :, None]) & ( |
|
k_indices[None, None, :] <= i_indices[:, None, None] |
|
) |
|
|
|
|
|
masked_beta_ratio = beta_ratio[..., :, None, :] * k_mask |
|
masked_n = n[..., None, :, :] * k_mask |
|
|
|
G = torch.sum(masked_beta_ratio * masked_n, dim=-1) |
|
return G |
|
|
|
|
|
def combine_params(theta, alpha, eta, seq_len): |
|
theta = theta.squeeze(-1) |
|
eta = eta.squeeze(-1) |
|
alpha = alpha.squeeze(-1) |
|
beta = torch.cumprod(1 - alpha, dim=-1) |
|
beta_T = beta[..., -1] |
|
|
|
m = torch.cumprod(eta, dim=-1) |
|
m_T = m[..., -1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
n = cal_n(theta, eta, seq_len) |
|
n_T = n[..., -1] |
|
|
|
|
|
|
|
|
|
|
|
f = cal_f(beta, seq_len, m) |
|
f_T = f[..., -1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
G = cal_G(beta, n, seq_len) |
|
g = G[:, :, -1, :] |
|
|
|
return beta, beta_T, f, f_T, g, G, m_T, n_T |
|
|
|
|
|
def titans_linear( |
|
q, k, v, w, b, theta, alpha, eta, eps, chunk_size, initial_state, output_final_state |
|
): |
|
""" |
|
Implementation of Titans Linear function based on the update rules: |
|
M_t = (1 - alpha_t) * M_{t-1} + S_t |
|
S_t = eta_t * S_{t-1} - theta_t * nabla_l(M_{t-1}; x_t) |
|
|
|
Args: |
|
q: Query tensor |
|
k: Key tensor |
|
v: Value tensor |
|
w: Weight tensor |
|
b: Bias tensor |
|
theta: Learning rate tensor |
|
alpha: Momentum decay tensor |
|
eta: Step size tensor |
|
eps: Epsilon for numerical stability |
|
initial_state: Initial state M_0 |
|
output_final_state: Whether to output the final state |
|
|
|
Returns: |
|
Tuple of (output tensor, final state) |
|
""" |
|
B, H, T, D = q.shape |
|
device = q.device |
|
w = w.reshape(H, 1, D).to(torch.float32) |
|
b = b.reshape(H, 1, D).to(torch.float32) |
|
|
|
if initial_state is None: |
|
M_prev = torch.zeros(B, H, D, D, device=device) |
|
else: |
|
M_prev = initial_state |
|
M_prev_nabla = M_prev.clone() |
|
S_prev = torch.zeros_like(M_prev) |
|
outputs = [] |
|
|
|
|
|
for t in range(T): |
|
|
|
q_t = q[:, :, t: t + 1, :] |
|
k_t = k[:, :, t: t + 1, :] |
|
v_t = v[:, :, t: t + 1, :] |
|
theta_t = theta[:, :, t: t + 1, :] |
|
alpha_t = alpha[:, :, t: t + 1, :] |
|
eta_t = eta[:, :, t: t + 1, :] |
|
|
|
|
|
km = k_t @ M_prev_nabla |
|
reconstruction_target = v_t - k_t |
|
mean = km.mean(-1, keepdim=True) |
|
var = km.var(-1, unbiased=False, keepdim=True).to(torch.float32) |
|
rstd = torch.sqrt(var + eps).to(torch.float32) |
|
km_hat = (km - mean) / rstd |
|
|
|
grad = w * km_hat + b - reconstruction_target |
|
grad = grad * w |
|
|
|
|
|
v_new = D * grad - grad.sum(-1, keepdim=True) / (rstd * D) |
|
proj_term = km_hat * (grad * km_hat).sum(-1, keepdim=True) / (rstd * D) |
|
v_new = v_new - proj_term |
|
|
|
|
|
|
|
S_t = eta_t * S_prev - 2 * theta_t * k_t.transpose(-2, -1) @ v_new |
|
|
|
|
|
M_t = (1 - alpha_t) * M_prev + S_t |
|
|
|
|
|
output_t = q_t @ M_t |
|
mean = output_t.mean(dim=-1, keepdim=True) |
|
var = output_t.var(dim=-1, unbiased=False, keepdim=True).to(torch.float32) |
|
rstd = torch.sqrt(var + eps).to(torch.float32) |
|
output_t = output_t + (output_t - mean) / rstd * w + b |
|
outputs.append(output_t) |
|
|
|
|
|
if (t + 1) % chunk_size == 0: |
|
M_prev_nabla = M_t.clone() |
|
M_prev = M_t |
|
S_prev = S_t |
|
|
|
|
|
output = torch.stack(outputs, dim=-2).squeeze( |
|
-3 |
|
) |
|
|
|
if output_final_state: |
|
return output, M_prev |
|
return output, None |
|
|
|
|
|
def chunk_titans_linear( |
|
q, k, v, w, b, theta, alpha, eta, eps, chunk_size, initial_state, output_final_state |
|
): |
|
B, H, T, D = q.shape |
|
num_batch = T // chunk_size |
|
|
|
_q = q.reshape(B, H, num_batch, chunk_size, D).permute(2, 0, 1, 3, 4) |
|
_k = k.reshape(B, H, num_batch, chunk_size, D).permute(2, 0, 1, 3, 4) |
|
_v = v.reshape(B, H, num_batch, chunk_size, D).permute(2, 0, 1, 3, 4) |
|
|
|
_eta = eta.reshape(B, H, num_batch, chunk_size, 1).permute(2, 0, 1, 3, 4) |
|
_theta = theta.reshape(B, H, num_batch, chunk_size, 1).permute(2, 0, 1, 3, 4) |
|
_alpha = alpha.reshape(B, H, num_batch, chunk_size, 1).permute(2, 0, 1, 3, 4) |
|
|
|
w = w.reshape(H, 1, D).to(torch.float32) |
|
b = b.reshape(H, 1, D).to(torch.float32) |
|
|
|
if initial_state is None: |
|
M_prev = torch.zeros((B, H, D, D), device=v.device, dtype=v.dtype).to( |
|
torch.float32 |
|
) |
|
else: |
|
M_prev = initial_state |
|
|
|
S_prev = torch.zeros_like(M_prev) |
|
|
|
|
|
o = torch.empty_like(_v) |
|
|
|
for i in range(num_batch): |
|
q_i, k_i, v_i, eta_i, theta_i, alpha_i = [ |
|
x[i] for x in [_q, _k, _v, _eta, _theta, _alpha] |
|
] |
|
|
|
|
|
beta, beta_T, f, f_T, g, G, m_T, n = combine_params_log( |
|
theta_i, alpha_i, eta_i, chunk_size |
|
) |
|
|
|
m_T = m_T.unsqueeze(-1).unsqueeze(-1) |
|
beta_T = beta_T.unsqueeze(-1).unsqueeze(-1) |
|
f_T = f_T.unsqueeze(-1).unsqueeze(-1) |
|
g_diag = torch.diag_embed(g).to(q_i.dtype) |
|
n = torch.diag_embed(n).to(q_i.dtype) |
|
beta = torch.diag_embed(beta).to(q_i.dtype) |
|
f = torch.diag_embed(f).to(q_i.dtype) |
|
km = k_i @ M_prev |
|
reconstruction_target = v_i - k_i |
|
|
|
mean = km.mean(-1, True) |
|
var = km.var(-1, unbiased=False, keepdim=True).to(torch.float32) |
|
rstd = torch.sqrt(var + eps).to(torch.float32) |
|
km_hat = (km - mean) / rstd |
|
|
|
grad = w * km_hat + b - reconstruction_target |
|
grad *= w |
|
v_new = D * grad - grad.sum(-1, keepdim=True) / (rstd * D) |
|
proj_term = km_hat * (grad * km_hat).sum(-1, keepdim=True) / (rstd * D) |
|
v_new = v_new - proj_term |
|
|
|
|
|
|
|
|
|
|
|
Attn = torch.tril(q_i @ k_i.transpose(-2, -1)) * G |
|
|
|
|
|
output_t = beta @ q_i @ M_prev + f @ q_i @ S_prev - 2 * Attn @ v_new |
|
|
|
M_t = ( |
|
beta_T * M_prev |
|
+ f_T * S_prev |
|
- 2 * (g_diag @ k_i).transpose(-1, -2) @ v_new |
|
) |
|
|
|
S_t = m_T * S_prev - 2 * (n @ k_i).transpose(-1, -2) @ v_new |
|
|
|
mean = output_t.mean(dim=-1, keepdim=True) |
|
var = output_t.var(dim=-1, unbiased=False, keepdim=True).to(torch.float32) |
|
rstd = torch.sqrt(var + eps).to(torch.float32) |
|
output_t = output_t + (output_t - mean) / rstd * w + b |
|
o[i] = output_t |
|
S_prev = S_t |
|
M_prev = M_t |
|
|
|
|
|
o = o.permute(1, 2, 0, 3, 4).reshape(B, H, T, D) |
|
M_prev = M_prev if output_final_state else None |
|
return o, M_prev |
|
|
|
|
|
|
|
def chunk_titans_linear_ref( |
|
q: torch.Tensor, |
|
k: torch.Tensor, |
|
v: torch.Tensor, |
|
w: torch.Tensor, |
|
b: torch.Tensor, |
|
theta: torch.Tensor, |
|
alpha: torch.Tensor, |
|
eta: torch.Tensor, |
|
eps: float = 1e-6, |
|
chunk_size: int = 16, |
|
initial_state: torch.Tensor = None, |
|
output_final_state: bool = False, |
|
head_first: bool = True, |
|
use_chunk: bool = True, |
|
): |
|
assert q.dtype == k.dtype == v.dtype |
|
assert k.shape[-1] == v.shape[-1], "DK must equal to DV." |
|
if not head_first: |
|
q = q.transpose(1, 2) |
|
k = k.transpose(1, 2) |
|
v = v.transpose(1, 2) |
|
eta = eta.transpose(1, 2) |
|
alpha = alpha.transpose(1, 2) |
|
theta = theta.transpose(1, 2) |
|
seq_len = q.shape[-2] |
|
pad_len = (chunk_size - (seq_len % chunk_size)) % chunk_size |
|
if pad_len > 0: |
|
q = F.pad(q, (0, 0, 0, pad_len)) |
|
k = F.pad(k, (0, 0, 0, pad_len)) |
|
v = F.pad(v, (0, 0, 0, pad_len)) |
|
theta = F.pad(theta, (0, 0, 0, pad_len)) |
|
alpha = F.pad(alpha, (0, 0, 0, pad_len)) |
|
eta = F.pad(eta, (0, 0, 0, pad_len)) |
|
theta[:, :, -1, :] = theta[:, :, -(pad_len + 1), :] |
|
alpha[:, :, -1, :] = alpha[:, :, -(pad_len + 1), :] |
|
eta[:, :, -1, :] = eta[:, :, -(pad_len + 1), :] |
|
assert q.shape[-2] % chunk_size == 0, "Sequence length should be a multiple of BT." |
|
q, k, v, w, b = map(lambda x: x.to(torch.float32), [q, k, v, w, b]) |
|
if use_chunk: |
|
o, final_state = chunk_titans_linear( |
|
q, |
|
k, |
|
v, |
|
w, |
|
b, |
|
theta, |
|
alpha, |
|
eta, |
|
eps, |
|
chunk_size, |
|
initial_state, |
|
output_final_state, |
|
) |
|
else: |
|
o, final_state = titans_linear( |
|
q, |
|
k, |
|
v, |
|
w, |
|
b, |
|
theta, |
|
alpha, |
|
eta, |
|
eps, |
|
chunk_size, |
|
initial_state, |
|
output_final_state, |
|
) |
|
o = o[:, :, :seq_len, :] |
|
if not head_first: |
|
o = o.transpose(1, 2) |
|
return o, final_state |
|
|