Update core/mamba_block.py
Browse files- core/mamba_block.py +16 -5
core/mamba_block.py
CHANGED
|
@@ -7,7 +7,7 @@ import math
|
|
| 7 |
class MambaBlock(nn.Module):
|
| 8 |
"""
|
| 9 |
Production-ready Mamba block for graph processing
|
| 10 |
-
|
| 11 |
"""
|
| 12 |
def __init__(self, d_model, d_state=16, d_conv=4, expand=2, dt_rank="auto", bias=False):
|
| 13 |
super().__init__()
|
|
@@ -56,6 +56,11 @@ class MambaBlock(nn.Module):
|
|
| 56 |
Returns: (batch, length, d_model)
|
| 57 |
"""
|
| 58 |
batch, length, _ = x.shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
# Input projection and split
|
| 61 |
xz = self.in_proj(x) # (batch, length, 2 * d_inner)
|
|
@@ -81,6 +86,7 @@ class MambaBlock(nn.Module):
|
|
| 81 |
def selective_scan(self, u):
|
| 82 |
"""Selective scan operation - core of Mamba"""
|
| 83 |
batch, length, d_inner = u.shape
|
|
|
|
| 84 |
|
| 85 |
# Compute ∆, B, C
|
| 86 |
x_dbl = self.x_proj(u) # (batch, length, dt_rank + 2*d_state)
|
|
@@ -92,15 +98,20 @@ class MambaBlock(nn.Module):
|
|
| 92 |
return self._selective_scan_pytorch(u, delta, B, C)
|
| 93 |
|
| 94 |
def _selective_scan_pytorch(self, u, delta, B, C):
|
| 95 |
-
"""PyTorch implementation of selective scan"""
|
| 96 |
batch, length, d_inner = u.shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
# Discretize
|
| 99 |
-
deltaA = torch.exp(delta.unsqueeze(-1) * (-torch.exp(
|
| 100 |
deltaB_u = delta.unsqueeze(-1) * B.unsqueeze(2) * u.unsqueeze(-1) # (batch, length, d_inner, d_state)
|
| 101 |
|
| 102 |
# Initialize state
|
| 103 |
-
x = torch.zeros((batch, d_inner, self.d_state), device=
|
| 104 |
ys = []
|
| 105 |
|
| 106 |
for i in range(length):
|
|
@@ -111,6 +122,6 @@ class MambaBlock(nn.Module):
|
|
| 111 |
y = torch.stack(ys, dim=1) # (batch, length, d_inner)
|
| 112 |
|
| 113 |
# Add skip connection
|
| 114 |
-
y = y + u *
|
| 115 |
|
| 116 |
return y
|
|
|
|
| 7 |
class MambaBlock(nn.Module):
|
| 8 |
"""
|
| 9 |
Production-ready Mamba block for graph processing
|
| 10 |
+
Device-safe implementation
|
| 11 |
"""
|
| 12 |
def __init__(self, d_model, d_state=16, d_conv=4, expand=2, dt_rank="auto", bias=False):
|
| 13 |
super().__init__()
|
|
|
|
| 56 |
Returns: (batch, length, d_model)
|
| 57 |
"""
|
| 58 |
batch, length, _ = x.shape
|
| 59 |
+
device = x.device
|
| 60 |
+
|
| 61 |
+
# Ensure all parameters are on correct device
|
| 62 |
+
self.A_log = self.A_log.to(device)
|
| 63 |
+
self.D = self.D.to(device)
|
| 64 |
|
| 65 |
# Input projection and split
|
| 66 |
xz = self.in_proj(x) # (batch, length, 2 * d_inner)
|
|
|
|
| 86 |
def selective_scan(self, u):
|
| 87 |
"""Selective scan operation - core of Mamba"""
|
| 88 |
batch, length, d_inner = u.shape
|
| 89 |
+
device = u.device
|
| 90 |
|
| 91 |
# Compute ∆, B, C
|
| 92 |
x_dbl = self.x_proj(u) # (batch, length, dt_rank + 2*d_state)
|
|
|
|
| 98 |
return self._selective_scan_pytorch(u, delta, B, C)
|
| 99 |
|
| 100 |
def _selective_scan_pytorch(self, u, delta, B, C):
|
| 101 |
+
"""PyTorch implementation of selective scan - device safe"""
|
| 102 |
batch, length, d_inner = u.shape
|
| 103 |
+
device = u.device
|
| 104 |
+
|
| 105 |
+
# Ensure A_log and D are on correct device
|
| 106 |
+
A_log = self.A_log.to(device)
|
| 107 |
+
D = self.D.to(device)
|
| 108 |
|
| 109 |
# Discretize
|
| 110 |
+
deltaA = torch.exp(delta.unsqueeze(-1) * (-torch.exp(A_log))) # (batch, length, d_inner, d_state)
|
| 111 |
deltaB_u = delta.unsqueeze(-1) * B.unsqueeze(2) * u.unsqueeze(-1) # (batch, length, d_inner, d_state)
|
| 112 |
|
| 113 |
# Initialize state
|
| 114 |
+
x = torch.zeros((batch, d_inner, self.d_state), device=device, dtype=u.dtype)
|
| 115 |
ys = []
|
| 116 |
|
| 117 |
for i in range(length):
|
|
|
|
| 122 |
y = torch.stack(ys, dim=1) # (batch, length, d_inner)
|
| 123 |
|
| 124 |
# Add skip connection
|
| 125 |
+
y = y + u * D
|
| 126 |
|
| 127 |
return y
|