import torch import torch.nn as nn import torch.nn.functional as F class TemporalSelfAttention(nn.Module): def __init__(self, embed_dim, num_heads, bias_type="linear", gamma=1.0, causal=False): super().__init__() assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" assert bias_type in ["linear", "gaussian"] self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.bias_type = bias_type self.gamma = gamma self.causal = causal self.qkv = nn.Linear(embed_dim, 3 * embed_dim) self.out_proj = nn.Linear(embed_dim, embed_dim) def forward(self, x, timestamps): """ x: [B, T, D] timestamps: [B, T] — real-valued time signals per token """ B, T, D = x.size() # Project input to Q, K, V qkv = self.qkv(x).reshape(B, T, 3, self.num_heads, self.head_dim) q, k, v = qkv.unbind(dim=2) # each: [B, T, num_heads, head_dim] q = q.transpose(1, 2) # [B, num_heads, T, head_dim] k = k.transpose(1, 2) v = v.transpose(1, 2) # Scaled dot-product attention attn_logits = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5) # [B, H, T, T] # Compute temporal bias t_i = timestamps.unsqueeze(2) # [B, T, 1] t_j = timestamps.unsqueeze(1) # [B, 1, T] delta_t = t_j - t_i # [B, T, T] if self.bias_type == "linear": temporal_bias = -self.gamma * torch.abs(delta_t) # [B, T, T] elif self.bias_type == "gaussian": temporal_bias = -self.gamma * (delta_t ** 2) # Expand for broadcasting: [B, 1, T, T] attn_logits = attn_logits + temporal_bias.unsqueeze(1) # Causal masking (prevent attending to future) if self.causal: causal_mask = torch.tril(torch.ones(T, T, device=x.device)).unsqueeze(0).unsqueeze(0) # [1,1,T,T] attn_logits = attn_logits.masked_fill(causal_mask == 0, float("-inf")) attn_weights = F.softmax(attn_logits, dim=-1) # [B, H, T, T] attn_output = torch.matmul(attn_weights, v) # [B, H, T, head_dim] # Merge heads attn_output = attn_output.transpose(1, 2).reshape(B, T, D) output = self.out_proj(attn_output) return output, attn_weights