Text Generation
Transformers
Safetensors
English
doge
conversational
custom_code
JingzeShi commited on
Commit
44e481b
·
verified ·
1 Parent(s): 24676a2

Upload DogeForCausalLM

Browse files
config.json CHANGED
@@ -1,43 +1,44 @@
1
- {
2
- "_name_or_path": "./results/Doge-60M",
3
- "architectures": [
4
- "DogeForCausalLM"
5
- ],
6
- "attention_dropout": 0.0,
7
- "auto_map": {
8
- "AutoConfig": "configuration_doge.DogeConfig",
9
- "AutoModelForCausalLM": "modeling_doge.DogeForCausalLM"
10
- },
11
- "bos_token_id": 0,
12
- "eos_token_id": 1,
13
- "expert_retrieval_size": 256,
14
- "hidden_act": "silu",
15
- "hidden_bias": false,
16
- "hidden_dropout": 0.0,
17
- "hidden_size": 512,
18
- "initializer_range": 0.02,
19
- "intermediate_size": 1024,
20
- "is_moe": false,
21
- "max_position_embeddings": 2048,
22
- "model_type": "doge",
23
- "num_attention_heads": 4,
24
- "num_cdmmoe_experts": 2048,
25
- "num_cdmmoe_experts_per_head": 8,
26
- "num_cdmmoe_heads": 4,
27
- "num_channels": 3,
28
- "num_hidden_layers": 16,
29
- "num_key_value_heads": 2,
30
- "pad_token_id": 2,
31
- "patch_size": 16,
32
- "rms_norm_eps": 1e-06,
33
- "rope_scaling": {
34
- "factor": 4.0,
35
- "original_max_position_embeddings": 2048,
36
- "rope_type": "dynamic"
37
- },
38
- "rope_theta": 10000.0,
39
- "torch_dtype": "float32",
40
- "transformers_version": "4.46.1",
41
- "use_cache": true,
42
- "vocab_size": 32768
43
- }
 
 
1
+ {
2
+ "_name_or_path": "./results/Doge-60M",
3
+ "architectures": [
4
+ "DogeForCausalLM"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_doge.DogeConfig",
9
+ "AutoModelForCausalLM": "modeling_doge.DogeForCausalLM"
10
+ },
11
+ "bos_token_id": 0,
12
+ "dynamic_mask_ratio": 0.0,
13
+ "eos_token_id": 1,
14
+ "expert_retrieval_size": 256,
15
+ "hidden_act": "silu",
16
+ "hidden_bias": false,
17
+ "hidden_dropout": 0.0,
18
+ "hidden_size": 512,
19
+ "initializer_range": 0.02,
20
+ "intermediate_size": 1024,
21
+ "is_moe": false,
22
+ "max_position_embeddings": 2048,
23
+ "model_type": "doge",
24
+ "num_attention_heads": 4,
25
+ "num_cdmmoe_experts": 2048,
26
+ "num_cdmmoe_experts_per_head": 8,
27
+ "num_cdmmoe_heads": 4,
28
+ "num_channels": 3,
29
+ "num_hidden_layers": 16,
30
+ "num_key_value_heads": 2,
31
+ "pad_token_id": 2,
32
+ "patch_size": 16,
33
+ "rms_norm_eps": 1e-06,
34
+ "rope_scaling": {
35
+ "factor": 4.0,
36
+ "original_max_position_embeddings": 2048,
37
+ "rope_type": "dynamic"
38
+ },
39
+ "rope_theta": 10000.0,
40
+ "torch_dtype": "float32",
41
+ "transformers_version": "4.49.0.dev0",
42
+ "use_cache": true,
43
+ "vocab_size": 32768
44
+ }
configuration_doge.py CHANGED
@@ -111,6 +111,8 @@ class DogeConfig(PretrainedConfig):
111
  If it is not specified, will default to `num_attention_heads`.
112
  attention_dropout (`float`, *optional*, defaults to 0.0):
113
  The dropout ratio for the attention probabilities.
 
 
114
  is_moe (`bool`, *optional*, defaults to `False`):
115
  Whether to use the Cross Domain Mixture of Experts, if `True`, the MoE will inherit the MLP to initialize
