Text Generation
Transformers
Safetensors
PyTorch
nvidia
nemotron-h

Enable `cache_params` to work with `generate()` from `GenerationMixin`

#3
Files changed (1) hide show
  1. modeling_nemotron_h.py +20 -15
modeling_nemotron_h.py CHANGED
@@ -31,6 +31,9 @@ from transformers.modeling_attn_mask_utils import (
31
  AttentionMaskConverter,
32
  )
33
  from transformers.modeling_utils import PreTrainedModel
 
 
 
34
  from transformers.utils import (
35
  ModelOutput,
36
  add_code_sample_docstrings,
@@ -168,12 +171,14 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
168
 
169
  def __init__(self, config, batch_size, dtype=torch.float16, device=None):
170
  super().__init__()
 
171
  self.dtype = dtype
172
  self.hybrid_override_pattern = config.hybrid_override_pattern
173
  self.has_previous_state = False # only used by mamba
174
- intermediate_size = config.expand * config.hidden_size
175
- ssm_state_size = config.ssm_state_size
176
- conv_kernel_size = config.conv_kernel
 
177
  self.conv_states = []
178
  self.ssm_states = []
179
  self.transformer_layers = []
@@ -181,10 +186,10 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
181
  if self.hybrid_override_pattern[i] == "M":
182
  # Mamba layer
183
  self.conv_states += [
184
- torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype)
185
  ]
186
  self.ssm_states += [
187
- torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype)
188
  ]
189
  else:
190
  # Attention or MLP layer
@@ -245,14 +250,14 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
245
  self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False
246
  ) -> torch.Tensor:
247
  if cache_init:
248
- self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device)
249
  else:
250
  self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1)
251
- self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device)
252
  return self.conv_states[layer_idx]
253
 
254
  def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
255
- self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device)
256
  return self.ssm_states[layer_idx]
257
 
258
  def reset(self):
@@ -413,7 +418,7 @@ class NemotronHMamba2Mixer(nn.Module):
413
  dt_softplus=True,
414
  )
415
  hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim)
416
- breakpoint()
417
  hidden_states = self.norm(hidden_states, gate)
418
 
419
  # 4. Final linear projection
@@ -560,7 +565,7 @@ class NemotronHMamba2Mixer(nn.Module):
560
  A = -torch.exp(self.A_log.float()) # [num_heads]
561
  if cache_params is not None and cache_position is not None and cache_position[0] > 0:
562
  # We need to guarantee that anything regarding the cache is on the same device
563
- cache_device = cache_params.ssm_states.device
564
 
565
  # Note: there is no need to pad parameter matrices here, as there is just one new token
566
  # for batched generation
@@ -1185,7 +1190,7 @@ class NemotronHOutput(ModelOutput):
1185
 
1186
  @dataclass
1187
  # Copied from transformers.models.mamba2.modeling_mamba2.MambaCausalLMOutput with Mamba2->NemotronH
