sanskxr02 commited on
Commit
0076e3f
·
verified ·
1 Parent(s): dfdc6d9

Update temporal_attention.py

Browse files
Files changed (1) hide show
  1. temporal_attention.py +64 -1
temporal_attention.py CHANGED
@@ -1 +1,64 @@
1
- <PASTE YOUR FINAL TEMPORAL ATTENTION CLASS HERE>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class TemporalSelfAttention(nn.Module):
6
+ def __init__(self, embed_dim, num_heads, bias_type="linear", gamma=1.0, causal=False):
7
+ super().__init__()
8
+ assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
9
+ assert bias_type in ["linear", "gaussian"]
10
+
11
+ self.embed_dim = embed_dim
12
+ self.num_heads = num_heads
13
+ self.head_dim = embed_dim // num_heads
14
+ self.bias_type = bias_type
15
+ self.gamma = gamma
16
+ self.causal = causal
17
+
18
+ self.qkv = nn.Linear(embed_dim, 3 * embed_dim)
19
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
20
+
21
+ def forward(self, x, timestamps):
22
+ """
23
+ x: [B, T, D]
24
+ timestamps: [B, T] — real-valued time signals per token
25
+ """
26
+ B, T, D = x.size()
27
+
28
+ # Project input to Q, K, V
29
+ qkv = self.qkv(x).reshape(B, T, 3, self.num_heads, self.head_dim)
30
+ q, k, v = qkv.unbind(dim=2) # each: [B, T, num_heads, head_dim]
31
+
32
+ q = q.transpose(1, 2) # [B, num_heads, T, head_dim]
33
+ k = k.transpose(1, 2)
34
+ v = v.transpose(1, 2)
35
+
36
+ # Scaled dot-product attention
37
+ attn_logits = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5) # [B, H, T, T]
38
+
39
+ # Compute temporal bias
40
+ t_i = timestamps.unsqueeze(2) # [B, T, 1]
41
+ t_j = timestamps.unsqueeze(1) # [B, 1, T]
42
+ delta_t = t_j - t_i # [B, T, T]
43
+
44
+ if self.bias_type == "linear":
45
+ temporal_bias = -self.gamma * torch.abs(delta_t) # [B, T, T]
46
+ elif self.bias_type == "gaussian":
47
+ temporal_bias = -self.gamma * (delta_t ** 2)
48
+
49
+ # Expand for broadcasting: [B, 1, T, T]
50
+ attn_logits = attn_logits + temporal_bias.unsqueeze(1)
51
+
52
+ # Causal masking (prevent attending to future)
53
+ if self.causal:
54
+ causal_mask = torch.tril(torch.ones(T, T, device=x.device)).unsqueeze(0).unsqueeze(0) # [1,1,T,T]
55
+ attn_logits = attn_logits.masked_fill(causal_mask == 0, float("-inf"))
56
+
57
+ attn_weights = F.softmax(attn_logits, dim=-1) # [B, H, T, T]
58
+ attn_output = torch.matmul(attn_weights, v) # [B, H, T, head_dim]
59
+
60
+ # Merge heads
61
+ attn_output = attn_output.transpose(1, 2).reshape(B, T, D)
62
+ output = self.out_proj(attn_output)
63
+
64
+ return output, attn_weights