116
  num_cdmmoe_experts (`int`, *optional*, defaults to 2048):
@@ -154,6 +156,7 @@ class DogeConfig(PretrainedConfig):
154
  num_attention_heads=8,
155
  num_key_value_heads=None,
156
  attention_dropout=0.0,
 
157
  is_moe=False,
158
  num_cdmmoe_experts=2048,
159
  num_cdmmoe_heads=4,
@@ -183,6 +186,7 @@ class DogeConfig(PretrainedConfig):
183
  self.num_attention_heads = num_attention_heads
184
  self.num_key_value_heads = num_key_value_heads
185
  self.attention_dropout = attention_dropout
 
186
  self.is_moe = is_moe
187
  self.num_cdmmoe_experts = num_cdmmoe_experts
188
  self.num_cdmmoe_heads = num_cdmmoe_heads
@@ -195,6 +199,10 @@ class DogeConfig(PretrainedConfig):
195
  self.rope_scaling["rope_type"] = self.rope_scaling["type"]
196
  rope_config_validation(self)
197
 
 
 
 
 
198
  super().__init__(
199
  bos_token_id=bos_token_id,
200
  eos_token_id=eos_token_id,
 
111
  If it is not specified, will default to `num_attention_heads`.
112
  attention_dropout (`float`, *optional*, defaults to 0.0):
113
  The dropout ratio for the attention probabilities.
114
+ dynamic_mask_ratio (`float`, *optional*, defaults to 0.0, range [0, 1]):
115
+ The ratio to control the proportion of the dynamic mask filled with the minimum value.
116
  is_moe (`bool`, *optional*, defaults to `False`):
117
  Whether to use the Cross Domain Mixture of Experts, if `True`, the MoE will inherit the MLP to initialize
118
  num_cdmmoe_experts (`int`, *optional*, defaults to 2048):
 
156
  num_attention_heads=8,
157
  num_key_value_heads=None,
158
  attention_dropout=0.0,
159
+ dynamic_mask_ratio=0.0,
160
  is_moe=False,
161
  num_cdmmoe_experts=2048,
162
  num_cdmmoe_heads=4,
 
186
  self.num_attention_heads = num_attention_heads
187
  self.num_key_value_heads = num_key_value_heads
188
  self.attention_dropout = attention_dropout
189
+ self.dynamic_mask_ratio = dynamic_mask_ratio
190
  self.is_moe = is_moe
191
  self.num_cdmmoe_experts = num_cdmmoe_experts
192
  self.num_cdmmoe_heads = num_cdmmoe_heads
 
199
  self.rope_scaling["rope_type"] = self.rope_scaling["type"]
200
  rope_config_validation(self)
201
 
202
+ # for backward compatibility
203
+ if num_key_value_heads is None:
204
+ self.num_key_value_heads = num_attention_heads
205
+
206
  super().__init__(
207
  bos_token_id=bos_token_id,
208
  eos_token_id=eos_token_id,
generation_config.json CHANGED
@@ -1,7 +1,7 @@
1
- {
2
- "_from_model_config": true,
3
- "bos_token_id": 0,
4
- "eos_token_id": 1,
5
- "pad_token_id": 2,
6
- "transformers_version": "4.46.1"
7
- }
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 0,
4
+ "eos_token_id": 1,
5
+ "pad_token_id": 2,
6
+ "transformers_version": "4.49.0.dev0"
7
+ }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f6ff7db0f6721882934053a9c20eec73c33b55fc47ef428e20a0e91391738985
3
- size 218391112
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:550dbbf30bc9f8b88c7ac4136a1412414be8db29a5146b9f0bab2e795ab991e5
3
+ size 218325576
modeling_doge.py CHANGED
@@ -22,6 +22,7 @@ import math
22
  from typing import List, Optional, Tuple, Union
23
 
24
  import torch
 
25
  import torch.nn.functional as F
26
  import torch.utils.checkpoint
27
  from torch import nn
@@ -216,14 +217,15 @@ class DogeDynamicMaskAttention(nn.Module):
216
  self.num_key_value_heads = config.num_key_value_heads
217
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
218
  self.attention_dropout = config.attention_dropout
 
219
 
220
  # Q K V O projections
221
  self.q_proj = nn.Linear(self.hidden_dim, self.num_heads * self.head_dim, bias=config.hidden_bias)
222
  self.k_proj = nn.Linear(self.hidden_dim, self.num_key_value_heads * self.head_dim, bias=config.hidden_bias)
 
223
  # dynamic mask for the QK^T attention score matrix
224
  self.A = nn.Parameter(torch.ones(self.num_heads))
225
- self.dt_proj = nn.Linear(self.hidden_dim, self.num_heads, bias=config.hidden_bias)
226
- self.v_proj = nn.Linear(self.hidden_dim, self.num_key_value_heads * self.head_dim, bias=config.hidden_bias)
227
  self.o_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=config.hidden_bias)