1188
- class NemotronHCausalLMOutput(ModelOutput):
1189
  """
1190
  Base class for causal language model (or autoregressive) outputs.
1191
 
@@ -1208,7 +1213,7 @@ class NemotronHCausalLMOutput(ModelOutput):
1208
 
1209
  loss: Optional[torch.FloatTensor] = None
1210
  logits: Optional[torch.FloatTensor] = None
1211
- cache_params: Optional[HybridMambaAttentionDynamicCache] = None
1212
  hidden_states: Optional[Tuple[torch.FloatTensor]] = None
1213
  attentions: Optional[Tuple[torch.FloatTensor]] = None
1214
 
@@ -1568,7 +1573,7 @@ class NemotronHForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
1568
  input_ids: Optional[torch.LongTensor] = None,
1569
  inputs_embeds: Optional[torch.FloatTensor] = None,
1570
  position_ids: Optional[torch.LongTensor] = None,
1571
- cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
1572
  labels: Optional[torch.LongTensor] = None,
1573
  output_attentions: Optional[bool] = None,
1574
  output_hidden_states: Optional[bool] = None,
@@ -1593,7 +1598,7 @@ class NemotronHForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
1593
 
1594
  nemotron_h_outputs = self.backbone(
1595
  input_ids,
1596
- cache_params=cache_params,
1597
  inputs_embeds=inputs_embeds,
1598
  output_attentions=output_attentions,
1599
  output_hidden_states=output_hidden_states,
@@ -1626,7 +1631,7 @@ class NemotronHForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
1626
  return NemotronHCausalLMOutput(
1627
  loss=loss,
1628
  logits=logits,
1629
- cache_params=nemotron_h_outputs.cache_params,
1630
  hidden_states=nemotron_h_outputs.hidden_states,
1631
  attentions=nemotron_h_outputs.attentions,
1632
  )
 
31
  AttentionMaskConverter,
32
  )
33
  from transformers.modeling_utils import PreTrainedModel
34
+ from transformers.modeling_outputs import (
35
+ MoeCausalLMOutputWithPast,
36
+ )
37
  from transformers.utils import (
38
  ModelOutput,
39
  add_code_sample_docstrings,
 
171
 
172
  def __init__(self, config, batch_size, dtype=torch.float16, device=None):
173
  super().__init__()
174
+ self.device=device
175
  self.dtype = dtype
176
  self.hybrid_override_pattern = config.hybrid_override_pattern
177
  self.has_previous_state = False # only used by mamba
178
+ self.intermediate_size = config.expand * config.hidden_size
179
+ self.ssm_state_size = config.ssm_state_size
180
+ self.conv_kernel_size = config.conv_kernel
181
+ self.conv_dim = self.intermediate_size + 2 * config.n_groups * config.ssm_state_size
182
  self.conv_states = []
183
  self.ssm_states = []
184
  self.transformer_layers = []
 
186
  if self.hybrid_override_pattern[i] == "M":
187
  # Mamba layer
188
  self.conv_states += [
189
+ torch.zeros(batch_size, self.conv_dim, self.conv_kernel_size, device=device, dtype=dtype)
190
  ]
191
  self.ssm_states += [
192
+ torch.zeros(batch_size, self.intermediate_size, self.ssm_state_size, device=device, dtype=dtype)
193
  ]
194
  else:
195
  # Attention or MLP layer
 
250
  self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False
251
  ) -> torch.Tensor:
252
  if cache_init:
253
+ self.conv_states[layer_idx] = new_conv_state.to(self.device)
254
  else:
255
  self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1)
256
+ self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states[layer_idx].device)
257
  return self.conv_states[layer_idx]
258
 
259
  def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
260
+ self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[layer_idx].device)
261
  return self.ssm_states[layer_idx]
262
 
263
  def reset(self):
 
418
  dt_softplus=True,
419
  )
420
  hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim)
421
+ # TODO: why was there a breakpoint() call here?
422
  hidden_states = self.norm(hidden_states, gate)
423
 
424
  # 4. Final linear projection
 
565
  A = -torch.exp(self.A_log.float()) # [num_heads]
566
  if cache_params is not None and cache_position is not None and cache_position[0] > 0:
567
  # We need to guarantee that anything regarding the cache is on the same device
568
+ cache_device = cache_params.ssm_states[0].device if len(cache_params.ssm_states) > 0 else cache_params.device
569
 
570
  # Note: there is no need to pad parameter matrices here, as there is just one new token
571
  # for batched generation
 
1190
 
1191
  @dataclass
1192
  # Copied from transformers.models.mamba2.modeling_mamba2.MambaCausalLMOutput with Mamba2->NemotronH
1193
+ class NemotronHCausalLMOutput(MoeCausalLMOutputWithPast):
1194
  """
1195
  Base class for causal language model (or autoregressive) outputs.
1196
 
 
1213
 
1214
  loss: Optional[torch.FloatTensor] = None
1215
  logits: Optional[torch.FloatTensor] = None
1216
+ past_key_values: Optional[HybridMambaAttentionDynamicCache] = None
1217
  hidden_states: Optional[Tuple[torch.FloatTensor]] = None
1218
  attentions: Optional[Tuple[torch.FloatTensor]] = None
1219
 
 
1573
  input_ids: Optional[torch.LongTensor] = None,
1574
  inputs_embeds: Optional[torch.FloatTensor] = None,
1575
  position_ids: Optional[torch.LongTensor] = None,
1576
+ past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
1577
  labels: Optional[torch.LongTensor] = None,
1578
  output_attentions: Optional[bool] = None,
1579
  output_hidden_states: Optional[bool] = None,
 
1598
 
1599
  nemotron_h_outputs = self.backbone(
1600
  input_ids,
1601
+ cache_params=past_key_values,
1602
  inputs_embeds=inputs_embeds,
1603
  output_attentions=output_attentions,
1604
  output_hidden_states=output_hidden_states,
 
1631
  return NemotronHCausalLMOutput(
1632
  loss=loss,
1633
  logits=logits,
1634
+ past_key_values=nemotron_h_outputs.cache_params,
1635
  hidden_states=nemotron_h_outputs.hidden_states,
1636
  attentions=nemotron_h_outputs.attentions,
1637
  )