Enable `cache_params` to work with `generate()` from `GenerationMixin`
#3
by
FremyCompany
- opened
- modeling_nemotron_h.py +20 -15
modeling_nemotron_h.py
CHANGED
@@ -31,6 +31,9 @@ from transformers.modeling_attn_mask_utils import (
|
|
31 |
AttentionMaskConverter,
|
32 |
)
|
33 |
from transformers.modeling_utils import PreTrainedModel
|
|
|
|
|
|
|
34 |
from transformers.utils import (
|
35 |
ModelOutput,
|
36 |
add_code_sample_docstrings,
|
@@ -168,12 +171,14 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
|
|
168 |
|
169 |
def __init__(self, config, batch_size, dtype=torch.float16, device=None):
|
170 |
super().__init__()
|
|
|
171 |
self.dtype = dtype
|
172 |
self.hybrid_override_pattern = config.hybrid_override_pattern
|
173 |
self.has_previous_state = False # only used by mamba
|
174 |
-
intermediate_size = config.expand * config.hidden_size
|
175 |
-
ssm_state_size = config.ssm_state_size
|
176 |
-
conv_kernel_size = config.conv_kernel
|
|
|
177 |
self.conv_states = []
|
178 |
self.ssm_states = []
|
179 |
self.transformer_layers = []
|
@@ -181,10 +186,10 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
|
|
181 |
if self.hybrid_override_pattern[i] == "M":
|
182 |
# Mamba layer
|
183 |
self.conv_states += [
|
184 |
-
torch.zeros(batch_size,
|
185 |
]
|
186 |
self.ssm_states += [
|
187 |
-
torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype)
|
188 |
]
|
189 |
else:
|
190 |
# Attention or MLP layer
|
@@ -245,14 +250,14 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
|
|
245 |
self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False
|
246 |
) -> torch.Tensor:
|
247 |
if cache_init:
|
248 |
-
self.conv_states[layer_idx] = new_conv_state.to(self.
|
249 |
else:
|
250 |
self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1)
|
251 |
-
self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device)
|
252 |
return self.conv_states[layer_idx]
|
253 |
|
254 |
def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
|
255 |
-
self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device)
|
256 |
return self.ssm_states[layer_idx]
|
257 |
|
258 |
def reset(self):
|
@@ -413,7 +418,7 @@ class NemotronHMamba2Mixer(nn.Module):
|
|
413 |
dt_softplus=True,
|
414 |
)
|
415 |
hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim)
|
416 |
-
breakpoint()
|
417 |
hidden_states = self.norm(hidden_states, gate)
|
418 |
|
419 |
# 4. Final linear projection
|
@@ -560,7 +565,7 @@ class NemotronHMamba2Mixer(nn.Module):
|
|
560 |
A = -torch.exp(self.A_log.float()) # [num_heads]
|
561 |
if cache_params is not None and cache_position is not None and cache_position[0] > 0:
|
562 |
# We need to guarantee that anything regarding the cache is on the same device
|
563 |
-
cache_device = cache_params.ssm_states.device
|
564 |
|
565 |
# Note: there is no need to pad parameter matrices here, as there is just one new token
|
566 |
# for batched generation
|
@@ -1185,7 +1190,7 @@ class NemotronHOutput(ModelOutput):
|
|
1185 |
|
1186 |
@dataclass
|
1187 |
# Copied from transformers.models.mamba2.modeling_mamba2.MambaCausalLMOutput with Mamba2->NemotronH
|
1188 |
-
class NemotronHCausalLMOutput(
|
1189 |
"""
|
1190 |
Base class for causal language model (or autoregressive) outputs.
|
1191 |
|
@@ -1208,7 +1213,7 @@ class NemotronHCausalLMOutput(ModelOutput):
|
|
1208 |
|
1209 |
loss: Optional[torch.FloatTensor] = None
|
1210 |
logits: Optional[torch.FloatTensor] = None
|
1211 |
-
|
1212 |
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
1213 |
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
1214 |
|
@@ -1568,7 +1573,7 @@ class NemotronHForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
|
|
1568 |
input_ids: Optional[torch.LongTensor] = None,
|
1569 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1570 |
position_ids: Optional[torch.LongTensor] = None,
|
1571 |
-
|
1572 |
labels: Optional[torch.LongTensor] = None,
|
1573 |
output_attentions: Optional[bool] = None,
|
1574 |
output_hidden_states: Optional[bool] = None,
|
@@ -1593,7 +1598,7 @@ class NemotronHForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
|
|
1593 |
|
1594 |
nemotron_h_outputs = self.backbone(
|
1595 |
input_ids,
|
1596 |
-
cache_params=
|
1597 |
inputs_embeds=inputs_embeds,
|
1598 |
output_attentions=output_attentions,
|
1599 |
output_hidden_states=output_hidden_states,
|
@@ -1626,7 +1631,7 @@ class NemotronHForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
|
|
1626 |
return NemotronHCausalLMOutput(
|
1627 |
loss=loss,
|
1628 |
logits=logits,
|
1629 |
-
|
1630 |
hidden_states=nemotron_h_outputs.hidden_states,
|
1631 |
attentions=nemotron_h_outputs.attentions,
|
1632 |
)
|
|
|
31 |
AttentionMaskConverter,
|
32 |
)
|
33 |
from transformers.modeling_utils import PreTrainedModel
|
34 |
+
from transformers.modeling_outputs import (
|
35 |
+
MoeCausalLMOutputWithPast,
|
36 |
+
)
|
37 |
from transformers.utils import (
|
38 |
ModelOutput,
|
39 |
add_code_sample_docstrings,
|
|
|
171 |
|
172 |
def __init__(self, config, batch_size, dtype=torch.float16, device=None):
|
173 |
super().__init__()
|
174 |
+
self.device=device
|
175 |
self.dtype = dtype
|
176 |
self.hybrid_override_pattern = config.hybrid_override_pattern
|
177 |
self.has_previous_state = False # only used by mamba
|
178 |
+
self.intermediate_size = config.expand * config.hidden_size
|
179 |
+
self.ssm_state_size = config.ssm_state_size
|
180 |
+
self.conv_kernel_size = config.conv_kernel
|
181 |
+
self.conv_dim = self.intermediate_size + 2 * config.n_groups * config.ssm_state_size
|
182 |
self.conv_states = []
|
183 |
self.ssm_states = []
|
184 |
self.transformer_layers = []
|
|
|
186 |
if self.hybrid_override_pattern[i] == "M":
|
187 |
# Mamba layer
|
188 |
self.conv_states += [
|
189 |
+
torch.zeros(batch_size, self.conv_dim, self.conv_kernel_size, device=device, dtype=dtype)
|
190 |
]
|
191 |
self.ssm_states += [
|
192 |
+
torch.zeros(batch_size, self.intermediate_size, self.ssm_state_size, device=device, dtype=dtype)
|
193 |
]
|
194 |
else:
|
195 |
# Attention or MLP layer
|
|
|
250 |
self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False
|
251 |
) -> torch.Tensor:
|
252 |
if cache_init:
|
253 |
+
self.conv_states[layer_idx] = new_conv_state.to(self.device)
|
254 |
else:
|
255 |
self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1)
|
256 |
+
self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states[layer_idx].device)
|
257 |
return self.conv_states[layer_idx]
|
258 |
|
259 |
def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
|
260 |
+
self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[layer_idx].device)
|
261 |
return self.ssm_states[layer_idx]
|
262 |
|
263 |
def reset(self):
|
|
|
418 |
dt_softplus=True,
|
419 |
)
|
420 |
hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim)
|
421 |
+
# TODO: why was there a breakpoint() call here?
|
422 |
hidden_states = self.norm(hidden_states, gate)
|
423 |
|
424 |
# 4. Final linear projection
|
|
|
565 |
A = -torch.exp(self.A_log.float()) # [num_heads]
|
566 |
if cache_params is not None and cache_position is not None and cache_position[0] > 0:
|
567 |
# We need to guarantee that anything regarding the cache is on the same device
|
568 |
+
cache_device = cache_params.ssm_states[0].device if len(cache_params.ssm_states) > 0 else cache_params.device
|
569 |
|
570 |
# Note: there is no need to pad parameter matrices here, as there is just one new token
|
571 |
# for batched generation
|
|
|
1190 |
|
1191 |
@dataclass
|
1192 |
# Copied from transformers.models.mamba2.modeling_mamba2.MambaCausalLMOutput with Mamba2->NemotronH
|
1193 |
+
class NemotronHCausalLMOutput(MoeCausalLMOutputWithPast):
|
1194 |
"""
|
1195 |
Base class for causal language model (or autoregressive) outputs.
|
1196 |
|
|
|
1213 |
|
1214 |
loss: Optional[torch.FloatTensor] = None
|
1215 |
logits: Optional[torch.FloatTensor] = None
|
1216 |
+
past_key_values: Optional[HybridMambaAttentionDynamicCache] = None
|
1217 |
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
1218 |
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
1219 |
|
|
|
1573 |
input_ids: Optional[torch.LongTensor] = None,
|
1574 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1575 |
position_ids: Optional[torch.LongTensor] = None,
|
1576 |
+
past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
|
1577 |
labels: Optional[torch.LongTensor] = None,
|
1578 |
output_attentions: Optional[bool] = None,
|
1579 |
output_hidden_states: Optional[bool] = None,
|
|
|
1598 |
|
1599 |
nemotron_h_outputs = self.backbone(
|
1600 |
input_ids,
|
1601 |
+
cache_params=past_key_values,
|
1602 |
inputs_embeds=inputs_embeds,
|
1603 |
output_attentions=output_attentions,
|
1604 |
output_hidden_states=output_hidden_states,
|
|
|
1631 |
return NemotronHCausalLMOutput(
|
1632 |
loss=loss,
|
1633 |
logits=logits,
|
1634 |
+
past_key_values=nemotron_h_outputs.cache_params,
|
1635 |
hidden_states=nemotron_h_outputs.hidden_states,
|
1636 |
attentions=nemotron_h_outputs.attentions,
|
1637 |
)
|