Enable `cache_params` to work with `generate()` from `GenerationMixin`
Browse filesIMPORTANT: The cache doesn't seem to work very well in my tests, and given that it was disabled and still contained a `breakpoint()` call, I assume it was just not ready, but this is still important to be able to understand how fast the model can run when the cache is used.
In my tests, the model in Q4 BF16 goes from generating about 3.3 tok/s to about 19.3 tok/s on the same RTX 5090. Now, I understand that ideally this model should be run in native FP4 on the 5090 but this is not supported yet in pytorch, so I guess this is only possible in the NeMo engine for now.
- modeling_nemotron_h.py +20 -15
modeling_nemotron_h.py
CHANGED
@@ -31,6 +31,9 @@ from transformers.modeling_attn_mask_utils import (
|
|
31 |
AttentionMaskConverter,
|
32 |
)
|
33 |
from transformers.modeling_utils import PreTrainedModel
|
|
|
|
|
|
|
34 |
from transformers.utils import (
|
35 |
ModelOutput,
|
36 |
add_code_sample_docstrings,
|
@@ -168,12 +171,14 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
|
|
168 |
|
169 |
def __init__(self, config, batch_size, dtype=torch.float16, device=None):
|
170 |
super().__init__()
|
|
|
171 |
self.dtype = dtype
|
172 |
self.hybrid_override_pattern = config.hybrid_override_pattern
|
173 |
self.has_previous_state = False # only used by mamba
|
174 |
-
intermediate_size = config.expand * config.hidden_size
|
175 |
-
ssm_state_size = config.ssm_state_size
|
176 |
-
conv_kernel_size = config.conv_kernel
|
|
|
177 |
self.conv_states = []
|
178 |
self.ssm_states = []
|
179 |
self.transformer_layers = []
|
@@ -181,10 +186,10 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
|
|
181 |
if self.hybrid_override_pattern[i] == "M":
|
182 |
# Mamba layer
|
183 |
self.conv_states += [
|
184 |
-
torch.zeros(batch_size,
|
185 |
]
|
186 |
self.ssm_states += [
|
187 |
-
torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype)
|
188 |
]
|
189 |
else:
|
190 |
# Attention or MLP layer
|
@@ -245,14 +250,14 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
|
|
245 |
self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False
|
246 |
) -> torch.Tensor:
|
247 |
if cache_init:
|
248 |
-
self.conv_states[layer_idx] = new_conv_state.to(self.
|
249 |
else:
|
250 |
self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1)
|
251 |
-
self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device)
|
252 |
return self.conv_states[layer_idx]
|
253 |
|
254 |
def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
|
255 |
-
self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device)
|
256 |
return self.ssm_states[layer_idx]
|
257 |
|
258 |
def reset(self):
|
@@ -413,7 +418,7 @@ class NemotronHMamba2Mixer(nn.Module):
|
|
413 |
dt_softplus=True,
|
414 |
)
|
415 |
hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim)
|
416 |
-
breakpoint()
|
417 |
hidden_states = self.norm(hidden_states, gate)
|
418 |
|
419 |
# 4. Final linear projection
|
@@ -560,7 +565,7 @@ class NemotronHMamba2Mixer(nn.Module):
|
|
560 |
A = -torch.exp(self.A_log.float()) # [num_heads]
|
561 |
if cache_params is not None and cache_position is not None and cache_position[0] > 0:
|
562 |
# We need to guarantee that anything regarding the cache is on the same device
|
563 |
-
cache_device = cache_params.ssm_states.device
|
564 |
|
565 |
# Note: there is no need to pad parameter matrices here, as there is just one new token
|
566 |
# for batched generation
|
@@ -1185,7 +1190,7 @@ class NemotronHOutput(ModelOutput):
|
|
1185 |
|
1186 |
@dataclass
|
1187 |
# Copied from transformers.models.mamba2.modeling_mamba2.MambaCausalLMOutput with Mamba2->NemotronH
|
1188 |
-
class NemotronHCausalLMOutput(
|
1189 |
"""
|
1190 |
Base class for causal language model (or autoregressive) outputs.
|
1191 |
|
@@ -1208,7 +1213,7 @@ class NemotronHCausalLMOutput(ModelOutput):
|
|
1208 |
|
1209 |
loss: Optional[torch.FloatTensor] = None
|
1210 |
logits: Optional[torch.FloatTensor] = None
|
1211 |
-
|
1212 |
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
1213 |
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
1214 |
|
@@ -1568,7 +1573,7 @@ class NemotronHForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
|
|
1568 |
input_ids: Optional[torch.LongTensor] = None,
|
1569 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1570 |
position_ids: Optional[torch.LongTensor] = None,
|
1571 |
-
|
1572 |
labels: Optional[torch.LongTensor] = None,
|
1573 |
output_attentions: Optional[bool] = None,
|
1574 |
output_hidden_states: Optional[bool] = None,
|
@@ -1593,7 +1598,7 @@ class NemotronHForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
|
|
1593 |
|
1594 |
nemotron_h_outputs = self.backbone(
|
1595 |
input_ids,
|
1596 |
-
cache_params=
|
1597 |
inputs_embeds=inputs_embeds,
|
1598 |
output_attentions=output_attentions,
|
1599 |
output_hidden_states=output_hidden_states,
|
@@ -1626,7 +1631,7 @@ class NemotronHForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
|
|
1626 |
return NemotronHCausalLMOutput(
|
1627 |
loss=loss,
|
1628 |
logits=logits,
|
1629 |
-
|
1630 |
hidden_states=nemotron_h_outputs.hidden_states,
|
1631 |
attentions=nemotron_h_outputs.attentions,
|
1632 |
)
|
|
|
31 |
AttentionMaskConverter,
|
32 |
)
|
33 |
from transformers.modeling_utils import PreTrainedModel
|
34 |
+
from transformers.modeling_outputs import (
|
35 |
+
MoeCausalLMOutputWithPast,
|
36 |
+
)
|
37 |
from transformers.utils import (
|
38 |
ModelOutput,
|
39 |
add_code_sample_docstrings,
|
|
|
171 |
|
172 |
def __init__(self, config, batch_size, dtype=torch.float16, device=None):
|
173 |
super().__init__()
|
174 |
+
self.device=device
|
175 |
self.dtype = dtype
|
176 |
self.hybrid_override_pattern = config.hybrid_override_pattern
|
177 |
self.has_previous_state = False # only used by mamba
|
178 |
+
self.intermediate_size = config.expand * config.hidden_size
|
179 |
+
self.ssm_state_size = config.ssm_state_size
|
180 |
+
self.conv_kernel_size = config.conv_kernel
|
181 |
+
self.conv_dim = self.intermediate_size + 2 * config.n_groups * config.ssm_state_size
|
182 |
self.conv_states = []
|
183 |
self.ssm_states = []
|
184 |
self.transformer_layers = []
|
|
|
186 |
if self.hybrid_override_pattern[i] == "M":
|
187 |
# Mamba layer
|
188 |
self.conv_states += [
|
189 |
+
torch.zeros(batch_size, self.conv_dim, self.conv_kernel_size, device=device, dtype=dtype)
|
190 |
]
|
191 |
self.ssm_states += [
|
192 |
+
torch.zeros(batch_size, self.intermediate_size, self.ssm_state_size, device=device, dtype=dtype)
|
193 |
]
|
194 |
else:
|
195 |
# Attention or MLP layer
|
|
|
250 |
self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False
|
251 |
) -> torch.Tensor:
|
252 |
if cache_init:
|
253 |
+
self.conv_states[layer_idx] = new_conv_state.to(self.device)
|
254 |
else:
|
255 |
self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1)
|
256 |
+
self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states[layer_idx].device)
|
257 |
return self.conv_states[layer_idx]
|
258 |
|
259 |
def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
|
260 |
+
self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[layer_idx].device)
|
261 |
return self.ssm_states[layer_idx]
|
262 |
|
263 |
def reset(self):
|
|
|
418 |
dt_softplus=True,
|
419 |
)
|
420 |
hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim)
|
421 |
+
# TODO: why was there a breakpoint() call here?
|
422 |
hidden_states = self.norm(hidden_states, gate)
|
423 |
|
424 |
# 4. Final linear projection
|
|
|
565 |
A = -torch.exp(self.A_log.float()) # [num_heads]
|
566 |
if cache_params is not None and cache_position is not None and cache_position[0] > 0:
|
567 |
# We need to guarantee that anything regarding the cache is on the same device
|
568 |
+
cache_device = cache_params.ssm_states[0].device if len(cache_params.ssm_states) > 0 else cache_params.device
|
569 |
|
570 |
# Note: there is no need to pad parameter matrices here, as there is just one new token
|
571 |
# for batched generation
|
|
|
1190 |
|
1191 |
@dataclass
|
1192 |
# Copied from transformers.models.mamba2.modeling_mamba2.MambaCausalLMOutput with Mamba2->NemotronH
|
1193 |
+
class NemotronHCausalLMOutput(MoeCausalLMOutputWithPast):
|
1194 |
"""
|
1195 |
Base class for causal language model (or autoregressive) outputs.
|
1196 |
|
|
|
1213 |
|
1214 |
loss: Optional[torch.FloatTensor] = None
|
1215 |
logits: Optional[torch.FloatTensor] = None
|
1216 |
+
past_key_values: Optional[HybridMambaAttentionDynamicCache] = None
|
1217 |
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
1218 |
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
1219 |
|
|
|
1573 |
input_ids: Optional[torch.LongTensor] = None,
|
1574 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1575 |
position_ids: Optional[torch.LongTensor] = None,
|
1576 |
+
past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
|
1577 |
labels: Optional[torch.LongTensor] = None,
|
1578 |
output_attentions: Optional[bool] = None,
|
1579 |
output_hidden_states: Optional[bool] = None,
|
|
|
1598 |
|
1599 |
nemotron_h_outputs = self.backbone(
|
1600 |
input_ids,
|
1601 |
+
cache_params=past_key_values,
|
1602 |
inputs_embeds=inputs_embeds,
|
1603 |
output_attentions=output_attentions,
|
1604 |
output_hidden_states=output_hidden_states,
|
|
|
1631 |
return NemotronHCausalLMOutput(
|
1632 |
loss=loss,
|
1633 |
logits=logits,
|
1634 |
+
past_key_values=nemotron_h_outputs.cache_params,
|
1635 |
hidden_states=nemotron_h_outputs.hidden_states,
|
1636 |
attentions=nemotron_h_outputs.attentions,
|
1637 |
)
|