huihui-ai commited on
Commit
5643ecf
·
verified ·
1 Parent(s): 183e79b

Upload 4 files

Browse files
Files changed (4) hide show
  1. __init__.py +27 -0
  2. configuration_dots1.py +221 -0
  3. modeling_dots1.py +699 -0
  4. modular_dots1.py +94 -0
__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_dots1 import *
22
+ from .modeling_dots1 import *
23
+ else:
24
+ import sys
25
+
26
+ _file = globals()["__file__"]
27
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
configuration_dots1.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The rednote-hilab team and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #from ...configuration_utils import PretrainedConfig, layer_type_validation
16
+ #from ...utils import logging
17
+
18
+ from transformers.configuration_utils import PretrainedConfig, layer_type_validation
19
+ from transformers.utils import logging
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class Dots1Config(PretrainedConfig):
25
+ r"""
26
+ This is the configuration class to store the configuration of a [`Dots1Model`]. It is used to instantiate a
27
+ `dots.llm1` model according to the specified arguments, defining the model architecture. Instantiating a
28
+ configuration with the defaults will yield a similar configuration to that of
29
+ [rednote-hilab/dots.llm1.base](https://huggingface.co/rednote-hilab/dots.llm1.base).
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+ Args:
35
+ vocab_size (`int`, *optional*, defaults to 152064):
36
+ Vocabulary size of the model. Defines the number of different tokens that can be represented by the
37
+ `input_ids` passed when calling [`Dots1Model`].
38
+ hidden_size (`int`, *optional*, defaults to 4608):
39
+ Dimension of the hidden representations.
40
+ intermediate_size (`int`, *optional*, defaults to 10944):
41
+ Dimension of the MLP representations.
42
+ moe_intermediate_size (`int`, *optional*, defaults to 1408):
43
+ Dimension of the MoE representations.
44
+ num_hidden_layers (`int`, *optional*, defaults to 62):
45
+ Number of hidden layers in the Transformer decoder.
46
+ num_attention_heads (`int`, *optional*, defaults to 32):
47
+ Number of attention heads for each attention layer in the Transformer decoder.
48
+ num_key_value_heads (`int`, *optional*, defaults to 32):
49
+ Number of key/value heads for Grouped Query Attention. If `num_key_value_heads=num_attention_heads`, Multi
50
+ Head Attention (MHA) is used. If `num_key_value_heads=1`, Multi Query Attention (MQA) is used. Otherwise,
51
+ Grouped Query Attention (GQA) is used. If not specified, defaults to `num_attention_heads`.
52
+ n_shared_experts (`int`, *optional*, default=None):
53
+ Number of shared experts. None means dense model.
54
+ n_routed_experts (`int`, *optional*, default=None):
55
+ Number of routed experts. None means dense model.
56
+ n_group (`int`, *optional*, defaults to 1):
57
+ Number of groups for routed experts.
58
+ topk_group (`int`, *optional*, defaults to 1):
59
+ Number of selected groups for each token (selected experts only within `topk_group` groups).
60
+ num_experts_per_tok (`int`, *optional*, default=None):
61
+ Number of selected experts. None means dense model.
62
+ first_k_dense_replace (`int`, *optional*, defaults to 0):
63
+ Number of dense layers at the beginning of the model before the first MoE layer.
64
+ norm_topk_prob (`bool`, *optional*, defaults to `False`):
65
+ Whether to normalize the weights of the routed experts.
66
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
67
+ The non-linear activation function (function or string).
68
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
69
+ Maximum sequence length the model might ever be used with.
70
+ initializer_range (`float`, *optional*, defaults to 0.02):
71
+ Standard deviation of the truncated_normal_initializer for initializing all weight matrices.
72
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
73
+ Epsilon used by the RMS normalization layers.
74
+ use_cache (`bool`, *optional*, defaults to `True`):
75
+ Whether or not the model should return the last key/values attentions. Only relevant if `config.is_decoder=True`.
76
+ pretraining_tp (`int`, *optional*, defaults to 1):
77
+ Experimental: tensor parallelism rank used during pretraining. This is necessary for exact reproducibility
78
+ of pretraining results.
79
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
80
+ Whether to tie the input and output word embeddings.
81
+ rope_theta (`float`, *optional*, defaults to 10000.0):
82
+ The base period of the RoPE embeddings.
83
+ rope_scaling (`dict`, *optional*):
84
+ Dictionary for scaling RoPE embeddings. Supports `{"type": strategy name, "factor": scaling factor}`.
85
+ attention_bias (`bool`, *optional*, defaults to `False`):
86
+ Whether to use a bias in the self-attention projections.
87
+ attention_dropout (`float`, *optional*, defaults to 0.0):
88
+ Dropout ratio for the attention probabilities.
89
+ routed_scaling_factor (`float`, *optional*, defaults to 1.0):
90
+ Scaling factor for routed experts.
91
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
92
+ Whether to use sliding window attention.
93
+ sliding_window (`int`, *optional*, defaults to 4096):
94
+ Size of the sliding window for attention. If not specified, defaults to `4096`.
95
+ max_window_layers (`int`, *optional*, defaults to 62):
96
+ The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
97
+ layer_types (`list`, *optional*):
98
+ Attention pattern for each layer.
99
+
100
+ Examples:
101
+ ```python
102
+ >>> from transformers import Dots1Model, Dots1Config
103
+
104
+ >>> # Initializing a Dots1 style configuration
105
+ >>> configuration = Dots1Config()
106
+
107
+ >>> # Accessing the model configuration
108
+ >>> configuration = model.config
109
+ ```
110
+ """
111
+
112
+ model_type = "dots1"
113
+ keys_to_ignore_at_inference = ["past_key_values"]
114
+
115
+ base_model_tp_plan = { # TODO: only replicate attention layers when > first_k_dense_replace
116
+ "layers.*.self_attn.q_proj": "colwise",
117
+ "layers.*.self_attn.k_proj": "colwise",
118
+ "layers.*.self_attn.v_proj": "colwise",
119
+ "layers.*.self_attn.o_proj": "rowwise",
120
+ "layers.*.mlp.experts.*.gate_proj": "local_colwise",
121
+ "layers.*.mlp.experts.*.up_proj": "local_colwise",
122
+ "layers.*.mlp.experts.*.down_proj": "local_rowwise",
123
+ "layers.*.mlp.experts.*": "local", # each expert is wrapped in a module list
124
+ "layers.*.mlp.shared_experts.gate_proj": "local_colwise",
125
+ "layers.*.mlp.shared_experts.up_proj": "local_colwise",
126
+ "layers.*.mlp.shared_experts.down_proj": "local_rowwise",
127
+ "layers.*.mlp.shared_experts": "local",
128
+ "layers.*.mlp.gate_proj": "local_colwise",
129
+ "layers.*.mlp.up_proj": "local_colwise",
130
+ "layers.*.mlp.down_proj": "local_rowwise",
131
+ "layers.*.mlp": "gather", # This is the only moment where results are gathered
132
+ }
133
+
134
+ base_model_pp_plan = {
135
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
136
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
137
+ "norm": (["hidden_states"], ["hidden_states"]),
138
+ }
139
+
140
+ def __init__(
141
+ self,
142
+ vocab_size=152064,
143
+ hidden_size=4608,
144
+ intermediate_size=10944,
145
+ moe_intermediate_size=1408,
146
+ num_hidden_layers=62,
147
+ num_attention_heads=32,
148
+ num_key_value_heads=32,
149
+ n_shared_experts=None,
150
+ n_routed_experts=None,
151
+ n_group=1,
152
+ topk_group=1,
153
+ num_experts_per_tok=None,
154
+ first_k_dense_replace=0,
155
+ norm_topk_prob=False,
156
+ hidden_act="silu",
157
+ max_position_embeddings=2048,
158
+ initializer_range=0.02,
159
+ rms_norm_eps=1e-6,
160
+ use_cache=True,
161
+ pretraining_tp=1,
162
+ tie_word_embeddings=False,
163
+ rope_theta=10000.0,
164
+ rope_scaling=None,
165
+ attention_bias=False,
166
+ attention_dropout=0.0,
167
+ routed_scaling_factor=1.0,
168
+ use_sliding_window=False,
169
+ sliding_window=4096,
170
+ max_window_layers=62,
171
+ layer_types=None,
172
+ **kwargs,
173
+ ):
174
+ self.vocab_size = vocab_size
175
+ self.max_position_embeddings = max_position_embeddings
176
+ self.hidden_size = hidden_size
177
+ self.intermediate_size = intermediate_size
178
+ self.moe_intermediate_size = moe_intermediate_size
179
+ self.num_hidden_layers = num_hidden_layers
180
+ self.num_attention_heads = num_attention_heads
181
+ self.n_shared_experts = n_shared_experts
182
+ self.n_routed_experts = n_routed_experts
183
+ self.num_experts_per_tok = num_experts_per_tok
184
+ self.first_k_dense_replace = first_k_dense_replace
185
+ self.norm_topk_prob = norm_topk_prob
186
+ if num_key_value_heads is None:
187
+ num_key_value_heads = num_attention_heads
188
+ self.n_group = n_group
189
+ self.topk_group = topk_group
190
+ self.num_key_value_heads = num_key_value_heads
191
+ self.hidden_act = hidden_act
192
+ self.initializer_range = initializer_range
193
+ self.rms_norm_eps = rms_norm_eps
194
+ self.pretraining_tp = pretraining_tp
195
+ self.use_cache = use_cache
196
+ self.rope_theta = rope_theta
197
+ self.rope_scaling = rope_scaling
198
+ self.attention_bias = attention_bias
199
+ self.attention_dropout = attention_dropout
200
+ self.routed_scaling_factor = routed_scaling_factor
201
+ self.use_sliding_window = use_sliding_window
202
+ self.sliding_window = sliding_window if self.use_sliding_window else None
203
+ self.max_window_layers = max_window_layers
204
+
205
+ self.layer_types = layer_types
206
+ if self.layer_types is None:
207
+ self.layer_types = [
208
+ "sliding_attention"
209
+ if self.sliding_window is not None and i >= self.max_window_layers
210
+ else "full_attention"
211
+ for i in range(self.num_hidden_layers)
212
+ ]
213
+ layer_type_validation(self.layer_types)
214
+
215
+ super().__init__(
216
+ tie_word_embeddings=tie_word_embeddings,
217
+ **kwargs,
218
+ )
219
+
220
+
221
+ __all__ = ["Dots1Config"]
modeling_dots1.py ADDED
@@ -0,0 +1,699 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/dots1/modular_dots1.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_dots1.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ from typing import Callable, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import nn
12
+
13
+ #from ...activations import ACT2FN
14
+ #from ...cache_utils import Cache, DynamicCache
15
+ #from ...generation import GenerationMixin
16
+ #from ...integrations import use_kernel_forward_from_hub
17
+ #from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
18
+ #from ...modeling_flash_attention_utils import FlashAttentionKwargs
19
+ #from ...modeling_layers import GradientCheckpointingLayer
20
+ #from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
21
+ #from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
22
+ #from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
23
+ #from ...processing_utils import Unpack
24
+ #from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
25
+
26
+ from transformers.activations import ACT2FN
27
+ from transformers.cache_utils import Cache, DynamicCache
28
+ from transformers.generation import GenerationMixin
29
+ from transformers.integrations import use_kernel_forward_from_hub
30
+ from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
31
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
32
+ from transformers.modeling_layers import GradientCheckpointingLayer
33
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
34
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
35
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
36
+ from transformers.processing_utils import Unpack
37
+ from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, logging
38
+
39
+ from configuration_dots1 import Dots1Config
40
+
41
+
42
+ logger = logging.get_logger(__name__)
43
+
44
+
45
+ @use_kernel_forward_from_hub("RMSNorm")
46
+ class Dots1RMSNorm(nn.Module):
47
+ def __init__(self, hidden_size, eps=1e-6):
48
+ """
49
+ Dots1RMSNorm is equivalent to T5LayerNorm
50
+ """
51
+ super().__init__()
52
+ self.weight = nn.Parameter(torch.ones(hidden_size))
53
+ self.variance_epsilon = eps
54
+
55
+ def forward(self, hidden_states):
56
+ input_dtype = hidden_states.dtype
57
+ hidden_states = hidden_states.to(torch.float32)
58
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
59
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
60
+ return self.weight * hidden_states.to(input_dtype)
61
+
62
+ def extra_repr(self):
63
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
64
+
65
+
66
+ class Dots1RotaryEmbedding(nn.Module):
67
+ def __init__(self, config: Dots1Config, device=None):
68
+ super().__init__()
69
+ # BC: "rope_type" was originally "type"
70
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
71
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
72
+ else:
73
+ self.rope_type = "default"
74
+ self.max_seq_len_cached = config.max_position_embeddings
75
+ self.original_max_seq_len = config.max_position_embeddings
76
+
77
+ self.config = config
78
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
79
+
80
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
81
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
82
+ self.original_inv_freq = self.inv_freq
83
+
84
+ @torch.no_grad()
85
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
86
+ def forward(self, x, position_ids):
87
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
88
+ position_ids_expanded = position_ids[:, None, :].float()
89
+
90
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
91
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
92
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
93
+ emb = torch.cat((freqs, freqs), dim=-1)
94
+ cos = emb.cos() * self.attention_scaling
95
+ sin = emb.sin() * self.attention_scaling
96
+
97
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
98
+
99
+
100
+ def rotate_half(x):
101
+ """Rotates half the hidden dims of the input."""
102
+ x1 = x[..., : x.shape[-1] // 2]
103
+ x2 = x[..., x.shape[-1] // 2 :]
104
+ return torch.cat((-x2, x1), dim=-1)
105
+
106
+
107
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
108
+ """Applies Rotary Position Embedding to the query and key tensors.
109
+
110
+ Args:
111
+ q (`torch.Tensor`): The query tensor.
112
+ k (`torch.Tensor`): The key tensor.
113
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
114
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
115
+ position_ids (`torch.Tensor`, *optional*):
116
+ Deprecated and unused.
117
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
118
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
119
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
120
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
121
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
122
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
123
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
124
+ Returns:
125
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
126
+ """
127
+ cos = cos.unsqueeze(unsqueeze_dim)
128
+ sin = sin.unsqueeze(unsqueeze_dim)
129
+ q_embed = (q * cos) + (rotate_half(q) * sin)
130
+ k_embed = (k * cos) + (rotate_half(k) * sin)
131
+ return q_embed, k_embed
132
+
133
+
134
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
135
+ """
136
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
137
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
138
+ """
139
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
140
+ if n_rep == 1:
141
+ return hidden_states
142
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
143
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
144
+
145
+
146
+ def eager_attention_forward(
147
+ module: nn.Module,
148
+ query: torch.Tensor,
149
+ key: torch.Tensor,
150
+ value: torch.Tensor,
151
+ attention_mask: Optional[torch.Tensor],
152
+ scaling: float,
153
+ dropout: float = 0.0,
154
+ **kwargs,
155
+ ):
156
+ key_states = repeat_kv(key, module.num_key_value_groups)
157
+ value_states = repeat_kv(value, module.num_key_value_groups)
158
+
159
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
160
+ if attention_mask is not None:
161
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
162
+ attn_weights = attn_weights + causal_mask
163
+
164
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
165
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
166
+ attn_output = torch.matmul(attn_weights, value_states)
167
+ attn_output = attn_output.transpose(1, 2).contiguous()
168
+
169
+ return attn_output, attn_weights
170
+
171
+
172
+ class Dots1Attention(nn.Module):
173
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
174
+
175
+ def __init__(self, config: Dots1Config, layer_idx: int):
176
+ super().__init__()
177
+ self.config = config
178
+ self.layer_idx = layer_idx
179
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
180
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
181
+ self.scaling = self.head_dim**-0.5
182
+ self.attention_dropout = config.attention_dropout
183
+ self.is_causal = True
184
+
185
+ self.q_proj = nn.Linear(
186
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
187
+ )
188
+ self.k_proj = nn.Linear(
189
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
190
+ )
191
+ self.v_proj = nn.Linear(
192
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
193
+ )
194
+ self.o_proj = nn.Linear(
195
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
196
+ )
197
+ self.q_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
198
+ self.k_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
199
+ self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
200
+
201
+ def forward(
202
+ self,
203
+ hidden_states: torch.Tensor,
204
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
205
+ attention_mask: Optional[torch.Tensor],
206
+ past_key_value: Optional[Cache] = None,
207
+ cache_position: Optional[torch.LongTensor] = None,
208
+ **kwargs: Unpack[FlashAttentionKwargs],
209
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
210
+ input_shape = hidden_states.shape[:-1]
211
+ hidden_shape = (*input_shape, -1, self.head_dim)
212
+
213
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
214
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
215
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
216
+
217
+ cos, sin = position_embeddings
218
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
219
+
220
+ if past_key_value is not None:
221
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
222
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
223
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
224
+
225
+ attention_interface: Callable = eager_attention_forward
226
+ if self.config._attn_implementation != "eager":
227
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
228
+
229
+ attn_output, attn_weights = attention_interface(
230
+ self,
231
+ query_states,
232
+ key_states,
233
+ value_states,
234
+ attention_mask,
235
+ dropout=0.0 if not self.training else self.attention_dropout,
236
+ scaling=self.scaling,
237
+ sliding_window=self.sliding_window, # diff with Llama
238
+ **kwargs,
239
+ )
240
+
241
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
242
+ attn_output = self.o_proj(attn_output)
243
+ return attn_output, attn_weights
244
+
245
+
246
+ class Dots1MLP(nn.Module):
247
+ def __init__(self, config, hidden_size=None, intermediate_size=None):
248
+ super().__init__()
249
+ self.config = config
250
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
251
+ self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
252
+
253
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
254
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
255
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
256
+ self.act_fn = ACT2FN[config.hidden_act]
257
+
258
+ def forward(self, x):
259
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
260
+ return down_proj
261
+
262
+
263
+ class Dots1MoE(nn.Module):
264
+ """
265
+ A mixed expert module containing shared experts.
266
+ """
267
+
268
+ def __init__(self, config):
269
+ super().__init__()
270
+ self.config = config
271
+ self.experts = nn.ModuleList(
272
+ [Dots1MLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(config.n_routed_experts)]
273
+ )
274
+ self.gate = Dots1TopkRouter(config)
275
+ self.shared_experts = Dots1MLP(
276
+ config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
277
+ )
278
+
279
+ def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
280
+ r"""
281
+ CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused
282
+ to not have to do a loop here (deepseek has 256 experts soooo yeah).
283
+ """
284
+ final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
285
+ expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))
286
+ expert_mask = expert_mask.permute(2, 0, 1)
287
+
288
+ for expert_idx in range(len(self.experts)):
289
+ expert = self.experts[expert_idx]
290
+ mask = expert_mask[expert_idx]
291
+ token_indices, weight_indices = torch.where(mask)
292
+
293
+ if token_indices.numel() > 0:
294
+ expert_weights = topk_weights[token_indices, weight_indices]
295
+ expert_input = hidden_states[token_indices]
296
+ expert_output = expert(expert_input)
297
+ weighted_output = expert_output * expert_weights.unsqueeze(-1)
298
+ final_hidden_states.index_add_(0, token_indices, weighted_output)
299
+
300
+ # in original deepseek, the output of the experts are gathered once we leave this module
301
+ # thus the moe module is itelsf an IsolatedParallel module
302
+ # and all expert are "local" meaning we shard but we don't gather
303
+ return final_hidden_states.type(hidden_states.dtype)
304
+
305
+ def forward(self, hidden_states):
306
+ residuals = hidden_states
307
+ orig_shape = hidden_states.shape
308
+ topk_indices, topk_weights = self.gate(hidden_states)
309
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
310
+ hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
311
+ hidden_states = hidden_states + self.shared_experts(residuals)
312
+ return hidden_states
313
+
314
+
315
+ class Dots1TopkRouter(nn.Module):
316
+ def __init__(self, config):
317
+ super().__init__()
318
+ self.config = config
319
+ self.top_k = config.num_experts_per_tok
320
+ self.n_routed_experts = config.n_routed_experts
321
+ self.routed_scaling_factor = config.routed_scaling_factor
322
+ self.n_group = config.n_group
323
+ self.topk_group = config.topk_group
324
+ self.norm_topk_prob = config.norm_topk_prob
325
+
326
+ self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
327
+ self.register_buffer("e_score_correction_bias", torch.zeros((self.n_routed_experts)))
328
+
329
+ @torch.no_grad()
330
+ def get_topk_indices(self, scores):
331
+ scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
332
+ group_scores = (
333
+ scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
334
+ .topk(2, dim=-1)[0]
335
+ .sum(dim=-1)
336
+ )
337
+ group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
338
+ group_mask = torch.zeros_like(group_scores)
339
+ group_mask.scatter_(1, group_idx, 1)
340
+ score_mask = (
341
+ group_mask.unsqueeze(-1)
342
+ .expand(-1, self.n_group, self.n_routed_experts // self.n_group)
343
+ .reshape(-1, self.n_routed_experts)
344
+ )
345
+ scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
346
+ topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
347
+ return topk_indices
348
+
349
+ def forward(self, hidden_states):
350
+ hidden_states = hidden_states.view(-1, self.config.hidden_size)
351
+ router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
352
+ scores = router_logits.sigmoid()
353
+ topk_indices = self.get_topk_indices(scores)
354
+ topk_weights = scores.gather(1, topk_indices)
355
+ if self.norm_topk_prob:
356
+ denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
357
+ topk_weights /= denominator
358
+ topk_weights = topk_weights * self.routed_scaling_factor
359
+ return topk_indices, topk_weights
360
+
361
+
362
+ class Dots1DecoderLayer(GradientCheckpointingLayer):
363
+ def __init__(self, config: Dots1Config, layer_idx: int):
364
+ super().__init__()
365
+ self.hidden_size = config.hidden_size
366
+
367
+ self.self_attn = Dots1Attention(config=config, layer_idx=layer_idx)
368
+
369
+ if layer_idx >= config.first_k_dense_replace:
370
+ self.mlp = Dots1MoE(config)
371
+ else:
372
+ self.mlp = Dots1MLP(config)
373
+
374
+ self.input_layernorm = Dots1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
375
+ self.post_attention_layernorm = Dots1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
376
+ self.attention_type = config.layer_types[layer_idx]
377
+
378
+ def forward(
379
+ self,
380
+ hidden_states: torch.Tensor,
381
+ attention_mask: Optional[torch.Tensor] = None,
382
+ position_ids: Optional[torch.LongTensor] = None,
383
+ past_key_value: Optional[Cache] = None,
384
+ output_attentions: Optional[bool] = False,
385
+ use_cache: Optional[bool] = False,
386
+ cache_position: Optional[torch.LongTensor] = None,
387
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
388
+ **kwargs: Unpack[FlashAttentionKwargs],
389
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
390
+ residual = hidden_states
391
+ hidden_states = self.input_layernorm(hidden_states)
392
+
393
+ # Self Attention
394
+ hidden_states, self_attn_weights = self.self_attn(
395
+ hidden_states=hidden_states,
396
+ attention_mask=attention_mask,
397
+ position_ids=position_ids,
398
+ past_key_value=past_key_value,
399
+ output_attentions=output_attentions,
400
+ use_cache=use_cache,
401
+ cache_position=cache_position,
402
+ position_embeddings=position_embeddings,
403
+ **kwargs,
404
+ )
405
+ hidden_states = residual + hidden_states
406
+
407
+ # Fully Connected
408
+ residual = hidden_states
409
+ hidden_states = self.post_attention_layernorm(hidden_states)
410
+ hidden_states = self.mlp(hidden_states)
411
+ hidden_states = residual + hidden_states
412
+
413
+ outputs = (hidden_states,)
414
+ if output_attentions:
415
+ outputs += (self_attn_weights,)
416
+
417
+ return outputs
418
+
419
+
420
+ @auto_docstring
421
+ class Dots1PreTrainedModel(PreTrainedModel):
422
+ config_class = Dots1Config
423
+ base_model_prefix = "model"
424
+ supports_gradient_checkpointing = True
425
+ _no_split_modules = ["Dots1DecoderLayer"]
426
+ _skip_keys_device_placement = ["past_key_values"]
427
+ _supports_flash_attn_2 = True
428
+ _supports_sdpa = True
429
+ _supports_flex_attn = True
430
+ _supports_cache_class = True
431
+ _supports_quantized_cache = True
432
+ _supports_static_cache = True
433
+ _supports_attention_backend = True
434
+
435
+ def _init_weights(self, module):
436
+ std = self.config.initializer_range
437
+ if isinstance(module, nn.Linear):
438
+ module.weight.data.normal_(mean=0.0, std=std)
439
+ if module.bias is not None:
440
+ module.bias.data.zero_()
441
+ elif isinstance(module, nn.Embedding):
442
+ module.weight.data.normal_(mean=0.0, std=std)
443
+ if module.padding_idx is not None:
444
+ module.weight.data[module.padding_idx].zero_()
445
+ elif isinstance(module, Dots1RMSNorm):
446
+ module.weight.data.fill_(1.0)
447
+ elif isinstance(module, Dots1TopkRouter):
448
+ module.weight.data.normal_(mean=0.0, std=std)
449
+
450
+
451
+ @auto_docstring
452
+ class Dots1Model(Dots1PreTrainedModel):
453
+ def __init__(self, config: Dots1Config):
454
+ super().__init__(config)
455
+ self.padding_idx = config.pad_token_id
456
+ self.vocab_size = config.vocab_size
457
+
458
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
459
+ self.layers = nn.ModuleList(
460
+ [Dots1DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
461
+ )
462
+ self.norm = Dots1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
463
+ self.rotary_emb = Dots1RotaryEmbedding(config=config)
464
+ self.gradient_checkpointing = False
465
+ self.has_sliding_layers = "sliding_attention" in self.config.layer_types
466
+
467
+ # Initialize weights and apply final processing
468
+ self.post_init()
469
+
470
+ def get_input_embeddings(self):
471
+ return self.embed_tokens
472
+
473
+ def set_input_embeddings(self, value):
474
+ self.embed_tokens = value
475
+
476
+ @can_return_tuple
477
+ @auto_docstring
478
+ def forward(
479
+ self,
480
+ input_ids: Optional[torch.LongTensor] = None,
481
+ attention_mask: Optional[torch.Tensor] = None,
482
+ position_ids: Optional[torch.LongTensor] = None,
483
+ past_key_values: Optional[Cache] = None,
484
+ inputs_embeds: Optional[torch.FloatTensor] = None,
485
+ use_cache: Optional[bool] = None,
486
+ output_attentions: Optional[bool] = None,
487
+ output_hidden_states: Optional[bool] = None,
488
+ cache_position: Optional[torch.LongTensor] = None,
489
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
490
+ ) -> BaseModelOutputWithPast:
491
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
492
+ output_hidden_states = (
493
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
494
+ )
495
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
496
+
497
+ if (input_ids is None) ^ (inputs_embeds is not None):
498
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
499
+
500
+ if self.gradient_checkpointing and self.training and use_cache:
501
+ logger.warning_once(
502
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
503
+ )
504
+ use_cache = False
505
+
506
+ # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
507
+ if not isinstance(past_key_values, (type(None), Cache)):
508
+ raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
509
+
510
+ if inputs_embeds is None:
511
+ inputs_embeds = self.embed_tokens(input_ids)
512
+
513
+ if use_cache and past_key_values is None:
514
+ past_key_values = DynamicCache()
515
+
516
+ if cache_position is None:
517
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
518
+ cache_position = torch.arange(
519
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
520
+ )
521
+
522
+ if position_ids is None:
523
+ position_ids = cache_position.unsqueeze(0)
524
+
525
+ # It may already have been prepared by e.g. `generate`
526
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
527
+ # Prepare mask arguments
528
+ mask_kwargs = {
529
+ "config": self.config,
530
+ "input_embeds": inputs_embeds,
531
+ "attention_mask": attention_mask,
532
+ "cache_position": cache_position,
533
+ "past_key_values": past_key_values,
534
+ }
535
+ # Create the masks
536
+ causal_mask_mapping = {
537
+ "full_attention": create_causal_mask(**mask_kwargs),
538
+ }
539
+ # The sliding window alternating layers are not always activated depending on the config
540
+ if self.has_sliding_layers:
541
+ causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
542
+
543
+ hidden_states = inputs_embeds
544
+
545
+ # create position embeddings to be shared across the decoder layers
546
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
547
+
548
+ # decoder layers
549
+ all_hidden_states = () if output_hidden_states else None
550
+ all_self_attns = () if output_attentions else None
551
+
552
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
553
+ if output_hidden_states:
554
+ all_hidden_states += (hidden_states,)
555
+
556
+ layer_outputs = decoder_layer(
557
+ hidden_states,
558
+ attention_mask=causal_mask_mapping[decoder_layer.attention_type],
559
+ position_ids=position_ids,
560
+ past_key_value=past_key_values,
561
+ output_attentions=output_attentions,
562
+ use_cache=use_cache,
563
+ cache_position=cache_position,
564
+ position_embeddings=position_embeddings,
565
+ **flash_attn_kwargs,
566
+ )
567
+
568
+ hidden_states = layer_outputs[0]
569
+
570
+ if output_attentions:
571
+ all_self_attns += (layer_outputs[1],)
572
+
573
+ hidden_states = self.norm(hidden_states)
574
+
575
+ # add hidden states from the last decoder layer
576
+ if output_hidden_states:
577
+ all_hidden_states += (hidden_states,)
578
+
579
+ return BaseModelOutputWithPast(
580
+ last_hidden_state=hidden_states,
581
+ past_key_values=past_key_values if use_cache else None,
582
+ hidden_states=all_hidden_states,
583
+ attentions=all_self_attns,
584
+ )
585
+
586
+
587
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
588
+
589
+
590
+ @auto_docstring
591
+ class Dots1ForCausalLM(Dots1PreTrainedModel, GenerationMixin):
592
+ _tied_weights_keys = ["lm_head.weight"]
593
+ _tp_plan = {"lm_head": "colwise_rep"}
594
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
595
+
596
+ def __init__(self, config):
597
+ super().__init__(config)
598
+ self.model = Dots1Model(config)
599
+ self.vocab_size = config.vocab_size
600
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
601
+
602
+ # Initialize weights and apply final processing
603
+ self.post_init()
604
+
605
+ def get_input_embeddings(self):
606
+ return self.model.embed_tokens
607
+
608
+ def set_input_embeddings(self, value):
609
+ self.model.embed_tokens = value
610
+
611
+ def get_output_embeddings(self):
612
+ return self.lm_head
613
+
614
+ def set_output_embeddings(self, new_embeddings):
615
+ self.lm_head = new_embeddings
616
+
617
+ def set_decoder(self, decoder):
618
+ self.model = decoder
619
+
620
+ def get_decoder(self):
621
+ return self.model
622
+
623
+ @can_return_tuple
624
+ @auto_docstring
625
+ def forward(
626
+ self,
627
+ input_ids: Optional[torch.LongTensor] = None,
628
+ attention_mask: Optional[torch.Tensor] = None,
629
+ position_ids: Optional[torch.LongTensor] = None,
630
+ past_key_values: Optional[Cache] = None,
631
+ inputs_embeds: Optional[torch.FloatTensor] = None,
632
+ labels: Optional[torch.LongTensor] = None,
633
+ use_cache: Optional[bool] = None,
634
+ output_attentions: Optional[bool] = None,
635
+ output_hidden_states: Optional[bool] = None,
636
+ cache_position: Optional[torch.LongTensor] = None,
637
+ logits_to_keep: Union[int, torch.Tensor] = 0,
638
+ **kwargs: Unpack[KwargsForCausalLM],
639
+ ) -> CausalLMOutputWithPast:
640
+ r"""
641
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
642
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
643
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
644
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
645
+
646
+ Example:
647
+
648
+ ```python
649
+ >>> from transformers import AutoTokenizer, Dots1ForCausalLM
650
+
651
+ >>> model = Dots1ForCausalLM.from_pretrained("rednote-hilab/dots1.llm1.inst")
652
+ >>> tokenizer = AutoTokenizer.from_pretrained("rednote-hilab/dots1.llm1.inst")
653
+
654
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
655
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
656
+
657
+ >>> # Generate
658
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
659
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
660
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
661
+ ```"""
662
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
663
+ output_hidden_states = (
664
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
665
+ )
666
+
667
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
668
+ outputs: BaseModelOutputWithPast = self.model(
669
+ input_ids=input_ids,
670
+ attention_mask=attention_mask,
671
+ position_ids=position_ids,
672
+ past_key_values=past_key_values,
673
+ inputs_embeds=inputs_embeds,
674
+ use_cache=use_cache,
675
+ output_attentions=output_attentions,
676
+ output_hidden_states=output_hidden_states,
677
+ cache_position=cache_position,
678
+ **kwargs,
679
+ )
680
+
681
+ hidden_states = outputs.last_hidden_state
682
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
683
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
684
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
685
+
686
+ loss = None
687
+ if labels is not None:
688
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
689
+
690
+ return CausalLMOutputWithPast(
691
+ loss=loss,
692
+ logits=logits,
693
+ past_key_values=outputs.past_key_values,
694
+ hidden_states=outputs.hidden_states,
695
+ attentions=outputs.attentions,
696
+ )
697
+
698
+
699
+ __all__ = ["Dots1PreTrainedModel", "Dots1Model", "Dots1ForCausalLM"]
modular_dots1.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ...modeling_outputs import CausalLMOutputWithPast
2
+ from ...processing_utils import Unpack
3
+ from ...utils import logging
4
+ from ..deepseek_v3.modeling_deepseek_v3 import (
5
+ DeepseekV3DecoderLayer,
6
+ DeepseekV3MLP,
7
+ DeepseekV3MoE,
8
+ DeepseekV3PreTrainedModel,
9
+ DeepseekV3TopkRouter,
10
+ )
11
+ from ..llama.modeling_llama import (
12
+ KwargsForCausalLM,
13
+ LlamaRMSNorm,
14
+ )
15
+ from ..qwen3.modeling_qwen3 import Qwen3Attention, Qwen3ForCausalLM, Qwen3Model, Qwen3RotaryEmbedding
16
+ from .configuration_dots1 import Dots1Config
17
+
18
+
19
+ logger = logging.get_logger(__name__)
20
+
21
+
22
+ class Dots1RMSNorm(LlamaRMSNorm):
23
+ pass
24
+
25
+
26
+ class Dots1RotaryEmbedding(Qwen3RotaryEmbedding):
27
+ pass
28
+
29
+
30
+ class Dots1Attention(Qwen3Attention):
31
+ pass
32
+
33
+
34
+ class Dots1MLP(DeepseekV3MLP):
35
+ pass
36
+
37
+
38
+ class Dots1MoE(DeepseekV3MoE):
39
+ pass
40
+
41
+
42
+ class Dots1TopkRouter(DeepseekV3TopkRouter):
43
+ pass
44
+
45
+
46
+ class Dots1DecoderLayer(DeepseekV3DecoderLayer):
47
+ def __init__(self, config: Dots1Config, layer_idx: int):
48
+ super().__init__()
49
+ self.attention_type = config.layer_types[layer_idx]
50
+
51
+
52
+ class Dots1PreTrainedModel(DeepseekV3PreTrainedModel):
53
+ pass
54
+
55
+
56
+ class Dots1Model(Qwen3Model):
57
+ pass
58
+
59
+
60
+ class Dots1ForCausalLM(Qwen3ForCausalLM):
61
+ def forward(
62
+ self,
63
+ **super_kwargs: Unpack[KwargsForCausalLM],
64
+ ) -> CausalLMOutputWithPast:
65
+ r"""
66
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
67
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
68
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
69
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
70
+
71
+ Example:
72
+
73
+ ```python
74
+ >>> from transformers import AutoTokenizer, Dots1ForCausalLM
75
+
76
+ >>> model = Dots1ForCausalLM.from_pretrained("rednote-hilab/dots1.llm1.inst")
77
+ >>> tokenizer = AutoTokenizer.from_pretrained("rednote-hilab/dots1.llm1.inst")
78
+
79
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
80
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
81
+
82
+ >>> # Generate
83
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
84
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
85
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
86
+ ```"""
87
+ return super().forward(**super_kwargs)
88
+
89
+
90
+ __all__ = [
91
+ "Dots1PreTrainedModel",
92
+ "Dots1Model",
93
+ "Dots1ForCausalLM",
94
+ ]