Update core/mamba_block.py
Browse files- core/mamba_block.py +16 -5
core/mamba_block.py
CHANGED
@@ -7,7 +7,7 @@ import math
|
|
7 |
class MambaBlock(nn.Module):
|
8 |
"""
|
9 |
Production-ready Mamba block for graph processing
|
10 |
-
|
11 |
"""
|
12 |
def __init__(self, d_model, d_state=16, d_conv=4, expand=2, dt_rank="auto", bias=False):
|
13 |
super().__init__()
|
@@ -56,6 +56,11 @@ class MambaBlock(nn.Module):
|
|
56 |
Returns: (batch, length, d_model)
|
57 |
"""
|
58 |
batch, length, _ = x.shape
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
# Input projection and split
|
61 |
xz = self.in_proj(x) # (batch, length, 2 * d_inner)
|
@@ -81,6 +86,7 @@ class MambaBlock(nn.Module):
|
|
81 |
def selective_scan(self, u):
|
82 |
"""Selective scan operation - core of Mamba"""
|
83 |
batch, length, d_inner = u.shape
|
|
|
84 |
|
85 |
# Compute ∆, B, C
|
86 |
x_dbl = self.x_proj(u) # (batch, length, dt_rank + 2*d_state)
|
@@ -92,15 +98,20 @@ class MambaBlock(nn.Module):
|
|
92 |
return self._selective_scan_pytorch(u, delta, B, C)
|
93 |
|
94 |
def _selective_scan_pytorch(self, u, delta, B, C):
|
95 |
-
"""PyTorch implementation of selective scan"""
|
96 |
batch, length, d_inner = u.shape
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
# Discretize
|
99 |
-
deltaA = torch.exp(delta.unsqueeze(-1) * (-torch.exp(
|
100 |
deltaB_u = delta.unsqueeze(-1) * B.unsqueeze(2) * u.unsqueeze(-1) # (batch, length, d_inner, d_state)
|
101 |
|
102 |
# Initialize state
|
103 |
-
x = torch.zeros((batch, d_inner, self.d_state), device=
|
104 |
ys = []
|
105 |
|
106 |
for i in range(length):
|
@@ -111,6 +122,6 @@ class MambaBlock(nn.Module):
|
|
111 |
y = torch.stack(ys, dim=1) # (batch, length, d_inner)
|
112 |
|
113 |
# Add skip connection
|
114 |
-
y = y + u *
|
115 |
|
116 |
return y
|
|
|
7 |
class MambaBlock(nn.Module):
|
8 |
"""
|
9 |
Production-ready Mamba block for graph processing
|
10 |
+
Device-safe implementation
|
11 |
"""
|
12 |
def __init__(self, d_model, d_state=16, d_conv=4, expand=2, dt_rank="auto", bias=False):
|
13 |
super().__init__()
|
|
|
56 |
Returns: (batch, length, d_model)
|
57 |
"""
|
58 |
batch, length, _ = x.shape
|
59 |
+
device = x.device
|
60 |
+
|
61 |
+
# Ensure all parameters are on correct device
|
62 |
+
self.A_log = self.A_log.to(device)
|
63 |
+
self.D = self.D.to(device)
|
64 |
|
65 |
# Input projection and split
|
66 |
xz = self.in_proj(x) # (batch, length, 2 * d_inner)
|
|
|
86 |
def selective_scan(self, u):
|
87 |
"""Selective scan operation - core of Mamba"""
|
88 |
batch, length, d_inner = u.shape
|
89 |
+
device = u.device
|
90 |
|
91 |
# Compute ∆, B, C
|
92 |
x_dbl = self.x_proj(u) # (batch, length, dt_rank + 2*d_state)
|
|
|
98 |
return self._selective_scan_pytorch(u, delta, B, C)
|
99 |
|
100 |
def _selective_scan_pytorch(self, u, delta, B, C):
|
101 |
+
"""PyTorch implementation of selective scan - device safe"""
|
102 |
batch, length, d_inner = u.shape
|
103 |
+
device = u.device
|
104 |
+
|
105 |
+
# Ensure A_log and D are on correct device
|
106 |
+
A_log = self.A_log.to(device)
|
107 |
+
D = self.D.to(device)
|
108 |
|
109 |
# Discretize
|
110 |
+
deltaA = torch.exp(delta.unsqueeze(-1) * (-torch.exp(A_log))) # (batch, length, d_inner, d_state)
|
111 |
deltaB_u = delta.unsqueeze(-1) * B.unsqueeze(2) * u.unsqueeze(-1) # (batch, length, d_inner, d_state)
|
112 |
|
113 |
# Initialize state
|
114 |
+
x = torch.zeros((batch, d_inner, self.d_state), device=device, dtype=u.dtype)
|
115 |
ys = []
|
116 |
|
117 |
for i in range(length):
|
|
|
122 |
y = torch.stack(ys, dim=1) # (batch, length, d_inner)
|
123 |
|
124 |
# Add skip connection
|
125 |
+
y = y + u * D
|
126 |
|
127 |
return y
|