Upload DogeForCausalLM
Browse files- config.json +44 -43
- configuration_doge.py +8 -0
- generation_config.json +7 -7
- model.safetensors +2 -2
- modeling_doge.py +125 -29
config.json
CHANGED
@@ -1,43 +1,44 @@
|
|
1 |
-
{
|
2 |
-
"_name_or_path": "./results/Doge-60M",
|
3 |
-
"architectures": [
|
4 |
-
"DogeForCausalLM"
|
5 |
-
],
|
6 |
-
"attention_dropout": 0.0,
|
7 |
-
"auto_map": {
|
8 |
-
"AutoConfig": "configuration_doge.DogeConfig",
|
9 |
-
"AutoModelForCausalLM": "modeling_doge.DogeForCausalLM"
|
10 |
-
},
|
11 |
-
"bos_token_id": 0,
|
12 |
-
"
|
13 |
-
"
|
14 |
-
"
|
15 |
-
"
|
16 |
-
"
|
17 |
-
"
|
18 |
-
"
|
19 |
-
"
|
20 |
-
"
|
21 |
-
"
|
22 |
-
"
|
23 |
-
"
|
24 |
-
"
|
25 |
-
"
|
26 |
-
"
|
27 |
-
"
|
28 |
-
"
|
29 |
-
"
|
30 |
-
"
|
31 |
-
"
|
32 |
-
"
|
33 |
-
"
|
34 |
-
|
35 |
-
"
|
36 |
-
"
|
37 |
-
|
38 |
-
|
39 |
-
"
|
40 |
-
"
|
41 |
-
"
|
42 |
-
"
|
43 |
-
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "./results/Doge-60M",
|
3 |
+
"architectures": [
|
4 |
+
"DogeForCausalLM"
|
5 |
+
],
|
6 |
+
"attention_dropout": 0.0,
|
7 |
+
"auto_map": {
|
8 |
+
"AutoConfig": "configuration_doge.DogeConfig",
|
9 |
+
"AutoModelForCausalLM": "modeling_doge.DogeForCausalLM"
|
10 |
+
},
|
11 |
+
"bos_token_id": 0,
|
12 |
+
"dynamic_mask_ratio": 0.0,
|
13 |
+
"eos_token_id": 1,
|
14 |
+
"expert_retrieval_size": 256,
|
15 |
+
"hidden_act": "silu",
|
16 |
+
"hidden_bias": false,
|
17 |
+
"hidden_dropout": 0.0,
|
18 |
+
"hidden_size": 512,
|
19 |
+
"initializer_range": 0.02,
|
20 |
+
"intermediate_size": 1024,
|
21 |
+
"is_moe": false,
|
22 |
+
"max_position_embeddings": 2048,
|
23 |
+
"model_type": "doge",
|
24 |
+
"num_attention_heads": 4,
|
25 |
+
"num_cdmmoe_experts": 2048,
|
26 |
+
"num_cdmmoe_experts_per_head": 8,
|
27 |
+
"num_cdmmoe_heads": 4,
|
28 |
+
"num_channels": 3,
|
29 |
+
"num_hidden_layers": 16,
|
30 |
+
"num_key_value_heads": 2,
|
31 |
+
"pad_token_id": 2,
|
32 |
+
"patch_size": 16,
|
33 |
+
"rms_norm_eps": 1e-06,
|
34 |
+
"rope_scaling": {
|
35 |
+
"factor": 4.0,
|
36 |
+
"original_max_position_embeddings": 2048,
|
37 |
+
"rope_type": "dynamic"
|
38 |
+
},
|
39 |
+
"rope_theta": 10000.0,
|
40 |
+
"torch_dtype": "float32",
|
41 |
+
"transformers_version": "4.49.0.dev0",
|
42 |
+
"use_cache": true,
|
43 |
+
"vocab_size": 32768
|
44 |
+
}
|
configuration_doge.py
CHANGED
@@ -111,6 +111,8 @@ class DogeConfig(PretrainedConfig):
|
|
111 |
If it is not specified, will default to `num_attention_heads`.
|
112 |
attention_dropout (`float`, *optional*, defaults to 0.0):
|
113 |
The dropout ratio for the attention probabilities.
|
|
|
|
|
114 |
is_moe (`bool`, *optional*, defaults to `False`):
|
115 |
Whether to use the Cross Domain Mixture of Experts, if `True`, the MoE will inherit the MLP to initialize
|
116 |
num_cdmmoe_experts (`int`, *optional*, defaults to 2048):
|
@@ -154,6 +156,7 @@ class DogeConfig(PretrainedConfig):
|
|
154 |
num_attention_heads=8,
|
155 |
num_key_value_heads=None,
|
156 |
attention_dropout=0.0,
|
|
|
157 |
is_moe=False,
|
158 |
num_cdmmoe_experts=2048,
|
159 |
num_cdmmoe_heads=4,
|
@@ -183,6 +186,7 @@ class DogeConfig(PretrainedConfig):
|
|
183 |
self.num_attention_heads = num_attention_heads
|
184 |
self.num_key_value_heads = num_key_value_heads
|
185 |
self.attention_dropout = attention_dropout
|
|
|
186 |
self.is_moe = is_moe
|
187 |
self.num_cdmmoe_experts = num_cdmmoe_experts
|
188 |
self.num_cdmmoe_heads = num_cdmmoe_heads
|
@@ -195,6 +199,10 @@ class DogeConfig(PretrainedConfig):
|
|
195 |
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
196 |
rope_config_validation(self)
|
197 |
|
|
|
|
|
|
|
|
|
198 |
super().__init__(
|
199 |
bos_token_id=bos_token_id,
|
200 |
eos_token_id=eos_token_id,
|
|
|
111 |
If it is not specified, will default to `num_attention_heads`.
|
112 |
attention_dropout (`float`, *optional*, defaults to 0.0):
|
113 |
The dropout ratio for the attention probabilities.
|
114 |
+
dynamic_mask_ratio (`float`, *optional*, defaults to 0.0, range [0, 1]):
|
115 |
+
The ratio to control the proportion of the dynamic mask filled with the minimum value.
|
116 |
is_moe (`bool`, *optional*, defaults to `False`):
|
117 |
Whether to use the Cross Domain Mixture of Experts, if `True`, the MoE will inherit the MLP to initialize
|
118 |
num_cdmmoe_experts (`int`, *optional*, defaults to 2048):
|
|
|
156 |
num_attention_heads=8,
|
157 |
num_key_value_heads=None,
|
158 |
attention_dropout=0.0,
|
159 |
+
dynamic_mask_ratio=0.0,
|
160 |
is_moe=False,
|
161 |
num_cdmmoe_experts=2048,
|
162 |
num_cdmmoe_heads=4,
|
|
|
186 |
self.num_attention_heads = num_attention_heads
|
187 |
self.num_key_value_heads = num_key_value_heads
|
188 |
self.attention_dropout = attention_dropout
|
189 |
+
self.dynamic_mask_ratio = dynamic_mask_ratio
|
190 |
self.is_moe = is_moe
|
191 |
self.num_cdmmoe_experts = num_cdmmoe_experts
|
192 |
self.num_cdmmoe_heads = num_cdmmoe_heads
|
|
|
199 |
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
200 |
rope_config_validation(self)
|
201 |
|
202 |
+
# for backward compatibility
|
203 |
+
if num_key_value_heads is None:
|
204 |
+
self.num_key_value_heads = num_attention_heads
|
205 |
+
|
206 |
super().__init__(
|
207 |
bos_token_id=bos_token_id,
|
208 |
eos_token_id=eos_token_id,
|
generation_config.json
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
-
{
|
2 |
-
"_from_model_config": true,
|
3 |
-
"bos_token_id": 0,
|
4 |
-
"eos_token_id": 1,
|
5 |
-
"pad_token_id": 2,
|
6 |
-
"transformers_version": "4.
|
7 |
-
}
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"bos_token_id": 0,
|
4 |
+
"eos_token_id": 1,
|
5 |
+
"pad_token_id": 2,
|
6 |
+
"transformers_version": "4.49.0.dev0"
|
7 |
+
}
|
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:550dbbf30bc9f8b88c7ac4136a1412414be8db29a5146b9f0bab2e795ab991e5
|
3 |
+
size 218325576
|
modeling_doge.py
CHANGED
@@ -22,6 +22,7 @@ import math
|
|
22 |
from typing import List, Optional, Tuple, Union
|
23 |
|
24 |
import torch
|
|
|
25 |
import torch.nn.functional as F
|
26 |
import torch.utils.checkpoint
|
27 |
from torch import nn
|
@@ -216,14 +217,15 @@ class DogeDynamicMaskAttention(nn.Module):
|
|
216 |
self.num_key_value_heads = config.num_key_value_heads
|
217 |
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
218 |
self.attention_dropout = config.attention_dropout
|
|
|
219 |
|
220 |
# Q K V O projections
|
221 |
self.q_proj = nn.Linear(self.hidden_dim, self.num_heads * self.head_dim, bias=config.hidden_bias)
|
222 |
self.k_proj = nn.Linear(self.hidden_dim, self.num_key_value_heads * self.head_dim, bias=config.hidden_bias)
|
|
|
223 |
# dynamic mask for the QK^T attention score matrix
|
224 |
self.A = nn.Parameter(torch.ones(self.num_heads))
|
225 |
-
self.dt_proj = nn.Linear(self.
|
226 |
-
self.v_proj = nn.Linear(self.hidden_dim, self.num_key_value_heads * self.head_dim, bias=config.hidden_bias)
|
227 |
self.o_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=config.hidden_bias)
|
228 |
|
229 |
def forward(
|
@@ -254,6 +256,10 @@ class DogeDynamicMaskAttention(nn.Module):
|
|
254 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
255 |
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
256 |
|
|
|
|
|
|
|
|
|
257 |
# repeat key and value states
|
258 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
259 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
@@ -262,12 +268,13 @@ class DogeDynamicMaskAttention(nn.Module):
|
|
262 |
attn_weights = torch.matmul(query_states, key_states.transpose(-1, -2)) / math.sqrt(self.head_dim)
|
263 |
|
264 |
# add mask to attention scores
|
265 |
-
|
266 |
-
|
267 |
-
dynamic_mask
|
268 |
-
|
269 |
-
|
270 |
-
|
|
|
271 |
|
272 |
# upcast attention scores to fp32
|
273 |
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
@@ -282,8 +289,37 @@ class DogeDynamicMaskAttention(nn.Module):
|
|
282 |
|
283 |
return attn_output, past_key_value
|
284 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
|
286 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
287 |
|
288 |
def forward(
|
289 |
self,
|
@@ -312,34 +348,31 @@ class DogeSdpaDynamicMaskAttn(DogeDynamicMaskAttention):
|
|
312 |
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
313 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
314 |
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
|
|
|
|
|
|
|
315 |
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
dt_states = self.dt_proj(value_states.transpose(1, 2).reshape(bsz, value_states.shape[-2], -1))
|
323 |
-
dynamic_mask = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
|
324 |
-
dynamic_mask = dynamic_mask < 1.0
|
325 |
-
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]].masked_fill(dynamic_mask[:, :, None, :], torch.finfo(hidden_states.dtype).min)
|
326 |
|
327 |
query_states = query_states.contiguous()
|
328 |
key_states = key_states.contiguous()
|
329 |
value_states = value_states.contiguous()
|
330 |
|
331 |
-
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
332 |
-
is_causal = True if causal_mask is None and q_len > 1 else False
|
333 |
-
|
334 |
# NOTE: As of pytorch 2.5.1, cuDNN's SDPA backward pass is still incorrect, so we disable cuDNN SDPA (see https://github.com/pytorch/pytorch/issues/138581)
|
335 |
torch.backends.cuda.enable_cudnn_sdp(False)
|
336 |
attn_output = F.scaled_dot_product_attention(
|
337 |
query_states,
|
338 |
key_states,
|
339 |
value_states,
|
340 |
-
attn_mask=
|
341 |
dropout_p=self.attention_dropout if self.training else 0.0,
|
342 |
-
|
343 |
)
|
344 |
|
345 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
@@ -349,9 +382,70 @@ class DogeSdpaDynamicMaskAttn(DogeDynamicMaskAttention):
|
|
349 |
return attn_output, past_key_value
|
350 |
|
351 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
352 |
DOGE_ATTENTION_CLASSES = {
|
|
|
353 |
"eager": DogeDynamicMaskAttention,
|
354 |
-
"sdpa":
|
355 |
}
|
356 |
|
357 |
|
@@ -519,6 +613,7 @@ class DogePreTrainedModel(PreTrainedModel):
|
|
519 |
supports_gradient_checkpointing = True
|
520 |
_no_split_modules = ["DogeDecoderLayer"]
|
521 |
_skip_keys_device_placement = ["past_key_values"]
|
|
|
522 |
_supports_sdpa = True
|
523 |
_supports_cache_class = True
|
524 |
_supports_quantized_cache = True
|
@@ -693,7 +788,7 @@ class DogeModel(DogePreTrainedModel):
|
|
693 |
all_self_attns = () if output_attentions else None
|
694 |
next_decoder_cache = None
|
695 |
|
696 |
-
for decoder_layer in self.layers:
|
697 |
if output_hidden_states:
|
698 |
all_hidden_states += (hidden_states,)
|
699 |
|
@@ -877,7 +972,7 @@ class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
|
|
877 |
input_ids: torch.LongTensor = None,
|
878 |
attention_mask: Optional[torch.Tensor] = None,
|
879 |
position_ids: Optional[torch.LongTensor] = None,
|
880 |
-
past_key_values: Optional[torch.
|
881 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
882 |
labels: Optional[torch.LongTensor] = None,
|
883 |
use_cache: Optional[bool] = None,
|
@@ -886,7 +981,7 @@ class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
|
|
886 |
return_dict: Optional[bool] = None,
|
887 |
cache_position: Optional[torch.LongTensor] = None,
|
888 |
num_logits_to_keep: int = 0,
|
889 |
-
**
|
890 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
891 |
r"""
|
892 |
Args:
|
@@ -920,6 +1015,7 @@ class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
|
|
920 |
output_hidden_states=output_hidden_states,
|
921 |
return_dict=return_dict,
|
922 |
cache_position=cache_position,
|
|
|
923 |
)
|
924 |
|
925 |
hidden_states = outputs[0]
|
@@ -929,7 +1025,7 @@ class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
|
|
929 |
|
930 |
loss = None
|
931 |
if labels is not None:
|
932 |
-
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.vocab_size, **
|
933 |
|
934 |
if not return_dict:
|
935 |
output = (logits,) + outputs[1:]
|
|
|
22 |
from typing import List, Optional, Tuple, Union
|
23 |
|
24 |
import torch
|
25 |
+
from torch.nn.attention.flex_attention import flex_attention
|
26 |
import torch.nn.functional as F
|
27 |
import torch.utils.checkpoint
|
28 |
from torch import nn
|
|
|
217 |
self.num_key_value_heads = config.num_key_value_heads
|
218 |
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
219 |
self.attention_dropout = config.attention_dropout
|
220 |
+
self.dynamic_mask_ratio = config.dynamic_mask_ratio
|
221 |
|
222 |
# Q K V O projections
|
223 |
self.q_proj = nn.Linear(self.hidden_dim, self.num_heads * self.head_dim, bias=config.hidden_bias)
|
224 |
self.k_proj = nn.Linear(self.hidden_dim, self.num_key_value_heads * self.head_dim, bias=config.hidden_bias)
|
225 |
+
self.v_proj = nn.Linear(self.hidden_dim, self.num_key_value_heads * self.head_dim, bias=config.hidden_bias)
|
226 |
# dynamic mask for the QK^T attention score matrix
|
227 |
self.A = nn.Parameter(torch.ones(self.num_heads))
|
228 |
+
self.dt_proj = nn.Linear(self.num_key_value_heads * self.head_dim, self.num_heads, bias=config.hidden_bias)
|
|
|
229 |
self.o_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=config.hidden_bias)
|
230 |
|
231 |
def forward(
|
|
|
256 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
257 |
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
258 |
|
259 |
+
# calculate dynamic mask from value_states
|
260 |
+
dt_states = self.dt_proj(value_states.transpose(1, 2).reshape(bsz, value_states.shape[-2], -1))
|
261 |
+
dynamic_mask = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
|
262 |
+
|
263 |
# repeat key and value states
|
264 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
265 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
|
268 |
attn_weights = torch.matmul(query_states, key_states.transpose(-1, -2)) / math.sqrt(self.head_dim)
|
269 |
|
270 |
# add mask to attention scores
|
271 |
+
attn_mask = self.prepare_dynamic_mask(
|
272 |
+
hidden_states=hidden_states,
|
273 |
+
dynamic_mask=dynamic_mask,
|
274 |
+
dynamic_mask_ratio=self.dynamic_mask_ratio,
|
275 |
+
attention_mask=attention_mask,
|
276 |
+
)
|
277 |
+
attn_weights = attn_weights + attn_mask
|
278 |
|
279 |
# upcast attention scores to fp32
|
280 |
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
|
|
289 |
|
290 |
return attn_output, past_key_value
|
291 |
|
292 |
+
def prepare_dynamic_mask(
|
293 |
+
self,
|
294 |
+
hidden_states: torch.Tensor,
|
295 |
+
dynamic_mask: torch.Tensor,
|
296 |
+
dynamic_mask_ratio: float = 0.0,
|
297 |
+
attention_mask: Optional[torch.Tensor] = None,
|
298 |
+
):
|
299 |
+
"""
|
300 |
+
Combine `dynamic_mask` with `attention_mask` to generate the final `attn_mask`.
|
301 |
|
302 |
+
Args:
|
303 |
+
hidden_states (`torch.Tensor`): The input hidden_states, used to determine the minimum value of the current input precision.
|
304 |
+
dynamic_mask (`torch.Tensor`): dynamic mask of shape `(batch_size, num_heads, key_sequence_length)`.
|
305 |
+
dynamic_mask_ratio (`float`, *optional*): Ratio from 0.0 to 1.0 used to control the proportion of the dynamic mask filled with the minimum value.
|
306 |
+
attention_mask (`torch.Tensor`, *optional*): attention mask of shape `(batch_size, 1, query_sequence_length, key_sequence_length)`.
|
307 |
+
"""
|
308 |
+
min_type = torch.finfo(hidden_states.dtype).min
|
309 |
+
attn_mask = dynamic_mask[:, :, None, :]
|
310 |
+
if 0.0 < dynamic_mask_ratio < 1.0:
|
311 |
+
rate_value = torch.kthvalue(
|
312 |
+
attn_mask,
|
313 |
+
int(attn_mask.shape[-1] * dynamic_mask_ratio),
|
314 |
+
dim=-1, keepdim=True,
|
315 |
+
).values
|
316 |
+
attn_mask = attn_mask.masked_fill(attn_mask < rate_value, min_type)
|
317 |
+
if attention_mask is not None:
|
318 |
+
attn_mask = attn_mask.masked_fill(attention_mask[:, :, :, : hidden_states.shape[-2]] == min_type, min_type)
|
319 |
+
return attn_mask
|
320 |
+
|
321 |
+
|
322 |
+
class DogeSdpaDynamicMaskAttention(DogeDynamicMaskAttention):
|
323 |
|
324 |
def forward(
|
325 |
self,
|
|
|
348 |
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
349 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
350 |
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
351 |
+
|
352 |
+
# calculate dynamic mask from value_states
|
353 |
+
dt_states = self.dt_proj(value_states.transpose(1, 2).reshape(bsz, value_states.shape[-2], -1))
|
354 |
+
dynamic_mask = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
|
355 |
|
356 |
+
attn_mask = self.prepare_dynamic_mask(
|
357 |
+
hidden_states=hidden_states,
|
358 |
+
dynamic_mask=dynamic_mask,
|
359 |
+
dynamic_mask_ratio=self.dynamic_mask_ratio,
|
360 |
+
attention_mask=attention_mask,
|
361 |
+
)
|
|
|
|
|
|
|
|
|
362 |
|
363 |
query_states = query_states.contiguous()
|
364 |
key_states = key_states.contiguous()
|
365 |
value_states = value_states.contiguous()
|
366 |
|
|
|
|
|
|
|
367 |
# NOTE: As of pytorch 2.5.1, cuDNN's SDPA backward pass is still incorrect, so we disable cuDNN SDPA (see https://github.com/pytorch/pytorch/issues/138581)
|
368 |
torch.backends.cuda.enable_cudnn_sdp(False)
|
369 |
attn_output = F.scaled_dot_product_attention(
|
370 |
query_states,
|
371 |
key_states,
|
372 |
value_states,
|
373 |
+
attn_mask=attn_mask,
|
374 |
dropout_p=self.attention_dropout if self.training else 0.0,
|
375 |
+
enable_gqa=True,
|
376 |
)
|
377 |
|
378 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
382 |
return attn_output, past_key_value
|
383 |
|
384 |
|
385 |
+
class DogeFlexDynamicMaskAttention(DogeDynamicMaskAttention):
|
386 |
+
|
387 |
+
def forward(
|
388 |
+
self,
|
389 |
+
hidden_states: torch.Tensor,
|
390 |
+
attention_mask: Optional[torch.Tensor] = None,
|
391 |
+
position_ids: Optional[torch.LongTensor] = None,
|
392 |
+
past_key_value: Optional[Cache] = None,
|
393 |
+
cache_position: Optional[torch.LongTensor] = None,
|
394 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
395 |
+
**kwargs,
|
396 |
+
) -> Tuple[torch.Tensor, Optional[Cache]]:
|
397 |
+
bsz, q_len, _ = hidden_states.shape
|
398 |
+
|
399 |
+
query_states = self.q_proj(hidden_states)
|
400 |
+
key_states = self.k_proj(hidden_states)
|
401 |
+
value_states = self.v_proj(hidden_states)
|
402 |
+
|
403 |
+
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
404 |
+
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
405 |
+
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
406 |
+
|
407 |
+
cos, sin = position_embeddings
|
408 |
+
query_states, key_states = apply_QK_rotary_pos_emb(query_states, key_states, cos, sin)
|
409 |
+
|
410 |
+
if past_key_value is not None:
|
411 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
412 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
413 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
414 |
+
|
415 |
+
dt_states = self.dt_proj(value_states.transpose(1, 2).reshape(bsz, value_states.shape[-2], -1))
|
416 |
+
dynamic_mask = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
|
417 |
+
|
418 |
+
attn_mask = self.prepare_dynamic_mask(
|
419 |
+
hidden_states=hidden_states,
|
420 |
+
dynamic_mask=dynamic_mask,
|
421 |
+
dynamic_mask_ratio=self.dynamic_mask_ratio,
|
422 |
+
attention_mask=attention_mask,
|
423 |
+
)
|
424 |
+
# TODO: flex_attention: Captured buffers that require grad are not yet supported.
|
425 |
+
# NOTE: So we only use flex_attention in inference mode.
|
426 |
+
def dynamic_mask_mod(score, batch, head, q_idx, kv_idx):
|
427 |
+
score = score + attn_mask[batch][head][q_idx][kv_idx]
|
428 |
+
return score
|
429 |
+
|
430 |
+
attn_output = flex_attention(
|
431 |
+
query_states,
|
432 |
+
key_states,
|
433 |
+
value_states,
|
434 |
+
score_mod=dynamic_mask_mod,
|
435 |
+
enable_gqa=True,
|
436 |
+
)
|
437 |
+
|
438 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
439 |
+
attn_output = attn_output.view(bsz, q_len, -1)
|
440 |
+
attn_output = self.o_proj(attn_output)
|
441 |
+
|
442 |
+
return attn_output, past_key_value
|
443 |
+
|
444 |
+
|
445 |
DOGE_ATTENTION_CLASSES = {
|
446 |
+
"flex_attention": DogeFlexDynamicMaskAttention,
|
447 |
"eager": DogeDynamicMaskAttention,
|
448 |
+
"sdpa": DogeSdpaDynamicMaskAttention,
|
449 |
}
|
450 |
|
451 |
|
|
|
613 |
supports_gradient_checkpointing = True
|
614 |
_no_split_modules = ["DogeDecoderLayer"]
|
615 |
_skip_keys_device_placement = ["past_key_values"]
|
616 |
+
_supports_flex_attn = True
|
617 |
_supports_sdpa = True
|
618 |
_supports_cache_class = True
|
619 |
_supports_quantized_cache = True
|
|
|
788 |
all_self_attns = () if output_attentions else None
|
789 |
next_decoder_cache = None
|
790 |
|
791 |
+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
792 |
if output_hidden_states:
|
793 |
all_hidden_states += (hidden_states,)
|
794 |
|
|
|
972 |
input_ids: torch.LongTensor = None,
|
973 |
attention_mask: Optional[torch.Tensor] = None,
|
974 |
position_ids: Optional[torch.LongTensor] = None,
|
975 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
976 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
977 |
labels: Optional[torch.LongTensor] = None,
|
978 |
use_cache: Optional[bool] = None,
|
|
|
981 |
return_dict: Optional[bool] = None,
|
982 |
cache_position: Optional[torch.LongTensor] = None,
|
983 |
num_logits_to_keep: int = 0,
|
984 |
+
**kwargs,
|
985 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
986 |
r"""
|
987 |
Args:
|
|
|
1015 |
output_hidden_states=output_hidden_states,
|
1016 |
return_dict=return_dict,
|
1017 |
cache_position=cache_position,
|
1018 |
+
**kwargs,
|
1019 |
)
|
1020 |
|
1021 |
hidden_states = outputs[0]
|
|
|
1025 |
|
1026 |
loss = None
|
1027 |
if labels is not None:
|
1028 |
+
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.vocab_size, **kwargs)
|
1029 |
|
1030 |
if not return_dict:
|
1031 |
output = (logits,) + outputs[1:]
|