Text Generation
Transformers
Safetensors
PyTorch
nvidia
nemotron-h
suhara commited on
Commit
6d8fd0b
·
verified ·
1 Parent(s): 83908dc

Upload 2 files

Browse files
configuration_nemotron_h.py CHANGED
@@ -239,4 +239,4 @@ class NemotronHConfig(PretrainedConfig):
239
  return [
240
  "mamba" if self.hybrid_override_pattern[i] == "M" else
241
  "attention" if self.hybrid_override_pattern[i] == "*" else "mlp"
242
- for i in range(self.num_hidden_layers)]
 
239
  return [
240
  "mamba" if self.hybrid_override_pattern[i] == "M" else
241
  "attention" if self.hybrid_override_pattern[i] == "*" else "mlp"
242
+ for i in range(self.num_hidden_layers)]
modeling_nemotron_h.py CHANGED
@@ -469,14 +469,12 @@ class NemotronHMamba2Mixer(nn.Module):
469
  self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)
470
  )
471
  else:
472
- # * Use torch.cuda.stream() to avoid NaN issues when using multiple GPUs
473
- with torch.cuda.stream(torch.cuda.default_stream(hidden_states_B_C.device)):
474
- hidden_states_B_C = causal_conv1d_fn(
475
- x=hidden_states_B_C.transpose(1, 2),
476
- weight=self.conv1d.weight.squeeze(1),
477
- bias=self.conv1d.bias,
478
- activation=self.activation,
479
- ).transpose(1, 2)
480
  hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
481
  hidden_states, B, C = torch.split(
482
  hidden_states_B_C,
@@ -485,23 +483,21 @@ class NemotronHMamba2Mixer(nn.Module):
485
  )
486
 
487
  # 3. SSM transformation
488
- # * Use torch.cuda.stream() to avoid NaN issues when using multiple GPUs
489
- with torch.cuda.stream(torch.cuda.default_stream(hidden_states_B_C.device)):
490
- scan_output, ssm_state = mamba_chunk_scan_combined(
491
- hidden_states.view(batch_size, seq_len, -1, self.head_dim),
492
- dt,
493
- A,
494
- B.view(batch_size, seq_len, self.n_groups, -1),
495
- C.view(batch_size, seq_len, self.n_groups, -1),
496
- chunk_size=self.chunk_size,
497
- D=self.D,
498
- z=None,
499
- seq_idx=None,
500
- return_final_states=True,
501
- dt_bias=self.dt_bias,
502
- dt_softplus=True,
503
- **dt_limit_kwargs,
504
- )
505
 
506
  # Init cache
507
  if ssm_state is not None and cache_params is not None:
@@ -768,30 +764,31 @@ class NemotronHBlock(nn.Module):
768
  cache_position: Optional[torch.LongTensor] = None,
769
  attention_mask: Optional[torch.Tensor] = None,
770
  ):
771
- residual = hidden_states
772
- hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
773
- if self.residual_in_fp32:
774
- residual = residual.to(torch.float32)
775
-
776
- if self.block_type == "mamba":
777
- hidden_states = self.mixer(
778
- hidden_states, cache_params=cache_params, cache_position=cache_position #, attention_mask=attention_mask
779
- )
780
- elif self.block_type == "attention":
781
- hidden_states = self.mixer(
782
- hidden_states, cache_position=cache_position #, attention_mask=attention_mask
783
- )
784
- # hidden_states = (attn_output, attn_weights, past_key_value)
785
- hidden_states = hidden_states[0]
786
- elif self.block_type == "mlp":
787
- hidden_states = self.mixer(
788
- hidden_states
789
- )
790
- else:
791
- raise ValueError(f"Invalid block_type: {self.block_type}")
 
792
 
793
- hidden_states = residual + hidden_states
794
- return hidden_states
795
 
796
 
797
  # Copied from transformers.models.nemotron.modeling_nemotron Nemotron->NemotronH
 
469
  self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)
470
  )
471
  else:
472
+ hidden_states_B_C = causal_conv1d_fn(
473
+ x=hidden_states_B_C.transpose(1, 2),
474
+ weight=self.conv1d.weight.squeeze(1),
475
+ bias=self.conv1d.bias,
476
+ activation=self.activation,
477
+ ).transpose(1, 2)
 
 
478
  hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
479
  hidden_states, B, C = torch.split(
480
  hidden_states_B_C,
 
483
  )
484
 
485
  # 3. SSM transformation
486
+ scan_output, ssm_state = mamba_chunk_scan_combined(
487
+ hidden_states.view(batch_size, seq_len, -1, self.head_dim),
488
+ dt,
489
+ A,
490
+ B.view(batch_size, seq_len, self.n_groups, -1),
491
+ C.view(batch_size, seq_len, self.n_groups, -1),
492
+ chunk_size=self.chunk_size,
493
+ D=self.D,
494
+ z=None,
495
+ seq_idx=None,
496
+ return_final_states=True,
497
+ dt_bias=self.dt_bias,
498
+ dt_softplus=True,
499
+ **dt_limit_kwargs,
500
+ )
 
 
501
 
502
  # Init cache
503
  if ssm_state is not None and cache_params is not None:
 
764
  cache_position: Optional[torch.LongTensor] = None,
765
  attention_mask: Optional[torch.Tensor] = None,
766
  ):
767
+ with torch.cuda.stream(torch.cuda.default_stream(hidden_states.device)):
768
+ # * Use torch.cuda.stream() to avoid NaN issues when using multiple GPUs
769
+ residual = hidden_states
770
+ hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
771
+ if self.residual_in_fp32:
772
+ residual = residual.to(torch.float32)
773
+
774
+ if self.block_type == "mamba":
775
+ hidden_states = self.mixer(
776
+ hidden_states, cache_params=cache_params, cache_position=cache_position
777
+ )
778
+ elif self.block_type == "attention":
779
+ hidden_states = self.mixer(
780
+ hidden_states, cache_position=cache_position
781
+ )
782
+ hidden_states = hidden_states[0]
783
+ elif self.block_type == "mlp":
784
+ hidden_states = self.mixer(
785
+ hidden_states
786
+ )
787
+ else:
788
+ raise ValueError(f"Invalid block_type: {self.block_type}")
789
 
790
+ hidden_states = residual + hidden_states
791
+ return hidden_states
792
 
793
 
794
  # Copied from transformers.models.nemotron.modeling_nemotron Nemotron->NemotronH