Text Generation
Transformers
Safetensors
PyTorch
nvidia
nemotron-h
FremyCompany commited on
Commit
7c37779
·
verified ·
1 Parent(s): b9aaa05

Enable `cache_params` to work with `generate()` from `GenerationMixin`

Browse files

IMPORTANT: The cache doesn't seem to work very well in my tests, and given that it was disabled and still contained a `breakpoint()` call, I assume it was just not ready, but this is still important to be able to understand how fast the model can run when the cache is used.

In my tests, the model in Q4 BF16 goes from generating about 3.3 tok/s to about 19.3 tok/s on the same RTX 5090. Now, I understand that ideally this model should be run in native FP4 on the 5090 but this is not supported yet in pytorch, so I guess this is only possible in the NeMo engine for now.

Files changed (1) hide show
  1. modeling_nemotron_h.py +20 -15
modeling_nemotron_h.py CHANGED
@@ -31,6 +31,9 @@ from transformers.modeling_attn_mask_utils import (
31
  AttentionMaskConverter,
32
  )
33
  from transformers.modeling_utils import PreTrainedModel
 
 
 
34
  from transformers.utils import (
35
  ModelOutput,
36
  add_code_sample_docstrings,
@@ -168,12 +171,14 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
168
 
169
  def __init__(self, config, batch_size, dtype=torch.float16, device=None):
170
  super().__init__()
 
171
  self.dtype = dtype
172
  self.hybrid_override_pattern = config.hybrid_override_pattern
173
  self.has_previous_state = False # only used by mamba
174
- intermediate_size = config.expand * config.hidden_size
175
- ssm_state_size = config.ssm_state_size
176
- conv_kernel_size = config.conv_kernel
 
177
  self.conv_states = []
178
  self.ssm_states = []
179
  self.transformer_layers = []
@@ -181,10 +186,10 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
181
  if self.hybrid_override_pattern[i] == "M":
182
  # Mamba layer
183
  self.conv_states += [
184
- torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype)
185
  ]
186
  self.ssm_states += [
187
- torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype)
188
  ]
189
  else:
190
  # Attention or MLP layer
@@ -245,14 +250,14 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
245
  self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False
246
  ) -> torch.Tensor:
247
  if cache_init:
248
- self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device)
249
  else:
250
  self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1)
251
- self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device)
252
  return self.conv_states[layer_idx]
253
 
254
  def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
255
- self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device)
256
  return self.ssm_states[layer_idx]
257
 
258
  def reset(self):
@@ -413,7 +418,7 @@ class NemotronHMamba2Mixer(nn.Module):
413
  dt_softplus=True,
414
  )
415
  hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim)
416
- breakpoint()
417
  hidden_states = self.norm(hidden_states, gate)
418
 
419
  # 4. Final linear projection
@@ -560,7 +565,7 @@ class NemotronHMamba2Mixer(nn.Module):
560
  A = -torch.exp(self.A_log.float()) # [num_heads]
561
  if cache_params is not None and cache_position is not None and cache_position[0] > 0:
562
  # We need to guarantee that anything regarding the cache is on the same device
563
- cache_device = cache_params.ssm_states.device
564
 
565
  # Note: there is no need to pad parameter matrices here, as there is just one new token
566
  # for batched generation
@@ -1185,7 +1190,7 @@ class NemotronHOutput(ModelOutput):
1185
 
1186
  @dataclass
1187
  # Copied from transformers.models.mamba2.modeling_mamba2.MambaCausalLMOutput with Mamba2->NemotronH
