kfoughali commited on
Commit
069fc7a
·
verified ·
1 Parent(s): f97b87b

Update core/mamba_block.py

Browse files
Files changed (1) hide show
  1. core/mamba_block.py +16 -5
core/mamba_block.py CHANGED
@@ -7,7 +7,7 @@ import math
7
  class MambaBlock(nn.Module):
8
  """
9
  Production-ready Mamba block for graph processing
10
- Based on official Mamba implementation with graph optimizations
11
  """
12
  def __init__(self, d_model, d_state=16, d_conv=4, expand=2, dt_rank="auto", bias=False):
13
  super().__init__()
@@ -56,6 +56,11 @@ class MambaBlock(nn.Module):
56
  Returns: (batch, length, d_model)
57
  """
58
  batch, length, _ = x.shape
 
 
 
 
 
59
 
60
  # Input projection and split
61
  xz = self.in_proj(x) # (batch, length, 2 * d_inner)
@@ -81,6 +86,7 @@ class MambaBlock(nn.Module):
81
  def selective_scan(self, u):
82
  """Selective scan operation - core of Mamba"""
83
  batch, length, d_inner = u.shape
 
84
 
85
  # Compute ∆, B, C
86
  x_dbl = self.x_proj(u) # (batch, length, dt_rank + 2*d_state)
@@ -92,15 +98,20 @@ class MambaBlock(nn.Module):
92
  return self._selective_scan_pytorch(u, delta, B, C)
93
 
94
  def _selective_scan_pytorch(self, u, delta, B, C):
95
- """PyTorch implementation of selective scan"""
96
  batch, length, d_inner = u.shape
 
 
 
 
 
97
 
98
  # Discretize
99
- deltaA = torch.exp(delta.unsqueeze(-1) * (-torch.exp(self.A_log))) # (batch, length, d_inner, d_state)
100
  deltaB_u = delta.unsqueeze(-1) * B.unsqueeze(2) * u.unsqueeze(-1) # (batch, length, d_inner, d_state)
101
 
102
  # Initialize state
103
- x = torch.zeros((batch, d_inner, self.d_state), device=u.device, dtype=u.dtype)
104
  ys = []
105
 
106
  for i in range(length):
@@ -111,6 +122,6 @@ class MambaBlock(nn.Module):
111
  y = torch.stack(ys, dim=1) # (batch, length, d_inner)
112
 
113
  # Add skip connection
114
- y = y + u * self.D
115
 
116
  return y
 
7
  class MambaBlock(nn.Module):
8
  """
9
  Production-ready Mamba block for graph processing
10
+ Device-safe implementation
11
  """
12
  def __init__(self, d_model, d_state=16, d_conv=4, expand=2, dt_rank="auto", bias=False):
13
  super().__init__()
 
56
  Returns: (batch, length, d_model)
57
  """
58
  batch, length, _ = x.shape
59
+ device = x.device
60
+
61
+ # Ensure all parameters are on correct device
62
+ self.A_log = self.A_log.to(device)
63
+ self.D = self.D.to(device)
64
 
65
  # Input projection and split
66
  xz = self.in_proj(x) # (batch, length, 2 * d_inner)
 
86
  def selective_scan(self, u):
87
  """Selective scan operation - core of Mamba"""
88
  batch, length, d_inner = u.shape
89
+ device = u.device
90
 
91
  # Compute ∆, B, C
92
  x_dbl = self.x_proj(u) # (batch, length, dt_rank + 2*d_state)
 
98
  return self._selective_scan_pytorch(u, delta, B, C)
99
 
100
  def _selective_scan_pytorch(self, u, delta, B, C):
101
+ """PyTorch implementation of selective scan - device safe"""
102
  batch, length, d_inner = u.shape
103
+ device = u.device
104
+
105
+ # Ensure A_log and D are on correct device
106
+ A_log = self.A_log.to(device)
107
+ D = self.D.to(device)
108
 
109
  # Discretize
110
+ deltaA = torch.exp(delta.unsqueeze(-1) * (-torch.exp(A_log))) # (batch, length, d_inner, d_state)
111
  deltaB_u = delta.unsqueeze(-1) * B.unsqueeze(2) * u.unsqueeze(-1) # (batch, length, d_inner, d_state)
112
 
113
  # Initialize state
114
+ x = torch.zeros((batch, d_inner, self.d_state), device=device, dtype=u.dtype)
115
  ys = []
116
 
117
  for i in range(length):
 
122
  y = torch.stack(ys, dim=1) # (batch, length, d_inner)
123
 
124
  # Add skip connection
125
+ y = y + u * D
126
 
127
  return y