Create core/mamba_block.py
Browse files- core/mamba_block.py +116 -0
core/mamba_block.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from einops import rearrange, repeat
|
5 |
+
import math
|
6 |
+
|
7 |
+
class MambaBlock(nn.Module):
|
8 |
+
"""
|
9 |
+
Production-ready Mamba block for graph processing
|
10 |
+
Based on official Mamba implementation with graph optimizations
|
11 |
+
"""
|
12 |
+
def __init__(self, d_model, d_state=16, d_conv=4, expand=2, dt_rank="auto", bias=False):
|
13 |
+
super().__init__()
|
14 |
+
self.d_model = d_model
|
15 |
+
self.d_state = d_state
|
16 |
+
self.d_conv = d_conv
|
17 |
+
self.expand = expand
|
18 |
+
self.d_inner = int(self.expand * self.d_model)
|
19 |
+
|
20 |
+
if dt_rank == "auto":
|
21 |
+
self.dt_rank = math.ceil(self.d_model / 16)
|
22 |
+
else:
|
23 |
+
self.dt_rank = dt_rank
|
24 |
+
|
25 |
+
# Linear projections
|
26 |
+
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias)
|
27 |
+
|
28 |
+
# Convolution for local patterns
|
29 |
+
self.conv1d = nn.Conv1d(
|
30 |
+
in_channels=self.d_inner,
|
31 |
+
out_channels=self.d_inner,
|
32 |
+
kernel_size=d_conv,
|
33 |
+
groups=self.d_inner,
|
34 |
+
padding=d_conv - 1,
|
35 |
+
bias=True,
|
36 |
+
)
|
37 |
+
|
38 |
+
# SSM parameters
|
39 |
+
self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False)
|
40 |
+
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
|
41 |
+
|
42 |
+
# Initialize A (state evolution matrix)
|
43 |
+
A = repeat(torch.arange(1, self.d_state + 1, dtype=torch.float32), 'n -> d n', d=self.d_inner)
|
44 |
+
self.A_log = nn.Parameter(torch.log(A))
|
45 |
+
self.D = nn.Parameter(torch.ones(self.d_inner))
|
46 |
+
|
47 |
+
# Output projection
|
48 |
+
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias)
|
49 |
+
|
50 |
+
# Activation
|
51 |
+
self.act = nn.SiLU()
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
"""
|
55 |
+
x: (batch, length, d_model)
|
56 |
+
Returns: (batch, length, d_model)
|
57 |
+
"""
|
58 |
+
batch, length, _ = x.shape
|
59 |
+
|
60 |
+
# Input projection and split
|
61 |
+
xz = self.in_proj(x) # (batch, length, 2 * d_inner)
|
62 |
+
x, z = xz.chunk(2, dim=-1) # Each: (batch, length, d_inner)
|
63 |
+
|
64 |
+
# Convolution
|
65 |
+
x = rearrange(x, 'b l d -> b d l')
|
66 |
+
x = self.conv1d(x)[:, :, :length]
|
67 |
+
x = rearrange(x, 'b d l -> b l d')
|
68 |
+
x = self.act(x)
|
69 |
+
|
70 |
+
# SSM
|
71 |
+
y = self.selective_scan(x)
|
72 |
+
|
73 |
+
# Gating
|
74 |
+
y = y * self.act(z)
|
75 |
+
|
76 |
+
# Output projection
|
77 |
+
out = self.out_proj(y)
|
78 |
+
|
79 |
+
return out
|
80 |
+
|
81 |
+
def selective_scan(self, u):
|
82 |
+
"""Selective scan operation - core of Mamba"""
|
83 |
+
batch, length, d_inner = u.shape
|
84 |
+
|
85 |
+
# Compute ∆, B, C
|
86 |
+
x_dbl = self.x_proj(u) # (batch, length, dt_rank + 2*d_state)
|
87 |
+
delta, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
88 |
+
|
89 |
+
# Softplus ensures delta > 0
|
90 |
+
delta = F.softplus(self.dt_proj(delta)) # (batch, length, d_inner)
|
91 |
+
|
92 |
+
return self._selective_scan_pytorch(u, delta, B, C)
|
93 |
+
|
94 |
+
def _selective_scan_pytorch(self, u, delta, B, C):
|
95 |
+
"""PyTorch implementation of selective scan"""
|
96 |
+
batch, length, d_inner = u.shape
|
97 |
+
|
98 |
+
# Discretize
|
99 |
+
deltaA = torch.exp(delta.unsqueeze(-1) * (-torch.exp(self.A_log))) # (batch, length, d_inner, d_state)
|
100 |
+
deltaB_u = delta.unsqueeze(-1) * B.unsqueeze(2) * u.unsqueeze(-1) # (batch, length, d_inner, d_state)
|
101 |
+
|
102 |
+
# Initialize state
|
103 |
+
x = torch.zeros((batch, d_inner, self.d_state), device=u.device, dtype=u.dtype)
|
104 |
+
ys = []
|
105 |
+
|
106 |
+
for i in range(length):
|
107 |
+
x = deltaA[:, i] * x + deltaB_u[:, i]
|
108 |
+
y = torch.einsum('bdn,bn->bd', x, C[:, i])
|
109 |
+
ys.append(y)
|
110 |
+
|
111 |
+
y = torch.stack(ys, dim=1) # (batch, length, d_inner)
|
112 |
+
|
113 |
+
# Add skip connection
|
114 |
+
y = y + u * self.D
|
115 |
+
|
116 |
+
return y
|