1188
- class NemotronHCausalLMOutput(ModelOutput):
1189
  """
1190
  Base class for causal language model (or autoregressive) outputs.
1191
 
@@ -1208,7 +1213,7 @@ class NemotronHCausalLMOutput(ModelOutput):
1208
 
1209
  loss: Optional[torch.FloatTensor] = None
1210
  logits: Optional[torch.FloatTensor] = None
1211
- cache_params: Optional[HybridMambaAttentionDynamicCache] = None
1212
  hidden_states: Optional[Tuple[torch.FloatTensor]] = None
1213
  attentions: Optional[Tuple[torch.FloatTensor]] = None
1214
 
@@ -1568,7 +1573,7 @@ class NemotronHForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
1568
  input_ids: Optional[torch.LongTensor] = None,
1569
  inputs_embeds: Optional[torch.FloatTensor] = None,
1570
  position_ids: Optional[torch.LongTensor] = None,
1571
- cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
1572
  labels: Optional[torch.LongTensor] = None,
1573
  output_attentions: Optional[bool] = None,
1574
  output_hidden_states: Optional[bool] = None,
@@ -1593,7 +1598,7 @@ class NemotronHForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
1593
 
1594
  nemotron_h_outputs = self.backbone(
1595
  input_ids,
1596
- cache_params=cache_params,
1597
  inputs_embeds=inputs_embeds,
1598
  output_attentions=output_attentions,
1599
  output_hidden_states=output_hidden_states,
@@ -1626,7 +1631,7 @@ class NemotronHForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
1626
  return NemotronHCausalLMOutput(
1627
  loss=loss,
1628
  logits=logits,
1629
- cache_params=nemotron_h_outputs.cache_params,
1630
  hidden_states=nemotron_h_outputs.hidden_states,
1631
  attentions=nemotron_h_outputs.attentions,
1632
  )
 
31
  AttentionMaskConverter,
32
  )
33
  from transformers.modeling_utils import PreTrainedModel
34
+ from transformers.modeling_outputs import (
35
+ MoeCausalLMOutputWithPast,
36
+ )
37
  from transformers.utils import (
38
  ModelOutput,
39
  add_code_sample_docstrings,
 
171
 
172
  def __init__(self, config, batch_size, dtype=torch.float16, device=None):
173
  super().__init__()
174
+ self.device=device
175
  self.dtype = dtype
176
  self.hybrid_override_pattern = config.hybrid_override_pattern
177
  self.has_previous_state = False # only used by mamba
178
+ self.intermediate_size = config.expand * config.hidden_size
179
+ self.ssm_state_size = config.ssm_state_size
180
+ self.conv_kernel_size = config.conv_kernel
181
+ self.conv_dim = self.intermediate_size + 2 * config.n_groups * config.ssm_state_size
182
  self.conv_states = []
183
  self.ssm_states = []
184
  self.transformer_layers = []
 
186
  if self.hybrid_override_pattern[i] == "M":
187
  # Mamba layer
188
  self.conv_states += [
189
+ torch.zeros(batch_size, self.conv_dim, self.conv_kernel_size, device=device, dtype=dtype)
190
  ]
191
  self.ssm_states += [
192
+ torch.zeros(batch_size, self.intermediate_size, self.ssm_state_size, device=device, dtype=dtype)
193
  ]
194
  else:
195
  # Attention or MLP layer
 
250
  self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False
251
  ) -> torch.Tensor:
252
  if cache_init:
253
+ self.conv_states[layer_idx] = new_conv_state.to(self.device)
254
  else:
255
  self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1)
256
+ self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states[layer_idx].device)
257
  return self.conv_states[layer_idx]
258
 
259
  def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
260
+ self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[layer_idx].device)
261
  return self.ssm_states[layer_idx]
262
 
263
  def reset(self):
 
418
  dt_softplus=True,
419
  )
420
  hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim)
421
+ # TODO: why was there a breakpoint() call here?
422
  hidden_states = self.norm(hidden_states, gate)
423
 
424
  # 4. Final linear projection
 
565
  A = -torch.exp(self.A_log.float()) # [num_heads]
566
  if cache_params is not None and cache_position is not None and cache_position[0] > 0:
567
  # We need to guarantee that anything regarding the cache is on the same device
568
+ cache_device = cache_params.ssm_states[0].device if len(cache_params.ssm_states) > 0 else cache_params.device
569
 
570
  # Note: there is no need to pad parameter matrices here, as there is just one new token
571
  # for batched generation
 
1190
 
1191
  @dataclass
1192
  # Copied from transformers.models.mamba2.modeling_mamba2.MambaCausalLMOutput with Mamba2->NemotronH
1193
+ class NemotronHCausalLMOutput(MoeCausalLMOutputWithPast):
1194
  """
1195
  Base class for causal language model (or autoregressive) outputs.
1196
 
 
1213
 
1214
  loss: Optional[torch.FloatTensor] = None
1215
  logits: Optional[torch.FloatTensor] = None
1216
+ past_key_values: Optional[HybridMambaAttentionDynamicCache] = None
1217
  hidden_states: Optional[Tuple[torch.FloatTensor]] = None
1218
  attentions: Optional[Tuple[torch.FloatTensor]] = None
1219
 
 
1573
  input_ids: Optional[torch.LongTensor] = None,
1574
  inputs_embeds: Optional[torch.FloatTensor] = None,
1575
  position_ids: Optional[torch.LongTensor] = None,
1576
+ past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
1577
  labels: Optional[torch.LongTensor] = None,
1578
  output_attentions: Optional[bool] = None,
1579
  output_hidden_states: Optional[bool] = None,
 
1598
 
1599
  nemotron_h_outputs = self.backbone(
1600
  input_ids,
1601
+ cache_params=past_key_values,
1602
  inputs_embeds=inputs_embeds,
1603
  output_attentions=output_attentions,
1604
  output_hidden_states=output_hidden_states,
 
1631
  return NemotronHCausalLMOutput(
1632
  loss=loss,
1633
  logits=logits,
1634
+ past_key_values=nemotron_h_outputs.cache_params,
1635
  hidden_states=nemotron_h_outputs.hidden_states,
1636
  attentions=nemotron_h_outputs.attentions,
1637
  )