kfoughali commited on
Commit
3c6b427
·
verified ·
1 Parent(s): 3ab8374

Update core/mamba_block.py

Browse files
Files changed (1) hide show
  1. core/mamba_block.py +16 -1
core/mamba_block.py CHANGED
@@ -7,7 +7,7 @@ import math
7
  class MambaBlock(nn.Module):
8
  """
9
  Production-ready Mamba block for graph processing
10
- Device-safe implementation
11
  """
12
  def __init__(self, d_model, d_state=16, d_conv=4, expand=2, dt_rank="auto", bias=False):
13
  super().__init__()
@@ -50,6 +50,21 @@ class MambaBlock(nn.Module):
50
  # Activation
51
  self.act = nn.SiLU()
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  def forward(self, x):
54
  """
55
  x: (batch, length, d_model)
 
7
  class MambaBlock(nn.Module):
8
  """
9
  Production-ready Mamba block for graph processing
10
+ Device-safe implementation with optimizations
11
  """
12
  def __init__(self, d_model, d_state=16, d_conv=4, expand=2, dt_rank="auto", bias=False):
13
  super().__init__()
 
50
  # Activation
51
  self.act = nn.SiLU()
52
 
53
+ # Initialize parameters
54
+ self._init_parameters()
55
+
56
+ def _init_parameters(self):
57
+ """Initialize parameters with proper scaling"""
58
+ # Initialize dt projection specially
59
+ dt_init_std = self.dt_rank**-0.5 * self.d_state
60
+ with torch.no_grad():
61
+ self.dt_proj.bias.uniform_(-dt_init_std, dt_init_std)
62
+
63
+ # Initialize other projections
64
+ nn.init.xavier_uniform_(self.in_proj.weight)
65
+ nn.init.xavier_uniform_(self.x_proj.weight)
66
+ nn.init.xavier_uniform_(self.out_proj.weight)
67
+
68
  def forward(self, x):
69
  """
70
  x: (batch, length, d_model)