kfoughali commited on
Commit
159f602
·
verified ·
1 Parent(s): e0c1384

Create core/mamba_block.py

Browse files
Files changed (1) hide show
  1. core/mamba_block.py +116 -0
core/mamba_block.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops import rearrange, repeat
5
+ import math
6
+
7
+ class MambaBlock(nn.Module):
8
+ """
9
+ Production-ready Mamba block for graph processing
10
+ Based on official Mamba implementation with graph optimizations
11
+ """
12
+ def __init__(self, d_model, d_state=16, d_conv=4, expand=2, dt_rank="auto", bias=False):
13
+ super().__init__()
14
+ self.d_model = d_model
15
+ self.d_state = d_state
16
+ self.d_conv = d_conv
17
+ self.expand = expand
18
+ self.d_inner = int(self.expand * self.d_model)
19
+
20
+ if dt_rank == "auto":
21
+ self.dt_rank = math.ceil(self.d_model / 16)
22
+ else:
23
+ self.dt_rank = dt_rank
24
+
25
+ # Linear projections
26
+ self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias)
27
+
28
+ # Convolution for local patterns
29
+ self.conv1d = nn.Conv1d(
30
+ in_channels=self.d_inner,
31
+ out_channels=self.d_inner,
32
+ kernel_size=d_conv,
33
+ groups=self.d_inner,
34
+ padding=d_conv - 1,
35
+ bias=True,
36
+ )
37
+
38
+ # SSM parameters
39
+ self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False)
40
+ self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
41
+
42
+ # Initialize A (state evolution matrix)
43
+ A = repeat(torch.arange(1, self.d_state + 1, dtype=torch.float32), 'n -> d n', d=self.d_inner)
44
+ self.A_log = nn.Parameter(torch.log(A))
45
+ self.D = nn.Parameter(torch.ones(self.d_inner))
46
+
47
+ # Output projection
48
+ self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias)
49
+
50
+ # Activation
51
+ self.act = nn.SiLU()
52
+
53
+ def forward(self, x):
54
+ """
55
+ x: (batch, length, d_model)
56
+ Returns: (batch, length, d_model)
57
+ """
58
+ batch, length, _ = x.shape
59
+
60
+ # Input projection and split
61
+ xz = self.in_proj(x) # (batch, length, 2 * d_inner)
62
+ x, z = xz.chunk(2, dim=-1) # Each: (batch, length, d_inner)
63
+
64
+ # Convolution
65
+ x = rearrange(x, 'b l d -> b d l')
66
+ x = self.conv1d(x)[:, :, :length]
67
+ x = rearrange(x, 'b d l -> b l d')
68
+ x = self.act(x)
69
+
70
+ # SSM
71
+ y = self.selective_scan(x)
72
+
73
+ # Gating
74
+ y = y * self.act(z)
75
+
76
+ # Output projection
77
+ out = self.out_proj(y)
78
+
79
+ return out
80
+
81
+ def selective_scan(self, u):
82
+ """Selective scan operation - core of Mamba"""
83
+ batch, length, d_inner = u.shape
84
+
85
+ # Compute ∆, B, C
86
+ x_dbl = self.x_proj(u) # (batch, length, dt_rank + 2*d_state)
87
+ delta, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
88
+
89
+ # Softplus ensures delta > 0
90
+ delta = F.softplus(self.dt_proj(delta)) # (batch, length, d_inner)
91
+
92
+ return self._selective_scan_pytorch(u, delta, B, C)
93
+
94
+ def _selective_scan_pytorch(self, u, delta, B, C):
95
+ """PyTorch implementation of selective scan"""
96
+ batch, length, d_inner = u.shape
97
+
98
+ # Discretize
99
+ deltaA = torch.exp(delta.unsqueeze(-1) * (-torch.exp(self.A_log))) # (batch, length, d_inner, d_state)
100
+ deltaB_u = delta.unsqueeze(-1) * B.unsqueeze(2) * u.unsqueeze(-1) # (batch, length, d_inner, d_state)
101
+
102
+ # Initialize state
103
+ x = torch.zeros((batch, d_inner, self.d_state), device=u.device, dtype=u.dtype)
104
+ ys = []
105
+
106
+ for i in range(length):
107
+ x = deltaA[:, i] * x + deltaB_u[:, i]
108
+ y = torch.einsum('bdn,bn->bd', x, C[:, i])
109
+ ys.append(y)
110
+
111
+ y = torch.stack(ys, dim=1) # (batch, length, d_inner)
112
+
113
+ # Add skip connection
114
+ y = y + u * self.D
115
+
116
+ return y