zaydzuhri's picture
Add files using upload-large-folder tool
183cbc0 verified
raw
history blame
4.75 kB
import torch
def cal_n_log(log_theta, log_eta, seq_len):
"""
calculate n_{i,j} in log space
log(n_{i,j}) = log(θ_j) + sum_{k=j+1}^i log(η_k)
"""
# create log(n)
log_n = torch.zeros(*log_theta.shape, seq_len, dtype=log_eta.dtype).to(
log_eta.device
) # [batch_size, num_heads, seq_len, seq_len]
for i in range(seq_len):
for j in range(i + 1):
if i == j:
log_n[..., j, i] = log_theta[..., j]
else:
log_n[..., j, i] = log_theta[..., j] + torch.sum(
log_eta[..., j + 1: i + 1], dim=-1
)
return log_n
def cal_f_log(log_beta, seq_len, log_m):
"""
cal_f_log(log_beta, seq_len, log_m) -> f
log(f_t) = log(sum_{i=1}^t exp(sum_{k=i+1}^t log(1-α_k) + sum_{k=1}^i log(η_k)))
"""
# create f
# f = torch.zeros_like(log_beta)
# for t in range(seq_len):
# for i in range(t + 1):
# f[..., t] += torch.exp(log_beta[..., t] - log_beta[..., i] + log_m[..., i])
log_f = torch.zeros_like(log_beta)
for t in range(seq_len):
a_i = log_beta[..., t: t + 1] - log_beta[..., : t + 1] + log_m[..., : t + 1]
log_f[..., t] = torch.logsumexp(a_i, dim=-1)
f = torch.exp(log_f)
# this version overflow and even slower
# t_indices = torch.arange(seq_len, device=log_beta.device)
# i_indices = torch.arange(seq_len, device=log_beta.device)
#
# mask = i_indices.unsqueeze(0) <= t_indices.unsqueeze(1)
# log_beta_t = log_beta.unsqueeze(-1) # [..., seq_len, 1]
# log_beta_i = log_beta.unsqueeze(-2) # [..., 1, seq_len]
# log_m_i = log_m.unsqueeze(-2)
# a_i = log_beta_t - log_beta_i + log_m_i
# masked_a_i = torch.where(mask, a_i, torch.tensor(-float('inf'), device=a_i.device, dtype=a_i.dtype))
# log_f = torch.logsumexp(masked_a_i, dim=-1) # [..., seq_len]
#
# f = torch.exp(log_f)
return f
def cal_G_log(log_beta, log_n, seq_len):
"""
calculate G_{i,j}
log(G_{i,j}) = log(sum_{k=j}^i exp(log(β_i/β_k) + log(n_{k,j})))
"""
# G = torch.zeros(*log_beta.shape[:-1], seq_len, seq_len, device = log_beta.device)
# # Fill in the lower triangular part
# for i in range(seq_len): # row
# for j in range(i + 1): # column
# # Sum from k=j to i
# for k in range(j, i + 1):
# G[..., i, j] += torch.exp(log_beta[..., i] - log_beta[..., k] + log_n[..., j, k])
log_G = torch.full(
(*log_beta.shape[:-1], seq_len, seq_len), float("-inf"), device=log_beta.device
)
# fill in the lower triangular part
for i in range(seq_len): # row
for j in range(i + 1): # column
terms = (
log_beta[..., i: i + 1]
- log_beta[..., j: i + 1]
+ log_n[..., j: j + 1, j: i + 1].squeeze(-2)
)
# use logsumexp to avoid overflow
log_G[..., i, j] = torch.logsumexp(terms, dim=-1)
G = torch.exp(log_G)
return G
def _combine_params_log(log_theta, log_alpha_complement, log_eta, seq_len):
"""
Update rule for Titans in log space
Parameters:
- log_theta: log(θ)
- log_alpha_complement: log(1-α)
- log_eta: log(η)
- seq_len: sequence length
Returns:
- log_beta, beta_T, log_f, f_T, log_g, log_G, m_T, n_T
"""
# calculate log(β_t) = sum_{k=1}^t log(1-α_k)
log_beta = torch.cumsum(log_alpha_complement, dim=-1)
# get β_T
beta_T = torch.exp(log_beta[..., -1])
# calculate log(m_i) = sum_{k=1}^i log(η_k)
log_m = torch.cumsum(log_eta, dim=-1)
m_T = torch.exp(log_m[..., -1])
# cal log(n_{i,j})
log_n = cal_n_log(log_theta, log_eta, seq_len)
n_T = torch.exp(log_n[..., -1])
# cal log(f_t)
f = cal_f_log(log_beta, seq_len, log_m)
f_T = f[..., -1]
# cal log(G_{i,j})
G = cal_G_log(log_beta, log_n, seq_len)
# get log(g_j) = log(G_{T,j})
g = G[..., -1, :]
return log_beta, beta_T, f, f_T, g, G, m_T, n_T
def combine_params_log(theta, alpha, eta, seq_len):
"""
log space Titians
Parameters:
- theta: θ
- alpha: α
- eta: η
- seq_len: sequence length
Returns:
- beta, beta_T, f, f_T, g, G, m_T, n_T
"""
# convert to log space
log_theta = torch.log(theta.squeeze(-1))
log_alpha_complement = torch.log(1 - alpha.squeeze(-1))
log_eta = torch.log(eta.squeeze(-1))
# combine params in log space
log_beta, beta_T, f, f_T, g, G, m_T, n_T = _combine_params_log(
log_theta, log_alpha_complement, log_eta, seq_len
)
# convert back to normal space
beta = torch.exp(log_beta)
return beta, beta_T, f, f_T, g, G, m_T, n_T