Upload 2 files
Browse files- configuration_nemotron_h.py +1 -1
- modeling_nemotron_h.py +45 -48
configuration_nemotron_h.py
CHANGED
@@ -239,4 +239,4 @@ class NemotronHConfig(PretrainedConfig):
|
|
239 |
return [
|
240 |
"mamba" if self.hybrid_override_pattern[i] == "M" else
|
241 |
"attention" if self.hybrid_override_pattern[i] == "*" else "mlp"
|
242 |
-
for i in range(self.num_hidden_layers)]
|
|
|
239 |
return [
|
240 |
"mamba" if self.hybrid_override_pattern[i] == "M" else
|
241 |
"attention" if self.hybrid_override_pattern[i] == "*" else "mlp"
|
242 |
+
for i in range(self.num_hidden_layers)]
|
modeling_nemotron_h.py
CHANGED
@@ -469,14 +469,12 @@ class NemotronHMamba2Mixer(nn.Module):
|
|
469 |
self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)
|
470 |
)
|
471 |
else:
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
activation=self.activation,
|
479 |
-
).transpose(1, 2)
|
480 |
hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
|
481 |
hidden_states, B, C = torch.split(
|
482 |
hidden_states_B_C,
|
@@ -485,23 +483,21 @@ class NemotronHMamba2Mixer(nn.Module):
|
|
485 |
)
|
486 |
|
487 |
# 3. SSM transformation
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
**dt_limit_kwargs,
|
504 |
-
)
|
505 |
|
506 |
# Init cache
|
507 |
if ssm_state is not None and cache_params is not None:
|
@@ -768,30 +764,31 @@ class NemotronHBlock(nn.Module):
|
|
768 |
cache_position: Optional[torch.LongTensor] = None,
|
769 |
attention_mask: Optional[torch.Tensor] = None,
|
770 |
):
|
771 |
-
|
772 |
-
|
773 |
-
|
774 |
-
|
775 |
-
|
776 |
-
|
777 |
-
|
778 |
-
|
779 |
-
|
780 |
-
|
781 |
-
|
782 |
-
|
783 |
-
|
784 |
-
|
785 |
-
|
786 |
-
|
787 |
-
|
788 |
-
hidden_states
|
789 |
-
|
790 |
-
|
791 |
-
|
|
|
792 |
|
793 |
-
|
794 |
-
|
795 |
|
796 |
|
797 |
# Copied from transformers.models.nemotron.modeling_nemotron Nemotron->NemotronH
|
|
|
469 |
self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)
|
470 |
)
|
471 |
else:
|
472 |
+
hidden_states_B_C = causal_conv1d_fn(
|
473 |
+
x=hidden_states_B_C.transpose(1, 2),
|
474 |
+
weight=self.conv1d.weight.squeeze(1),
|
475 |
+
bias=self.conv1d.bias,
|
476 |
+
activation=self.activation,
|
477 |
+
).transpose(1, 2)
|
|
|
|
|
478 |
hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
|
479 |
hidden_states, B, C = torch.split(
|
480 |
hidden_states_B_C,
|
|
|
483 |
)
|
484 |
|
485 |
# 3. SSM transformation
|
486 |
+
scan_output, ssm_state = mamba_chunk_scan_combined(
|
487 |
+
hidden_states.view(batch_size, seq_len, -1, self.head_dim),
|
488 |
+
dt,
|
489 |
+
A,
|
490 |
+
B.view(batch_size, seq_len, self.n_groups, -1),
|
491 |
+
C.view(batch_size, seq_len, self.n_groups, -1),
|
492 |
+
chunk_size=self.chunk_size,
|
493 |
+
D=self.D,
|
494 |
+
z=None,
|
495 |
+
seq_idx=None,
|
496 |
+
return_final_states=True,
|
497 |
+
dt_bias=self.dt_bias,
|
498 |
+
dt_softplus=True,
|
499 |
+
**dt_limit_kwargs,
|
500 |
+
)
|
|
|
|
|
501 |
|
502 |
# Init cache
|
503 |
if ssm_state is not None and cache_params is not None:
|
|
|
764 |
cache_position: Optional[torch.LongTensor] = None,
|
765 |
attention_mask: Optional[torch.Tensor] = None,
|
766 |
):
|
767 |
+
with torch.cuda.stream(torch.cuda.default_stream(hidden_states.device)):
|
768 |
+
# * Use torch.cuda.stream() to avoid NaN issues when using multiple GPUs
|
769 |
+
residual = hidden_states
|
770 |
+
hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
|
771 |
+
if self.residual_in_fp32:
|
772 |
+
residual = residual.to(torch.float32)
|
773 |
+
|
774 |
+
if self.block_type == "mamba":
|
775 |
+
hidden_states = self.mixer(
|
776 |
+
hidden_states, cache_params=cache_params, cache_position=cache_position
|
777 |
+
)
|
778 |
+
elif self.block_type == "attention":
|
779 |
+
hidden_states = self.mixer(
|
780 |
+
hidden_states, cache_position=cache_position
|
781 |
+
)
|
782 |
+
hidden_states = hidden_states[0]
|
783 |
+
elif self.block_type == "mlp":
|
784 |
+
hidden_states = self.mixer(
|
785 |
+
hidden_states
|
786 |
+
)
|
787 |
+
else:
|
788 |
+
raise ValueError(f"Invalid block_type: {self.block_type}")
|
789 |
|
790 |
+
hidden_states = residual + hidden_states
|
791 |
+
return hidden_states
|
792 |
|
793 |
|
794 |
# Copied from transformers.models.nemotron.modeling_nemotron Nemotron->NemotronH
|