228
 
229
  def forward(
@@ -254,6 +256,10 @@ class DogeDynamicMaskAttention(nn.Module):
254
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
255
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
256
 
 
 
 
 
257
  # repeat key and value states
258
  key_states = repeat_kv(key_states, self.num_key_value_groups)
259
  value_states = repeat_kv(value_states, self.num_key_value_groups)
@@ -262,12 +268,13 @@ class DogeDynamicMaskAttention(nn.Module):
262
  attn_weights = torch.matmul(query_states, key_states.transpose(-1, -2)) / math.sqrt(self.head_dim)
263
 
264
  # add mask to attention scores
265
- if attention_mask is not None:
266
- dt_states = self.dt_proj(value_states.transpose(1, 2).reshape(bsz, value_states.shape[-2], -1))
267
- dynamic_mask = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
268
- dynamic_mask = dynamic_mask < 1.0
269
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]].masked_fill(dynamic_mask[:, :, None, :], torch.finfo(hidden_states.dtype).min)
270
- attn_weights = attn_weights + causal_mask
 
271
 
272
  # upcast attention scores to fp32
273
  attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
@@ -282,8 +289,37 @@ class DogeDynamicMaskAttention(nn.Module):
282
 
283
  return attn_output, past_key_value
284
 
 
 
 
 
 
 
 
 
 
285
 
286
- class DogeSdpaDynamicMaskAttn(DogeDynamicMaskAttention):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
 
288
  def forward(
289
  self,
@@ -312,34 +348,31 @@ class DogeSdpaDynamicMaskAttn(DogeDynamicMaskAttention):
312
  # sin and cos are specific to RoPE models; cache_position needed for the static cache
313
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
314
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
 
 
 
 
315
 
316
- # repeat key and value states
317
- key_states = repeat_kv(key_states, self.num_key_value_groups)
318
- value_states = repeat_kv(value_states, self.num_key_value_groups)
319
-
320
- causal_mask = attention_mask
321
- if attention_mask is not None:
322
- dt_states = self.dt_proj(value_states.transpose(1, 2).reshape(bsz, value_states.shape[-2], -1))
323
- dynamic_mask = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
324
- dynamic_mask = dynamic_mask < 1.0
325
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]].masked_fill(dynamic_mask[:, :, None, :], torch.finfo(hidden_states.dtype).min)
326
 
327
  query_states = query_states.contiguous()
328
  key_states = key_states.contiguous()
329
  value_states = value_states.contiguous()
330
 
331
- # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
332
- is_causal = True if causal_mask is None and q_len > 1 else False
333
-
334
  # NOTE: As of pytorch 2.5.1, cuDNN's SDPA backward pass is still incorrect, so we disable cuDNN SDPA (see https://github.com/pytorch/pytorch/issues/138581)
335
  torch.backends.cuda.enable_cudnn_sdp(False)
336
  attn_output = F.scaled_dot_product_attention(
337
  query_states,
338
  key_states,
339
  value_states,
340
- attn_mask=causal_mask,
341
  dropout_p=self.attention_dropout if self.training else 0.0,
342
- is_causal=is_causal,
343
  )
344
 
345
  attn_output = attn_output.transpose(1, 2).contiguous()
@@ -349,9 +382,70 @@ class DogeSdpaDynamicMaskAttn(DogeDynamicMaskAttention):
349
  return attn_output, past_key_value
