zaydzuhri's picture
Add files using upload-large-folder tool
183cbc0 verified
raw
history blame
13.8 kB
# -*- coding: utf-8 -*-
import torch
import torch.nn.functional as F
from fla.ops.titans.log_impl import combine_params_log
def cal_n(theta, eta, seq_len):
n = torch.zeros(*theta.shape, seq_len, dtype=theta.dtype).to(
theta.device
) # [batch_size, num_heads, seq_len, seq_len]
# 1. deal with diagonal elements
indices = torch.arange(seq_len, device=theta.device)
n[..., indices, indices] = theta[..., indices]
# 2. Create a cumulative product matrix
# First create a mask to mark the positions where eta needs to be multiplied
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).to(theta.device)
# Convert mask to boolean type
mask = mask.bool()
# Expand eta to match the target shape
eta_expanded = eta.unsqueeze(-2).expand(*theta.shape[:-1], seq_len, seq_len)
# Create a matrix filled with 1s for cumulative product
cumulative = torch.ones_like(eta_expanded)
cumulative = torch.where(mask, eta_expanded, cumulative)
# Calculate the cumulative product
cumulative_prod = torch.cumprod(cumulative, dim=-1)
# 3. Calculate non-diagonal elements
# Create an expanded version of theta
theta_expanded = theta.unsqueeze(-1).expand(*theta.shape[:-1], seq_len, seq_len)
# Create a mask to keep only the upper triangular part (excluding the diagonal)
upper_triangular = torch.triu(torch.ones_like(n), diagonal=1).bool()
# Combine theta and cumulative product
n = torch.where(upper_triangular, theta_expanded * cumulative_prod, n)
return n
def cal_f(beta, seq_len, m):
a = torch.tril(beta.to(torch.float32).unsqueeze(-1).expand(*beta.shape, seq_len), 0)
ratio = (m.to(torch.float32) / beta.to(torch.float32)).unsqueeze(-1)
f = torch.matmul(a, ratio).squeeze(-1)
return f.to(beta.dtype)
def cal_G(beta, n, seq_len):
i_indices = torch.arange(seq_len, device=beta.device)
j_indices = torch.arange(seq_len, device=beta.device)
k_indices = torch.arange(seq_len, device=beta.device)
beta_ratio = beta[..., :, None] / beta[..., None, :] # [..., i, k]
# create mask
k_mask = (k_indices[None, None, :] >= j_indices[None, :, None]) & (
k_indices[None, None, :] <= i_indices[:, None, None]
)
# use mask to filter out invalid values
masked_beta_ratio = beta_ratio[..., :, None, :] * k_mask # [..., i, j, k]
masked_n = n[..., None, :, :] * k_mask # [..., i, j, k]
# calculate G
G = torch.sum(masked_beta_ratio * masked_n, dim=-1) # [..., i, j]
return G
def combine_params(theta, alpha, eta, seq_len):
theta = theta.squeeze(-1)
eta = eta.squeeze(-1)
alpha = alpha.squeeze(-1)
beta = torch.cumprod(1 - alpha, dim=-1) # β_t = ∏(1 - α_t) in titans paper
beta_T = beta[..., -1] # β_T
# Calculate m_i = ∏(k=1 to i) η_k
m = torch.cumprod(eta, dim=-1) # [batch_size, num_heads, seq_len]
m_T = m[..., -1] # m_T
# Calculate n_{i,j}
# We need to calculate ∏(k=j+1 to i) η_k for each i,j pair
# # this may be optimized
# n = torch.zeros(*theta.shape, seq_len, dtype = theta.dtype).to(
# theta.device) # [batch_size, num_heads, seq_len, seq_len]
# for i in range(seq_len):
# for j in range(i + 1):
# if i == j:
# n[..., j, i] = theta[..., j]
# else:
# # Calculate product of eta from j+1 to i
# eta_product = torch.prod(eta[..., j + 1:i + 1], dim = -1)
# n[..., j, i] = theta[..., j] * eta_product
n = cal_n(theta, eta, seq_len)
n_T = n[..., -1] # [batch_size, num_heads, seq_len]
# Calculate f_t = ∑(i=1 to t) (β_t/β_i) m_i
# f = torch.zeros_like(theta)
# for t in range(seq_len):
# for i in range(t + 1):
# f[..., t] += (beta[..., t] / beta[..., i]) * m[..., i]
f = cal_f(beta, seq_len, m)
f_T = f[..., -1] # [batch_size, num_heads, seq_len]
# Calculate g_j = ∑(i=j to t) (β_t/β_i) n_{i,j}
# g = torch.zeros_like(theta) # [batch_size, num_heads, seq_len]
# for j in range(seq_len):
# for i in range(j, seq_len):
# g[..., j] += (beta[..., -1] / beta[..., i]) * n[..., j, i]
# G = torch.zeros(*beta.shape[:-1], seq_len, seq_len, device = beta.device)
# # Fill in the lower triangular part
# for i in range(seq_len): # row
# for j in range(i + 1): # column
# # Sum from k=j to i
# for k in range(j, i + 1):
# G[..., i, j] += (beta[..., i] / beta[..., k]) * n[..., j, k]
G = cal_G(beta, n, seq_len)
g = G[:, :, -1, :] # [batch_size, num_heads, seq_len]
# g2, G2 = compute_g_and_G(beta, n, seq_len)
return beta, beta_T, f, f_T, g, G, m_T, n_T
def titans_linear(
q, k, v, w, b, theta, alpha, eta, eps, chunk_size, initial_state, output_final_state
):
"""
Implementation of Titans Linear function based on the update rules:
M_t = (1 - alpha_t) * M_{t-1} + S_t
S_t = eta_t * S_{t-1} - theta_t * nabla_l(M_{t-1}; x_t)
Args:
q: Query tensor
k: Key tensor
v: Value tensor
w: Weight tensor
b: Bias tensor
theta: Learning rate tensor
alpha: Momentum decay tensor
eta: Step size tensor
eps: Epsilon for numerical stability
initial_state: Initial state M_0
output_final_state: Whether to output the final state
Returns:
Tuple of (output tensor, final state)
"""
B, H, T, D = q.shape
device = q.device
w = w.reshape(H, 1, D).to(torch.float32)
b = b.reshape(H, 1, D).to(torch.float32)
# Initialize states
if initial_state is None:
M_prev = torch.zeros(B, H, D, D, device=device)
else:
M_prev = initial_state
M_prev_nabla = M_prev.clone()
S_prev = torch.zeros_like(M_prev)
outputs = []
# Process sequence step by step
for t in range(T):
# Get current step inputs
q_t = q[:, :, t: t + 1, :] # (batch_size, num_heads, 1, dim)
k_t = k[:, :, t: t + 1, :] # (batch_size, num_heads, 1, dim)
v_t = v[:, :, t: t + 1, :] # (batch_size, num_heads, 1, dim)
theta_t = theta[:, :, t: t + 1, :] # (batch_size, num_heads, 1, dim)
alpha_t = alpha[:, :, t: t + 1, :] # (batch_size, num_heads, 1, dim)
eta_t = eta[:, :, t: t + 1, :] # (batch_size, num_heads, 1, dim)
# Compute gradient
km = k_t @ M_prev_nabla # (batch_size, num_heads, 1, dim)
reconstruction_target = v_t - k_t
mean = km.mean(-1, keepdim=True)
var = km.var(-1, unbiased=False, keepdim=True).to(torch.float32)
rstd = torch.sqrt(var + eps).to(torch.float32)
km_hat = (km - mean) / rstd
grad = w * km_hat + b - reconstruction_target
grad = grad * w
# v_new = (D * grad - grad.sum(-1, keepdim = True) - km_hat * (grad * km_hat).sum(-1, keepdim = True)) / (
# rstd * D)
v_new = D * grad - grad.sum(-1, keepdim=True) / (rstd * D)
proj_term = km_hat * (grad * km_hat).sum(-1, keepdim=True) / (rstd * D)
v_new = v_new - proj_term
# v_new = grad
# Update S_t
S_t = eta_t * S_prev - 2 * theta_t * k_t.transpose(-2, -1) @ v_new
# Update M_t
M_t = (1 - alpha_t) * M_prev + S_t
# Store output
output_t = q_t @ M_t # (batch_size, num_heads, seq_len, dim)
mean = output_t.mean(dim=-1, keepdim=True)
var = output_t.var(dim=-1, unbiased=False, keepdim=True).to(torch.float32)
rstd = torch.sqrt(var + eps).to(torch.float32)
output_t = output_t + (output_t - mean) / rstd * w + b
outputs.append(output_t)
# Update states for next step
if (t + 1) % chunk_size == 0:
M_prev_nabla = M_t.clone()
M_prev = M_t
S_prev = S_t
# Stack outputs along sequence dimension
output = torch.stack(outputs, dim=-2).squeeze(
-3
) # (batch_size, num_heads, seq_len, dim)
if output_final_state:
return output, M_prev
return output, None
def chunk_titans_linear(
q, k, v, w, b, theta, alpha, eta, eps, chunk_size, initial_state, output_final_state
):
B, H, T, D = q.shape
num_batch = T // chunk_size
# [num_batch, B, num_heads, mini_batch_size, head_dim]
_q = q.reshape(B, H, num_batch, chunk_size, D).permute(2, 0, 1, 3, 4)
_k = k.reshape(B, H, num_batch, chunk_size, D).permute(2, 0, 1, 3, 4)
_v = v.reshape(B, H, num_batch, chunk_size, D).permute(2, 0, 1, 3, 4)
# [num_batch, B, num_heads, mini_batch_size, 1]
_eta = eta.reshape(B, H, num_batch, chunk_size, 1).permute(2, 0, 1, 3, 4)
_theta = theta.reshape(B, H, num_batch, chunk_size, 1).permute(2, 0, 1, 3, 4)
_alpha = alpha.reshape(B, H, num_batch, chunk_size, 1).permute(2, 0, 1, 3, 4)
# [H, 1, D]
w = w.reshape(H, 1, D).to(torch.float32)
b = b.reshape(H, 1, D).to(torch.float32)
# [num_heads, 1, head_dim]
if initial_state is None:
M_prev = torch.zeros((B, H, D, D), device=v.device, dtype=v.dtype).to(
torch.float32
)
else:
M_prev = initial_state
S_prev = torch.zeros_like(M_prev)
# [num_batch, B, num_heads, mini_batch_size, head_dim]
o = torch.empty_like(_v)
for i in range(num_batch):
q_i, k_i, v_i, eta_i, theta_i, alpha_i = [
x[i] for x in [_q, _k, _v, _eta, _theta, _alpha]
]
# beta, beta_T, f, f_T, g, G, m_T, n = combine_params(theta_i, alpha_i, eta_i, chunk_size)
beta, beta_T, f, f_T, g, G, m_T, n = combine_params_log(
theta_i, alpha_i, eta_i, chunk_size
)
m_T = m_T.unsqueeze(-1).unsqueeze(-1)
beta_T = beta_T.unsqueeze(-1).unsqueeze(-1)
f_T = f_T.unsqueeze(-1).unsqueeze(-1)
g_diag = torch.diag_embed(g).to(q_i.dtype)
n = torch.diag_embed(n).to(q_i.dtype)
beta = torch.diag_embed(beta).to(q_i.dtype)
f = torch.diag_embed(f).to(q_i.dtype)
km = k_i @ M_prev
reconstruction_target = v_i - k_i
mean = km.mean(-1, True)
var = km.var(-1, unbiased=False, keepdim=True).to(torch.float32)
rstd = torch.sqrt(var + eps).to(torch.float32)
km_hat = (km - mean) / rstd
grad = w * km_hat + b - reconstruction_target
grad *= w
v_new = D * grad - grad.sum(-1, keepdim=True) / (rstd * D)
proj_term = km_hat * (grad * km_hat).sum(-1, keepdim=True) / (rstd * D)
v_new = v_new - proj_term
# v_new = (D * grad - grad.sum(-1, True))
# print(f"Projection term stats: min={torch.abs(beta_T).min()}")
# v_new = grad
Attn = torch.tril(q_i @ k_i.transpose(-2, -1)) * G
# o_i
output_t = beta @ q_i @ M_prev + f @ q_i @ S_prev - 2 * Attn @ v_new
M_t = (
beta_T * M_prev
+ f_T * S_prev
- 2 * (g_diag @ k_i).transpose(-1, -2) @ v_new
)
# cal S_T from S_0
S_t = m_T * S_prev - 2 * (n @ k_i).transpose(-1, -2) @ v_new
# layer norm with residuals
mean = output_t.mean(dim=-1, keepdim=True)
var = output_t.var(dim=-1, unbiased=False, keepdim=True).to(torch.float32)
rstd = torch.sqrt(var + eps).to(torch.float32)
output_t = output_t + (output_t - mean) / rstd * w + b
o[i] = output_t
S_prev = S_t
M_prev = M_t
# [B, num_mini_batch, mini_batch_size, num_heads, head_dim]
o = o.permute(1, 2, 0, 3, 4).reshape(B, H, T, D)
M_prev = M_prev if output_final_state else None
return o, M_prev
# most of the code is copied from ttt
def chunk_titans_linear_ref(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
w: torch.Tensor,
b: torch.Tensor,
theta: torch.Tensor,
alpha: torch.Tensor,
eta: torch.Tensor,
eps: float = 1e-6,
chunk_size: int = 16, # chunk size
initial_state: torch.Tensor = None,
output_final_state: bool = False,
head_first: bool = True,
use_chunk: bool = True,
):
assert q.dtype == k.dtype == v.dtype
assert k.shape[-1] == v.shape[-1], "DK must equal to DV."
if not head_first:
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
eta = eta.transpose(1, 2)
alpha = alpha.transpose(1, 2)
theta = theta.transpose(1, 2)
seq_len = q.shape[-2]
pad_len = (chunk_size - (seq_len % chunk_size)) % chunk_size
if pad_len > 0:
q = F.pad(q, (0, 0, 0, pad_len))
k = F.pad(k, (0, 0, 0, pad_len))
v = F.pad(v, (0, 0, 0, pad_len))
theta = F.pad(theta, (0, 0, 0, pad_len))
alpha = F.pad(alpha, (0, 0, 0, pad_len))
eta = F.pad(eta, (0, 0, 0, pad_len))
theta[:, :, -1, :] = theta[:, :, -(pad_len + 1), :]
alpha[:, :, -1, :] = alpha[:, :, -(pad_len + 1), :]
eta[:, :, -1, :] = eta[:, :, -(pad_len + 1), :]
assert q.shape[-2] % chunk_size == 0, "Sequence length should be a multiple of BT."
q, k, v, w, b = map(lambda x: x.to(torch.float32), [q, k, v, w, b])
if use_chunk:
o, final_state = chunk_titans_linear(
q,
k,
v,
w,
b,
theta,
alpha,
eta,
eps,
chunk_size,
initial_state,
output_final_state,
)
else:
o, final_state = titans_linear(
q,
k,
v,
w,
b,
theta,
alpha,
eta,
eps,
chunk_size,
initial_state,
output_final_state,
)
o = o[:, :, :seq_len, :]
if not head_first:
o = o.transpose(1, 2)
return o, final_state