350
 
351
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  DOGE_ATTENTION_CLASSES = {
 
353
  "eager": DogeDynamicMaskAttention,
354
- "sdpa": DogeSdpaDynamicMaskAttn,
355
  }
356
 
357
 
@@ -519,6 +613,7 @@ class DogePreTrainedModel(PreTrainedModel):
519
  supports_gradient_checkpointing = True
520
  _no_split_modules = ["DogeDecoderLayer"]
521
  _skip_keys_device_placement = ["past_key_values"]
 
522
  _supports_sdpa = True
523
  _supports_cache_class = True
524
  _supports_quantized_cache = True
@@ -693,7 +788,7 @@ class DogeModel(DogePreTrainedModel):
693
  all_self_attns = () if output_attentions else None
694
  next_decoder_cache = None
695
 
696
- for decoder_layer in self.layers:
697
  if output_hidden_states:
698
  all_hidden_states += (hidden_states,)
699
 
@@ -877,7 +972,7 @@ class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
877
  input_ids: torch.LongTensor = None,
878
  attention_mask: Optional[torch.Tensor] = None,
879
  position_ids: Optional[torch.LongTensor] = None,
880
- past_key_values: Optional[torch.Tensor] = None,
881
  inputs_embeds: Optional[torch.FloatTensor] = None,
882
  labels: Optional[torch.LongTensor] = None,
883
  use_cache: Optional[bool] = None,
@@ -886,7 +981,7 @@ class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
886
  return_dict: Optional[bool] = None,
887
  cache_position: Optional[torch.LongTensor] = None,
888
  num_logits_to_keep: int = 0,
889
- **loss_kwargs,
890
  ) -> Union[Tuple, CausalLMOutputWithPast]:
891
  r"""
892
  Args:
@@ -920,6 +1015,7 @@ class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
920
  output_hidden_states=output_hidden_states,
921
  return_dict=return_dict,
922
  cache_position=cache_position,
 
923
  )
924
 
925
  hidden_states = outputs[0]
@@ -929,7 +1025,7 @@ class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
929
 
930
  loss = None
931
  if labels is not None:
932
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.vocab_size, **loss_kwargs)
933
 
934
  if not return_dict:
935
  output = (logits,) + outputs[1:]
 
22
  from typing import List, Optional, Tuple, Union
23
 
24
  import torch
25
+ from torch.nn.attention.flex_attention import flex_attention
26
  import torch.nn.functional as F
27
  import torch.utils.checkpoint
28
  from torch import nn
 
217
  self.num_key_value_heads = config.num_key_value_heads
218
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
219
  self.attention_dropout = config.attention_dropout
220
+ self.dynamic_mask_ratio = config.dynamic_mask_ratio
221
 
222
  # Q K V O projections
223
  self.q_proj = nn.Linear(self.hidden_dim, self.num_heads * self.head_dim, bias=config.hidden_bias)
224
  self.k_proj = nn.Linear(self.hidden_dim, self.num_key_value_heads * self.head_dim, bias=config.hidden_bias)
225
+ self.v_proj = nn.Linear(self.hidden_dim, self.num_key_value_heads * self.head_dim, bias=config.hidden_bias)
226
  # dynamic mask for the QK^T attention score matrix
227
  self.A = nn.Parameter(torch.ones(self.num_heads))
228
+ self.dt_proj = nn.Linear(self.num_key_value_heads * self.head_dim, self.num_heads, bias=config.hidden_bias)
 
229
  self.o_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=config.hidden_bias)
230
 
231
  def forward(
 
256
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
257
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
258
 
259
+ # calculate dynamic mask from value_states
260
+ dt_states = self.dt_proj(value_states.transpose(1, 2).reshape(bsz, value_states.shape[-2], -1))
261
+ dynamic_mask = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
262
+
263
  # repeat key and value states
264
  key_states = repeat_kv(key_states, self.num_key_value_groups)
265
  value_states = repeat_kv(value_states, self.num_key_value_groups)
 
268
  attn_weights = torch.matmul(query_states, key_states.transpose(-1, -2)) / math.sqrt(self.head_dim)
269
 
270
  # add mask to attention scores
271
+ attn_mask = self.prepare_dynamic_mask(
272
+ hidden_states=hidden_states,
273
+ dynamic_mask=dynamic_mask,
274
+ dynamic_mask_ratio=self.dynamic_mask_ratio,
275
+ attention_mask=attention_mask,
276
+ )
277
+ attn_weights = attn_weights + attn_mask
278
 
279
  # upcast attention scores to fp32
280
  attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
 
289
 
290
  return attn_output, past_key_value
291
 
292
+ def prepare_dynamic_mask(
293
+ self,
294
+ hidden_states: torch.Tensor,
295
+ dynamic_mask: torch.Tensor,
296
+ dynamic_mask_ratio: float = 0.0,
297
+ attention_mask: Optional[torch.Tensor] = None,
298
+ ):
299
+ """
300
+ Combine `dynamic_mask` with `attention_mask` to generate the final `attn_mask`.
301
 
302
+ Args:
303
+ hidden_states (`torch.Tensor`): The input hidden_states, used to determine the minimum value of the current input precision.
304
+ dynamic_mask (`torch.Tensor`): dynamic mask of shape `(batch_size, num_heads, key_sequence_length)`.
305
+ dynamic_mask_ratio (`float`, *optional*): Ratio from 0.0 to 1.0 used to control the proportion of the dynamic mask filled with the minimum value.
306
+ attention_mask (`torch.Tensor`, *optional*): attention mask of shape `(batch_size, 1, query_sequence_length, key_sequence_length)`.
307
+ """
308
+ min_type = torch.finfo(hidden_states.dtype).min
309
+ attn_mask = dynamic_mask[:, :, None, :]
310
+ if 0.0 < dynamic_mask_ratio < 1.0:
311
+ rate_value = torch.kthvalue(
312
+ attn_mask,
313
+ int(attn_mask.shape[-1] * dynamic_mask_ratio),
314
+ dim=-1, keepdim=True,
315
+ ).values
316
+ attn_mask = attn_mask.masked_fill(attn_mask < rate_value, min_type)
317
+ if attention_mask is not None:
318
+ attn_mask = attn_mask.masked_fill(attention_mask[:, :, :, : hidden_states.shape[-2]] == min_type, min_type)
319
+ return attn_mask
320
+
321
+
322
+ class DogeSdpaDynamicMaskAttention(DogeDynamicMaskAttention):
323
 
324
  def forward(
325
  self,
 
348
  # sin and cos are specific to RoPE models; cache_position needed for the static cache
349
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
350
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
351
+
352
+ # calculate dynamic mask from value_states
353
+ dt_states = self.dt_proj(value_states.transpose(1, 2).reshape(bsz, value_states.shape[-2], -1))
354
+ dynamic_mask = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
355
 
356
+ attn_mask = self.prepare_dynamic_mask(
357
+ hidden_states=hidden_states,
358
+ dynamic_mask=dynamic_mask,
359
+ dynamic_mask_ratio=self.dynamic_mask_ratio,
360
+ attention_mask=attention_mask,
361
+ )
 
 
 
 
362
 
363
  query_states = query_states.contiguous()
364
  key_states = key_states.contiguous()
365
  value_states = value_states.contiguous()
366
 
 
 
 
367
  # NOTE: As of pytorch 2.5.1, cuDNN's SDPA backward pass is still incorrect, so we disable cuDNN SDPA (see https://github.com/pytorch/pytorch/issues/138581)
368
  torch.backends.cuda.enable_cudnn_sdp(False)
369
  attn_output = F.scaled_dot_product_attention(
370
  query_states,
371
  key_states,
372
  value_states,
373
+ attn_mask=attn_mask,
374
  dropout_p=self.attention_dropout if self.training else 0.0,
375
+ enable_gqa=True,
376
  )
377
 
378
  attn_output = attn_output.transpose(1, 2).contiguous()
 
382
  return attn_output, past_key_value
383
 
384
 
385
+ class DogeFlexDynamicMaskAttention(DogeDynamicMaskAttention):
386
+
387
+ def forward(
388
+ self,
389
+ hidden_states: torch.Tensor,
390
+ attention_mask: Optional[torch.Tensor] = None,
391
+ position_ids: Optional[torch.LongTensor] = None,
392
+ past_key_value: Optional[Cache] = None,
393
+ cache_position: Optional[torch.LongTensor] = None,
394
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
395
+ **kwargs,
396
+ ) -> Tuple[torch.Tensor, Optional[Cache]]:
397
+ bsz, q_len, _ = hidden_states.shape
398
+
399
+ query_states = self.q_proj(hidden_states)
400
+ key_states = self.k_proj(hidden_states)
401
+ value_states = self.v_proj(hidden_states)
402
+
403
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
404
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
405
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
406
+
407
+ cos, sin = position_embeddings
408
+ query_states, key_states = apply_QK_rotary_pos_emb(query_states, key_states, cos, sin)
409
+
410
+ if past_key_value is not None:
411
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
412
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
413
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
414
+
415
+ dt_states = self.dt_proj(value_states.transpose(1, 2).reshape(bsz, value_states.shape[-2], -1))
416
+ dynamic_mask = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
417
+
418
+ attn_mask = self.prepare_dynamic_mask(
419
+ hidden_states=hidden_states,
420
+ dynamic_mask=dynamic_mask,
421
+ dynamic_mask_ratio=self.dynamic_mask_ratio,
422
+ attention_mask=attention_mask,
423
+ )
424
+ # TODO: flex_attention: Captured buffers that require grad are not yet supported.
425
+ # NOTE: So we only use flex_attention in inference mode.
426
+ def dynamic_mask_mod(score, batch, head, q_idx, kv_idx):
427
+ score = score + attn_mask[batch][head][q_idx][kv_idx]
428
+ return score
429
+
430
+ attn_output = flex_attention(
431
+ query_states,
432
+ key_states,
433
+ value_states,
434
+ score_mod=dynamic_mask_mod,
435
+ enable_gqa=True,
436
+ )
437
+
438
+ attn_output = attn_output.transpose(1, 2).contiguous()
439
+ attn_output = attn_output.view(bsz, q_len, -1)
440
+ attn_output = self.o_proj(attn_output)
441
+
442
+ return attn_output, past_key_value
443
+
444
+
445
  DOGE_ATTENTION_CLASSES = {
446
+ "flex_attention": DogeFlexDynamicMaskAttention,
447
  "eager": DogeDynamicMaskAttention,
448
+ "sdpa": DogeSdpaDynamicMaskAttention,
449
  }
450
 
451
 
 
613
  supports_gradient_checkpointing = True
614
  _no_split_modules = ["DogeDecoderLayer"]
615
  _skip_keys_device_placement = ["past_key_values"]
616
+ _supports_flex_attn = True
617
  _supports_sdpa = True
618
  _supports_cache_class = True
619
  _supports_quantized_cache = True
 
788
  all_self_attns = () if output_attentions else None
789
  next_decoder_cache = None
790
 
791
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
792
  if output_hidden_states:
793
  all_hidden_states += (hidden_states,)
794
 
 
972
  input_ids: torch.LongTensor = None,
973
  attention_mask: Optional[torch.Tensor] = None,
974
  position_ids: Optional[torch.LongTensor] = None,
975
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
976
  inputs_embeds: Optional[torch.FloatTensor] = None,
977
  labels: Optional[torch.LongTensor] = None,
978
  use_cache: Optional[bool] = None,
 
981
  return_dict: Optional[bool] = None,
982
  cache_position: Optional[torch.LongTensor] = None,
983
  num_logits_to_keep: int = 0,
984
+ **kwargs,
985
  ) -> Union[Tuple, CausalLMOutputWithPast]:
986
  r"""
987
  Args:
 
1015
  output_hidden_states=output_hidden_states,
1016
  return_dict=return_dict,
1017
  cache_position=cache_position,
1018
+ **kwargs,
1019
  )
1020
 
1021
  hidden_states = outputs[0]
 
1025
 
1026
  loss = None
1027
  if labels is not None:
1028
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.vocab_size, **kwargs)
1029
 
1030
  if not return_dict:
1031
  output = (logits,) + outputs[1:]