Qubitium commited on
Commit
eb2ed2a
1 Parent(s): 3057ec6

Create modeling_dbrx.py

Browse files
Files changed (1) hide show
  1. modeling_dbrx.py +1448 -0
modeling_dbrx.py ADDED
@@ -0,0 +1,1448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code adapted from https://huggingface.co/fahadh4ilyas
2
+ """PyTorch Dbrx model."""
3
+
4
+ import math
5
+ import warnings
6
+ from copy import deepcopy
7
+ from functools import partial
8
+ from typing import Any, Callable, Dict, Optional, Tuple, Union
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import torch.utils.checkpoint
13
+ from torch import nn
14
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
15
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
16
+ from transformers.modeling_outputs import (MoeCausalLMOutputWithPast,
17
+ MoeModelOutputWithPast)
18
+ from transformers.modeling_utils import PreTrainedModel
19
+ from transformers.utils import is_flash_attn_2_available, logging
20
+
21
+ from .configuration_dbrx import DbrxAttentionConfig, DbrxConfig, DbrxFFNConfig
22
+
23
+ if is_flash_attn_2_available():
24
+ try:
25
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
26
+ from flash_attn.bert_padding import pad_input # noqa
27
+ from flash_attn.bert_padding import index_first_axis, unpad_input
28
+ except:
29
+ pass
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+ _CONFIG_FOR_DOC = 'DbrxConfig'
34
+
35
+ #############################################################################
36
+ # Copied from LLaMaRotaryEmbedding
37
+ #############################################################################
38
+
39
+
40
+ class DbrxRotaryEmbedding(nn.Module):
41
+
42
+ def __init__(self,
43
+ dim: int,
44
+ max_position_embeddings: int = 2048,
45
+ base: float = 10000.0,
46
+ scaling_factor: float = 1.0):
47
+ super().__init__()
48
+ self.scaling_factor = scaling_factor
49
+ self.dim = dim
50
+ self.max_position_embeddings = max_position_embeddings
51
+ self.base = base
52
+ inv_freq = 1.0 / (self.base**(
53
+ torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
54
+ self.register_buffer('inv_freq', inv_freq, persistent=False)
55
+ # For BC we register cos and sin cached
56
+ self.max_seq_len_cached = max_position_embeddings
57
+
58
+ @torch.no_grad()
59
+ def forward(
60
+ self, x: torch.Tensor, position_ids: torch.LongTensor
61
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
62
+ # x: [bs, num_attention_heads, seq_len, head_size]
63
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(
64
+ position_ids.shape[0], -1, 1)
65
+ position_ids_expanded = position_ids[:, None, :].float()
66
+ # Force float32 since bfloat16 loses precision on long contexts
67
+ # See https://github.com/huggingface/transformers/pull/29285
68
+ device_type = x.device.type
69
+ device_type = device_type if isinstance(
70
+ device_type, str) and device_type != 'mps' else 'cpu'
71
+ with torch.autocast(device_type=device_type, enabled=False):
72
+ freqs = (inv_freq_expanded.float()
73
+ @ position_ids_expanded.float()).transpose(1, 2)
74
+ emb = torch.cat((freqs, freqs), dim=-1)
75
+ cos = emb.cos()
76
+ sin = emb.sin()
77
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
78
+
79
+
80
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
81
+ """Rotates half the hidden dims of the input."""
82
+ x1 = x[..., :x.shape[-1] // 2]
83
+ x2 = x[..., x.shape[-1] // 2:]
84
+ return torch.cat((-x2, x1), dim=-1)
85
+
86
+
87
+ def apply_rotary_pos_emb(
88
+ q: torch.Tensor,
89
+ k: torch.Tensor,
90
+ cos: torch.Tensor,
91
+ sin: torch.Tensor,
92
+ unsqueeze_dim: int = 1) -> Tuple[torch.Tensor, torch.Tensor]:
93
+ """Applies Rotary Position Embedding to the query and key tensors.
94
+
95
+ Args:
96
+ q (`torch.Tensor`): The query tensor.
97
+ k (`torch.Tensor`): The key tensor.
98
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
99
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
100
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
101
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos and
102
+ sin so that they can be properly broadcasted to the dimensions of q and k. For example, note
103
+ that cos and sin have the shape [batch_size, seq_len, head_dim]. Then, if q and
104
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
105
+ cos and sin broadcastable to the shapes of q and k. Similarly, if q and k have
106
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
107
+
108
+ Returns:
109
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
110
+ """
111
+ cos = cos.unsqueeze(unsqueeze_dim)
112
+ sin = sin.unsqueeze(unsqueeze_dim)
113
+ q_embed = (q * cos) + (rotate_half(q) * sin)
114
+ k_embed = (k * cos) + (rotate_half(k) * sin)
115
+ return q_embed, k_embed
116
+
117
+
118
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
119
+ """Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
120
+
121
+ The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to
122
+ (batch, num_attention_heads, seqlen, head_dim)
123
+ """
124
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
125
+ if n_rep == 1:
126
+ return hidden_states
127
+ hidden_states = hidden_states[:, :,
128
+ None, :, :].expand(batch, num_key_value_heads,
129
+ n_rep, slen, head_dim)
130
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen,
131
+ head_dim)
132
+
133
+
134
+ #############################################################################
135
+
136
+ #############################################################################
137
+ # Modified from modeling_mixtral
138
+ #############################################################################
139
+
140
+
141
+ def load_balancing_loss_func(
142
+ gate_logits: torch.Tensor,
143
+ num_experts: int,
144
+ top_k: int,
145
+ attention_mask: Optional[torch.Tensor],
146
+ ) -> torch.Tensor:
147
+ r"""Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
148
+
149
+ See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
150
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
151
+ experts is too unbalanced.
152
+
153
+ Args:
154
+ gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
155
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
156
+ shape [batch_size X sequence_length, num_experts].
157
+ num_experts (`int`):
158
+ Number of experts.
159
+ top_k (`int`):
160
+ The number of experts each token is routed to.
161
+ attention_mask (`torch.Tensor`, None):
162
+ The attention_mask used in forward function
163
+ shape [batch_size X sequence_length] if not None.
164
+
165
+ Returns:
166
+ The auxiliary loss.
167
+ """
168
+ if gate_logits is None or not isinstance(gate_logits, tuple):
169
+ return torch.tensor(0.0)
170
+
171
+ if isinstance(gate_logits, tuple):
172
+ compute_device = gate_logits[0].device
173
+ concatenated_gate_logits = torch.cat(
174
+ [layer_gate.to(compute_device) for layer_gate in gate_logits],
175
+ dim=0)
176
+
177
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits,
178
+ dim=-1)
179
+
180
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
181
+
182
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
183
+
184
+ if attention_mask is None:
185
+ # Compute the percentage of tokens routed to each experts
186
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
187
+
188
+ # Compute the average probability of routing to these experts
189
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
190
+ else:
191
+ batch_size, sequence_length = attention_mask.shape
192
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (
193
+ batch_size * sequence_length)
194
+
195
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
196
+ expert_attention_mask = (attention_mask[None, :, :, None, None].expand(
197
+ (num_hidden_layers, batch_size, sequence_length, top_k,
198
+ num_experts)).reshape(-1, top_k, num_experts).to(compute_device))
199
+
200
+ # Compute the percentage of tokens routed to each experts
201
+ tokens_per_expert = torch.sum(
202
+ expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
203
+ expert_attention_mask, dim=0)
204
+
205
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
206
+ router_per_expert_attention_mask = (
207
+ attention_mask[None, :, :, None].expand(
208
+ (num_hidden_layers, batch_size, sequence_length,
209
+ num_experts)).reshape(-1, num_experts).to(compute_device))
210
+
211
+ # Compute the average probability of routing to these experts
212
+ router_prob_per_expert = torch.sum(
213
+ routing_weights * router_per_expert_attention_mask,
214
+ dim=0) / torch.sum(router_per_expert_attention_mask, dim=0)
215
+
216
+ overall_loss = torch.sum(tokens_per_expert *
217
+ router_prob_per_expert.unsqueeze(0))
218
+ return overall_loss * num_experts
219
+
220
+
221
+ #############################################################################
222
+
223
+
224
+ def resolve_ffn_act_fn(
225
+ ffn_act_fn: dict) -> Callable[[torch.Tensor], torch.Tensor]:
226
+ """Resolve the activation function for the feed-forward network.
227
+
228
+ Args:
229
+ ffn_act_fn (dict): The configuration dictionary for the activation function.
230
+ The dict config must specify the 'name' of a torch.nn.functional activation
231
+ function. All of other key values pairs are bound to the function as a partial.
232
+
233
+ Returns:
234
+ Callable[[torch.Tensor], torch.Tensor]: The activation function.
235
+ """
236
+ config = deepcopy(ffn_act_fn)
237
+ name = config.pop('name')
238
+ if not hasattr(nn.functional, name):
239
+ raise ValueError(f'Unrecognised activation function name ({name}).')
240
+ act = getattr(nn.functional, name)
241
+ return partial(act, **config)
242
+
243
+
244
+ #############################################################################
245
+ # Copied from LLaMaAttention
246
+ #############################################################################
247
+
248
+
249
+ def _get_unpad_data(attention_mask: torch.Tensor):
250
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
251
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
252
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
253
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32),
254
+ (1, 0))
255
+ return (
256
+ indices,
257
+ cu_seqlens,
258
+ max_seqlen_in_batch,
259
+ )
260
+
261
+
262
+ class DbrxAttention(nn.Module):
263
+ """Multi-head self attention."""
264
+
265
+ def __init__(self,
266
+ hidden_size: int,
267
+ num_heads: int,
268
+ max_position_embeddings: int,
269
+ attn_config: DbrxAttentionConfig,
270
+ block_idx: Optional[int] = None):
271
+ super().__init__()
272
+ self.hidden_size = hidden_size
273
+ self.num_heads = num_heads
274
+ self.head_dim = self.hidden_size // self.num_heads
275
+ self.max_position_embeddings = max_position_embeddings
276
+ self.block_idx = block_idx
277
+ self.config = attn_config
278
+ if block_idx is None:
279
+ logger.warning_once(
280
+ f'Instantiating {self.__class__.__name__} without passing a `block_idx` is not recommended and will '
281
+ +
282
+ 'lead to errors during the forward call if caching is used. Please make sure to provide a `block_idx` '
283
+ + 'when creating this class.')
284
+
285
+ self.attn_pdrop = attn_config.attn_pdrop
286
+ self.clip_qkv = attn_config.clip_qkv
287
+ self.num_key_value_heads = attn_config.kv_n_heads
288
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
289
+ self.rope_theta = attn_config.rope_theta
290
+
291
+ self.q_proj = nn.Linear(self.hidden_size,
292
+ self.hidden_size,
293
+ bias=False)
294
+ self.k_proj = nn.Linear(self.hidden_size,
295
+ self.num_key_value_heads * self.head_dim,
296
+ bias=False)
297
+ self.v_proj = nn.Linear(self.hidden_size,
298
+ self.num_key_value_heads * self.head_dim,
299
+ bias=False)
300
+ self.out_proj = nn.Linear(self.hidden_size,
301
+ self.hidden_size,
302
+ bias=False)
303
+ self.rotary_emb = DbrxRotaryEmbedding(
304
+ self.head_dim,
305
+ max_position_embeddings=self.max_position_embeddings,
306
+ base=self.rope_theta,
307
+ )
308
+
309
+ def forward(
310
+ self,
311
+ hidden_states: torch.Tensor,
312
+ position_ids: torch.LongTensor,
313
+ attention_mask: Optional[torch.Tensor] = None,
314
+ past_key_value: Optional[Cache] = None,
315
+ output_attentions: bool = False,
316
+ use_cache: bool = False,
317
+ cache_position: Optional[torch.LongTensor] = None,
318
+ **kwargs: Any,
319
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
320
+ bsz, q_len, _ = hidden_states.size()
321
+
322
+ query_states = self.q_proj(hidden_states)
323
+ key_states = self.k_proj(hidden_states)
324
+ value_states = self.v_proj(hidden_states)
325
+ if self.clip_qkv is not None:
326
+ query_states = query_states.clamp(min=-self.clip_qkv, max=self.clip_qkv)
327
+ key_states = key_states.clamp(min=-self.clip_qkv, max=self.clip_qkv)
328
+ value_states = value_states.clamp(min=-self.clip_qkv, max=self.clip_qkv)
329
+
330
+ query_states = query_states.view(bsz, q_len, self.num_heads,
331
+ self.head_dim).transpose(1, 2)
332
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
333
+ self.head_dim).transpose(1, 2)
334
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
335
+ self.head_dim).transpose(1, 2)
336
+
337
+ past_key_value = getattr(self, 'past_key_value', past_key_value)
338
+ cos, sin = self.rotary_emb(value_states, position_ids)
339
+ query_states, key_states = apply_rotary_pos_emb(query_states,
340
+ key_states, cos, sin)
341
+
342
+ if past_key_value is not None:
343
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
344
+ cache_kwargs = {
345
+ 'sin': sin,
346
+ 'cos': cos,
347
+ 'cache_position': cache_position
348
+ }
349
+ key_states, value_states = past_key_value.update(
350
+ key_states, value_states, self.block_idx, cache_kwargs)
351
+
352
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
353
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
354
+
355
+ attn_weights = torch.matmul(query_states, key_states.transpose(
356
+ 2, 3)) / math.sqrt(self.head_dim)
357
+
358
+ if attention_mask is not None: # no matter the length, we just slice it
359
+ causal_mask = attention_mask[:, :, :, :key_states.shape[-2]]
360
+ attn_weights = attn_weights + causal_mask
361
+
362
+ # upcast attention to fp32
363
+ attn_weights = nn.functional.softmax(attn_weights,
364
+ dim=-1,
365
+ dtype=torch.float32).to(
366
+ query_states.dtype)
367
+ attn_weights = nn.functional.dropout(attn_weights,
368
+ p=self.attn_pdrop,
369
+ training=self.training)
370
+ attn_output = torch.matmul(attn_weights, value_states)
371
+
372
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
373
+ raise ValueError(
374
+ f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
375
+ + f' {attn_output.size()}')
376
+
377
+ attn_output = attn_output.transpose(1, 2).contiguous()
378
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
379
+ attn_output = self.out_proj(attn_output)
380
+
381
+ if not output_attentions:
382
+ attn_weights = None
383
+
384
+ return attn_output, attn_weights, past_key_value
385
+
386
+
387
+ class DbrxFlashAttention2(DbrxAttention):
388
+ """Dbrx flash attention module.
389
+
390
+ This module inherits from `DbrxAttention` as the weights of the module stays
391
+ untouched. The only required change would be on the forward pass where it
392
+ calls the public API of flash attention.
393
+ """
394
+
395
+ def __init__(self, *args: Any, **kwargs: Any):
396
+ if not is_flash_attn_2_available():
397
+ raise ImportError(
398
+ 'Flash Attention 2 is not available. Please install it with `pip install flash-attn`.'
399
+ )
400
+
401
+ super().__init__(*args, **kwargs)
402
+
403
+ def forward(
404
+ self,
405
+ hidden_states: torch.Tensor,
406
+ attention_mask: Optional[torch.LongTensor] = None,
407
+ position_ids: Optional[torch.LongTensor] = None,
408
+ past_key_value: Optional[Cache] = None,
409
+ output_attentions: bool = False,
410
+ use_cache: bool = False,
411
+ cache_position: Optional[torch.LongTensor] = None,
412
+ **kwargs: Any,
413
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
414
+ Optional[Tuple[torch.Tensor]]]:
415
+ logger.debug(
416
+ 'Implicitly setting `output_attentions` to False as it is not supported in Flash Attention.'
417
+ )
418
+ output_attentions = False
419
+
420
+ bsz, q_len, _ = hidden_states.size()
421
+
422
+ query_states = self.q_proj(hidden_states)
423
+ key_states = self.k_proj(hidden_states)
424
+ value_states = self.v_proj(hidden_states)
425
+ if self.clip_qkv is not None:
426
+ query_states = query_states.clamp(min=-self.clip_qkv, max=self.clip_qkv)
427
+ key_states = key_states.clamp(min=-self.clip_qkv, max=self.clip_qkv)
428
+ value_states = value_states.clamp(min=-self.clip_qkv, max=self.clip_qkv)
429
+
430
+ # Flash attention requires the input to have the shape
431
+ # batch_size x seq_length x head_dim x hidden_dim
432
+ # therefore we just need to keep the original shape
433
+ query_states = query_states.view(bsz, q_len, self.num_heads,
434
+ self.head_dim).transpose(1, 2)
435
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
436
+ self.head_dim).transpose(1, 2)
437
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
438
+ self.head_dim).transpose(1, 2)
439
+
440
+ cos, sin = self.rotary_emb(value_states, position_ids)
441
+ query_states, key_states = apply_rotary_pos_emb(query_states,
442
+ key_states, cos, sin)
443
+
444
+ past_key_value = getattr(self, 'past_key_value', past_key_value)
445
+
446
+ if past_key_value is not None:
447
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
448
+ cache_kwargs = {
449
+ 'sin': sin,
450
+ 'cos': cos,
451
+ 'cache_position': cache_position
452
+ }
453
+ key_states, value_states = past_key_value.update(
454
+ key_states, value_states, self.block_idx, cache_kwargs)
455
+
456
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout
457
+ # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
458
+ # to be able to avoid many of these transpose/reshape/view.
459
+ query_states = query_states.transpose(1, 2)
460
+ key_states = key_states.transpose(1, 2)
461
+ value_states = value_states.transpose(1, 2)
462
+
463
+ dropout_rate = self.attn_pdrop if self.training else 0.0
464
+
465
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
466
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
467
+ # cast them back in the correct dtype just to be sure everything works as expected.
468
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
469
+ # in fp32. (LlamaRMSNorm handles it correctly)
470
+ input_dtype = query_states.dtype
471
+ if input_dtype == torch.float32:
472
+ if torch.is_autocast_enabled():
473
+ target_dtype = torch.get_autocast_gpu_dtype()
474
+ # Handle the case where the model is quantized
475
+ elif hasattr(self.config, '_pre_quantization_dtype'):
476
+ target_dtype = self.config._pre_quantization_dtype
477
+ else:
478
+ target_dtype = query_states.dtype
479
+
480
+ logger.warning_once(
481
+ f'The input hidden states seems to be silently casted in float32, this might be '
482
+ +
483
+ f'related to the fact you have upcasted embedding or layer norm layers in '
484
+ + f'float32. We will cast back the input in {target_dtype}.')
485
+
486
+ query_states = query_states.to(target_dtype)
487
+ key_states = key_states.to(target_dtype)
488
+ value_states = value_states.to(target_dtype)
489
+
490
+ attn_output = self._flash_attention_forward(
491
+ query_states,
492
+ key_states,
493
+ value_states,
494
+ attention_mask,
495
+ q_len,
496
+ dropout=dropout_rate,
497
+ )
498
+
499
+ attn_output = attn_output.reshape(bsz, q_len,
500
+ self.hidden_size).contiguous()
501
+ attn_output = self.out_proj(attn_output)
502
+
503
+ if not output_attentions:
504
+ attn_weights = None
505
+
506
+ return attn_output, attn_weights, past_key_value # type: ignore
507
+
508
+ def _flash_attention_forward(
509
+ self,
510
+ query_states: torch.Tensor,
511
+ key_states: torch.Tensor,
512
+ value_states: torch.Tensor,
513
+ attention_mask: Union[torch.LongTensor, None],
514
+ query_length: int,
515
+ dropout: float = 0.0,
516
+ softmax_scale: Optional[float] = None,
517
+ ):
518
+ """Use FlashAttention, stripping padding tokens if necessary.
519
+
520
+ Args:
521
+ query_states (torch.Tensor): Input query states to be passed to Flash Attention API
522
+ key_states (torch.Tensor): Input key states to be passed to Flash Attention API
523
+ value_states (torch.Tensor): Input value states to be passed to Flash Attention API
524
+ attention_mask (torch.LongTensor | None): The padding mask - corresponds to a tensor of size
525
+ (batch_size, seq_len) where 0 stands for the position of padding tokens and 1
526
+ for the position of non-padding tokens.
527
+ query_length (int): The length of the query sequence
528
+ dropout (float): Attention dropout
529
+ softmax_scale (float, optional): The scaling of QK^T before applying softmax.
530
+ Defaults to 1 / sqrt(head_dim)
531
+ """
532
+ causal = True
533
+ # Contains at least one padding token in the sequence
534
+ if attention_mask is not None:
535
+ batch_size = query_states.shape[0]
536
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
537
+ query_states, key_states, value_states, attention_mask,
538
+ query_length)
539
+
540
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
541
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
542
+
543
+ attn_output_unpad = flash_attn_varlen_func(
544
+ query_states,
545
+ key_states,
546
+ value_states,
547
+ cu_seqlens_q=cu_seqlens_q,
548
+ cu_seqlens_k=cu_seqlens_k,
549
+ max_seqlen_q=max_seqlen_in_batch_q,
550
+ max_seqlen_k=max_seqlen_in_batch_k,
551
+ dropout_p=dropout,
552
+ softmax_scale=softmax_scale,
553
+ causal=causal,
554
+ )
555
+
556
+ attn_output = pad_input(
557
+ attn_output_unpad,
558
+ indices_q,
559
+ batch_size,
560
+ query_length,
561
+ )
562
+ else:
563
+ attn_output = flash_attn_func(
564
+ query_states,
565
+ key_states,
566
+ value_states,
567
+ dropout,
568
+ softmax_scale=softmax_scale,
569
+ causal=causal,
570
+ )
571
+
572
+ return attn_output
573
+
574
+ def _upad_input(self, query_layer: torch.Tensor, key_layer: torch.Tensor,
575
+ value_layer: torch.Tensor, attention_mask: torch.Tensor,
576
+ query_length: int):
577
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(
578
+ attention_mask)
579
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
580
+
581
+ key_layer = index_first_axis(
582
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
583
+ head_dim), indices_k)
584
+ value_layer = index_first_axis(
585
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
586
+ head_dim), indices_k)
587
+ if query_length == kv_seq_len:
588
+ query_layer = index_first_axis(
589
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads,
590
+ head_dim), indices_k)
591
+ cu_seqlens_q = cu_seqlens_k
592
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
593
+ indices_q = indices_k
594
+ elif query_length == 1:
595
+ max_seqlen_in_batch_q = 1
596
+ cu_seqlens_q = torch.arange(
597
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
598
+ ) # There is a memcpy here, that is very bad.
599
+ indices_q = cu_seqlens_q[:-1]
600
+ query_layer = query_layer.squeeze(1)
601
+ else:
602
+ # The -q_len: slice assumes left padding.
603
+ attention_mask = attention_mask[:, -query_length:]
604
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
605
+ query_layer, attention_mask)
606
+
607
+ return (
608
+ query_layer,
609
+ key_layer,
610
+ value_layer,
611
+ indices_q,
612
+ (cu_seqlens_q, cu_seqlens_k),
613
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
614
+ )
615
+
616
+
617
+ DBRX_ATTENTION_CLASSES = {
618
+ 'eager': DbrxAttention,
619
+ 'flash_attention_2': DbrxFlashAttention2,
620
+ }
621
+
622
+
623
+ class DbrxNormAttentionNorm(nn.Module):
624
+
625
+ def __init__(
626
+ self,
627
+ hidden_size: int,
628
+ num_heads: int,
629
+ max_position_embeddings: int,
630
+ resid_pdrop: float,
631
+ attn_implementation: str,
632
+ attn_config: DbrxAttentionConfig,
633
+ block_idx: Optional[int] = None,
634
+ ):
635
+ super().__init__()
636
+ self.block_idx = block_idx
637
+ self.resid_pdrop = resid_pdrop
638
+ self.norm_1 = nn.LayerNorm(hidden_size, bias=False)
639
+ self.attn = DBRX_ATTENTION_CLASSES[attn_implementation](
640
+ hidden_size=hidden_size,
641
+ num_heads=num_heads,
642
+ max_position_embeddings=max_position_embeddings,
643
+ attn_config=attn_config,
644
+ block_idx=block_idx,
645
+ )
646
+ self.norm_2 = nn.LayerNorm(hidden_size, bias=False)
647
+
648
+ def forward(
649
+ self,
650
+ hidden_states: torch.Tensor,
651
+ position_ids: torch.LongTensor,
652
+ attention_mask: Optional[torch.Tensor] = None,
653
+ past_key_value: Optional[Cache] = None,
654
+ output_attentions: bool = False,
655
+ use_cache: bool = False,
656
+ cache_position: Optional[torch.LongTensor] = None,
657
+ **kwargs: Any,
658
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
659
+ Optional[Cache]]:
660
+
661
+ residual_states = hidden_states
662
+ hidden_states = self.norm_1(hidden_states).to(hidden_states.dtype)
663
+
664
+ hidden_states, attn_weights, past_key_value = self.attn(
665
+ hidden_states=hidden_states,
666
+ attention_mask=attention_mask,
667
+ position_ids=position_ids,
668
+ past_key_value=past_key_value,
669
+ output_attentions=output_attentions,
670
+ use_cache=use_cache,
671
+ cache_position=cache_position,
672
+ **kwargs,
673
+ )
674
+
675
+ hidden_states = nn.functional.dropout(hidden_states,
676
+ p=self.resid_pdrop,
677
+ training=self.training)
678
+ hidden_states = hidden_states + residual_states
679
+
680
+ residual_states = hidden_states
681
+ hidden_states = self.norm_2(hidden_states).to(hidden_states.dtype)
682
+
683
+ return residual_states, hidden_states, attn_weights, past_key_value
684
+
685
+
686
+ class DbrxRouter(nn.Module):
687
+
688
+ def __init__(self, hidden_size: int, moe_num_experts: int, moe_top_k: int,
689
+ moe_jitter_eps: Optional[float],
690
+ moe_normalize_expert_weights: Optional[float],
691
+ uniform_expert_assignment: bool):
692
+ super().__init__()
693
+ self.hidden_size = hidden_size
694
+ self.moe_num_experts = moe_num_experts
695
+ self.moe_top_k = moe_top_k
696
+ self.moe_jitter_eps = moe_jitter_eps
697
+ self.moe_normalize_expert_weights = moe_normalize_expert_weights
698
+ self.uniform_expert_assignment = uniform_expert_assignment
699
+
700
+ self.layer = nn.Linear(self.hidden_size,
701
+ self.moe_num_experts,
702
+ bias=False)
703
+
704
+ def jitter(self, x: torch.Tensor) -> torch.Tensor:
705
+ if self.moe_jitter_eps is None:
706
+ raise RuntimeError('The router does not have moe_jitter_eps set.')
707
+ low = 1.0 - self.moe_jitter_eps
708
+ high = 1.0 + self.moe_jitter_eps
709
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
710
+ return low + noise * (high - low)
711
+
712
+ def forward(
713
+ self, x: torch.Tensor
714
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
715
+ if self.training and self.moe_jitter_eps is not None:
716
+ x = x * self.jitter(x)
717
+
718
+ weights = self.layer(x.view(-1,
719
+ x.shape[-1])).softmax(dim=-1,
720
+ dtype=torch.float32)
721
+ top_weights, top_experts = torch.topk(weights, self.moe_top_k, dim=-1)
722
+
723
+ if self.moe_normalize_expert_weights:
724
+ top_weights = top_weights / torch.norm(
725
+ top_weights,
726
+ p=self.moe_normalize_expert_weights,
727
+ dim=-1,
728
+ keepdim=True)
729
+
730
+ if self.uniform_expert_assignment:
731
+ with torch.no_grad():
732
+ uniform_tensor = torch.arange(
733
+ 0,
734
+ top_experts.numel(),
735
+ device=top_experts.device,
736
+ dtype=top_experts.dtype) % self.moe_num_experts
737
+ top_experts = uniform_tensor.reshape(top_experts.shape)
738
+ # Note, weights and top_weights are not changed
739
+
740
+ weights = weights.to(x.dtype)
741
+ top_weights = top_weights.to(x.dtype)
742
+ return weights, top_weights, top_experts # type: ignore
743
+
744
+
745
+ class DbrxMLP(nn.Module):
746
+
747
+ def __init__(self, hidden_size: int, ffn_hidden_size: int, ffn_act_fn: dict):
748
+ super().__init__()
749
+
750
+ self.w1 = nn.Linear(hidden_size, ffn_hidden_size, bias=False)
751
+ self.v1 = nn.Linear(hidden_size, ffn_hidden_size, bias=False)
752
+ self.w2 = nn.Linear(ffn_hidden_size, hidden_size, bias=False)
753
+ self.activation_fn = resolve_ffn_act_fn(ffn_act_fn)
754
+
755
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
756
+
757
+ return self.w2(self.activation_fn(self.w1(x)) * self.v1(x))
758
+
759
+
760
+ class DbrxExperts(nn.Module):
761
+
762
+ def __init__(self, hidden_size: int, ffn_hidden_size: int,
763
+ moe_num_experts: int, ffn_act_fn: dict):
764
+ super().__init__()
765
+ self.moe_num_experts = moe_num_experts
766
+ self.mlp = nn.ModuleList([DbrxMLP(hidden_size, ffn_hidden_size, ffn_act_fn) for _ in range(moe_num_experts)])
767
+
768
+ def forward(self, x: torch.Tensor, weights: torch.Tensor,
769
+ top_weights: torch.Tensor,
770
+ top_experts: torch.LongTensor) -> torch.Tensor:
771
+ bsz, q_len, hidden_size = x.shape
772
+ x = x.view(-1, hidden_size)
773
+ out = torch.zeros_like(x)
774
+
775
+ expert_mask = nn.functional.one_hot(
776
+ top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0)
777
+ for expert_idx in range(0, self.moe_num_experts):
778
+ topk_idx, token_idx = torch.where(expert_mask[expert_idx])
779
+ if token_idx.shape[0] == 0:
780
+ continue
781
+
782
+ expert_tokens = x[None, token_idx].reshape(-1, hidden_size)
783
+ expert_out = self.mlp[expert_idx](expert_tokens) * top_weights[token_idx, topk_idx, None]
784
+
785
+ out.index_add_(0, token_idx, expert_out)
786
+
787
+ out = out.reshape(bsz, q_len, hidden_size)
788
+ return out
789
+
790
+
791
+ class DbrxFFN(nn.Module):
792
+
793
+ def __init__(self, hidden_size: int, ffn_config: DbrxFFNConfig):
794
+ super().__init__()
795
+
796
+ self.router = DbrxRouter(
797
+ hidden_size,
798
+ moe_num_experts=ffn_config.moe_num_experts,
799
+ moe_top_k=ffn_config.moe_top_k,
800
+ moe_jitter_eps=ffn_config.moe_jitter_eps,
801
+ moe_normalize_expert_weights=ffn_config.
802
+ moe_normalize_expert_weights,
803
+ uniform_expert_assignment=ffn_config.uniform_expert_assignment,
804
+ )
805
+
806
+ self.experts = DbrxExperts(
807
+ hidden_size=hidden_size,
808
+ ffn_hidden_size=ffn_config.ffn_hidden_size,
809
+ moe_num_experts=ffn_config.moe_num_experts,
810
+ ffn_act_fn=ffn_config.ffn_act_fn,
811
+ )
812
+
813
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
814
+ weights, top_weights, top_experts = self.router(x)
815
+ out = self.experts(x, weights, top_weights, top_experts)
816
+ return out, weights
817
+
818
+
819
+ class DbrxBlock(nn.Module):
820
+
821
+ def __init__(self, config: DbrxConfig, block_idx: int):
822
+ super().__init__()
823
+ self.hidden_size = config.d_model
824
+ self.resid_pdrop = config.resid_pdrop
825
+ self.block_idx = block_idx
826
+ self.norm_attn_norm = DbrxNormAttentionNorm(
827
+ hidden_size=config.d_model,
828
+ num_heads=config.n_heads,
829
+ max_position_embeddings=config.max_seq_len,
830
+ resid_pdrop=config.resid_pdrop,
831
+ attn_implementation=config._attn_implementation,
832
+ attn_config=config.attn_config,
833
+ block_idx=block_idx,
834
+ )
835
+ self.ffn = DbrxFFN(hidden_size=config.d_model,
836
+ ffn_config=config.ffn_config)
837
+
838
+ def forward(
839
+ self,
840
+ hidden_states: torch.Tensor,
841
+ position_ids: torch.LongTensor,
842
+ attention_mask: Optional[torch.Tensor] = None,
843
+ past_key_value: Optional[Cache] = None,
844
+ output_attentions: Optional[bool] = False,
845
+ output_router_logits: Optional[bool] = False,
846
+ use_cache: Optional[bool] = False,
847
+ cache_position: Optional[torch.LongTensor] = None,
848
+ **kwargs: Any,
849
+ ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]],
850
+ Tuple[torch.Tensor, Optional[Cache]], Tuple[
851
+ torch.Tensor, Optional[torch.Tensor], Optional[Cache]],
852
+ Tuple[torch.Tensor, Optional[torch.Tensor],
853
+ Optional[torch.Tensor]], Tuple[
854
+ torch.Tensor, Optional[Cache], Optional[torch.Tensor]],
855
+ Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache],
856
+ Optional[torch.Tensor]],]:
857
+ """Forward function for DbrxBlock.
858
+
859
+ Args:
860
+ hidden_states (`torch.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
861
+ position_ids (`torch.LongTensor`): position ids of shape `(batch, seq_len)`
862
+ attention_mask (`torch.Tensor`, optional): attention mask of size (batch_size, sequence_length)
863
+ if flash attention is used or (batch_size, 1, query_sequence_length, key_sequence_length)
864
+ if default attention is used.
865
+ past_key_value (`Tuple(torch.Tensor)`, optional): cached past key and value projection states
866
+ output_attentions (`bool`, optional): Whether or not to return the attentions tensors of all
867
+ attention layers. See `attentions` under returned tensors for more detail.
868
+ output_router_logits (`bool`, optional): Whether or not to return the router logits.
869
+ use_cache (`bool`, optional): If set to `True`, `past_key_values` key value states are
870
+ returned and can be used to speed up decoding (see `past_key_values`).
871
+ cache_position (`torch.LongTensor`, optional): position ids of the cache
872
+ """
873
+ if 'padding_mask' in kwargs:
874
+ warnings.warn(
875
+ 'Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`'
876
+ )
877
+
878
+ # Norm + Attention + Norm
879
+ resid_states, hidden_states, self_attn_weights, present_key_value = self.norm_attn_norm(
880
+ hidden_states=hidden_states,
881
+ attention_mask=attention_mask,
882
+ position_ids=position_ids,
883
+ past_key_value=past_key_value,
884
+ output_attentions=output_attentions,
885
+ use_cache=use_cache,
886
+ cache_position=cache_position,
887
+ **kwargs,
888
+ )
889
+
890
+ # Fully Connected
891
+ hidden_states, router_logits = self.ffn(hidden_states)
892
+ hidden_states = nn.functional.dropout(hidden_states,
893
+ p=self.resid_pdrop,
894
+ training=self.training)
895
+ hidden_states = resid_states + hidden_states
896
+
897
+ outputs = (hidden_states,)
898
+
899
+ if output_attentions:
900
+ outputs += (self_attn_weights,)
901
+
902
+ if use_cache:
903
+ outputs += (present_key_value,)
904
+
905
+ if output_router_logits:
906
+ outputs += (router_logits,)
907
+
908
+ return outputs
909
+
910
+
911
+ class DbrxPreTrainedModel(PreTrainedModel):
912
+ config_class = DbrxConfig
913
+ base_model_prefix = 'transformer'
914
+ supports_gradient_checkpointing = True
915
+ _no_split_modules = ['DbrxBlock']
916
+ _skip_keys_device_placement = ['past_key_values']
917
+ _supports_flash_attn_2 = True
918
+ _supports_sdpa = False
919
+ _supports_cache_class = True
920
+
921
+ def _init_weights(self, module: nn.Module):
922
+ std = self.config.initializer_range
923
+ if isinstance(module, nn.Linear):
924
+ module.weight.data.normal_(mean=0.0, std=std)
925
+ if module.bias is not None:
926
+ module.bias.data.zero_()
927
+ elif isinstance(module, nn.Embedding):
928
+ module.weight.data.normal_(mean=0.0, std=std)
929
+ if module.padding_idx is not None:
930
+ module.weight.data[module.padding_idx].zero_()
931
+ elif isinstance(module, nn.LayerNorm):
932
+ module.weight.data.normal_(mean=0.0, std=std)
933
+ if module.bias is not None:
934
+ module.bias.data.zero_()
935
+
936
+ def _setup_cache(self, cache_cls: Any, max_batch_size: int,
937
+ max_cache_len: int): # TODO: how to set var type of class?
938
+ if self.config._attn_implementation == 'flash_attention_2' and cache_cls == StaticCache:
939
+ raise ValueError(
940
+ '`static` cache implementation is not compatible with ' +
941
+ '`attn_implementation==flash_attention_2`. Make sure to use ' +
942
+ '`spda` in the mean time and open an issue at https://github.com/huggingface/transformers.'
943
+ )
944
+
945
+ for block in self.transformer.blocks:
946
+ device = block.norm_attn_norm.norm_1.weight.device
947
+ if hasattr(self.config, '_pre_quantization_dtype'):
948
+ dtype = self.config._pre_quantization_dtype
949
+ else:
950
+ dtype = block.norm_attn_norm.attn.out_proj.weight.dtype
951
+ block.norm_attn_norm.attn.past_key_value = cache_cls(self.config,
952
+ max_batch_size,
953
+ max_cache_len,
954
+ device=device,
955
+ dtype=dtype)
956
+
957
+ def _reset_cache(self):
958
+ for block in self.transformer.blocks:
959
+ block.norm_attn_norm.attn.past_key_value = None
960
+
961
+
962
+ class DbrxModel(DbrxPreTrainedModel):
963
+ """Transformer decoder consisting of *config.num_hidden_layers*
964
+
965
+ [`DbrxBlock`] layers.
966
+
967
+ Args:
968
+ config: DbrxConfig
969
+ """
970
+
971
+ def __init__(self, config: DbrxConfig):
972
+ super().__init__(config)
973
+ self.padding_idx = config.pad_token_id
974
+ self.vocab_size = config.vocab_size
975
+ self.emb_pdrop = config.emb_pdrop
976
+
977
+ self.wte = nn.Embedding(config.vocab_size, config.d_model,
978
+ self.padding_idx)
979
+ self.blocks = nn.ModuleList([
980
+ DbrxBlock(config, block_idx) for block_idx in range(config.n_layers)
981
+ ])
982
+ self.norm_f = nn.LayerNorm(config.d_model, bias=False)
983
+ self.gradient_checkpointing = False
984
+
985
+ # Initialize weights and apply final processing
986
+ self.post_init()
987
+
988
+ def get_input_embeddings(self) -> nn.Embedding:
989
+ return self.wte
990
+
991
+ def set_input_embeddings(self, value: nn.Embedding):
992
+ self.wte = value
993
+
994
+ def _autocast_input_embeddings(self,
995
+ inputs_embeds: torch.Tensor) -> torch.Tensor:
996
+ if inputs_embeds.device.type == 'cuda' and torch.is_autocast_enabled():
997
+ return inputs_embeds.to(dtype=torch.get_autocast_gpu_dtype())
998
+ elif inputs_embeds.device.type == 'cpu' and torch.is_autocast_cpu_enabled(
999
+ ):
1000
+ return inputs_embeds.to(dtype=torch.get_autocast_cpu_dtype())
1001
+ else:
1002
+ return inputs_embeds
1003
+
1004
+ def forward(
1005
+ self,
1006
+ input_ids: Optional[torch.LongTensor] = None,
1007
+ attention_mask: Optional[torch.Tensor] = None,
1008
+ position_ids: Optional[torch.LongTensor] = None,
1009
+ past_key_values: Optional[Cache] = None,
1010
+ inputs_embeds: Optional[torch.Tensor] = None,
1011
+ use_cache: Optional[bool] = None,
1012
+ output_attentions: Optional[bool] = None,
1013
+ output_hidden_states: Optional[bool] = None,
1014
+ output_router_logits: Optional[bool] = None,
1015
+ return_dict: Optional[bool] = None,
1016
+ cache_position: Optional[torch.LongTensor] = None,
1017
+ ) -> Union[Tuple, MoeModelOutputWithPast]:
1018
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1019
+ output_hidden_states = (output_hidden_states
1020
+ if output_hidden_states is not None else
1021
+ self.config.output_hidden_states)
1022
+ output_router_logits = (output_router_logits
1023
+ if output_router_logits is not None else
1024
+ self.config.output_router_logits)
1025
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1026
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1027
+
1028
+ if (input_ids is None) ^ (inputs_embeds is not None):
1029
+ raise ValueError(
1030
+ 'You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one'
1031
+ )
1032
+
1033
+ if self.gradient_checkpointing and self.training and use_cache:
1034
+ logger.warning_once(
1035
+ '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.'
1036
+ )
1037
+ use_cache = False
1038
+
1039
+ if inputs_embeds is None:
1040
+ inputs_embeds = self.wte(input_ids)
1041
+
1042
+ inputs_embeds = self._autocast_input_embeddings(
1043
+ inputs_embeds) # type: ignore
1044
+ inputs_embeds = nn.functional.dropout(inputs_embeds,
1045
+ p=self.emb_pdrop,
1046
+ training=self.training)
1047
+
1048
+ past_seen_tokens = 0
1049
+ if use_cache: # kept for BC (cache positions)
1050
+ if not isinstance(past_key_values, StaticCache):
1051
+ past_key_values = DynamicCache.from_legacy_cache(
1052
+ past_key_values)
1053
+ past_seen_tokens = past_key_values.get_seq_length( # type: ignore
1054
+ )
1055
+
1056
+ if cache_position is None:
1057
+ if isinstance(past_key_values, StaticCache):
1058
+ raise ValueError(
1059
+ 'cache_position is a required argument when using StaticCache.'
1060
+ )
1061
+ cache_position = torch.arange( # type: ignore
1062
+ past_seen_tokens,
1063
+ past_seen_tokens + inputs_embeds.shape[1],
1064
+ device=inputs_embeds.device)
1065
+
1066
+ if position_ids is None:
1067
+ position_ids = cache_position.unsqueeze(0) # type: ignore
1068
+
1069
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds,
1070
+ cache_position) # type: ignore
1071
+
1072
+ # embed positions
1073
+ hidden_states = inputs_embeds
1074
+
1075
+ # decoder layers
1076
+ all_hidden_states = () if output_hidden_states else None
1077
+ all_self_attns = () if output_attentions else None
1078
+ all_router_logits = () if output_router_logits else None
1079
+ next_decoder_cache = None
1080
+
1081
+ for block in self.blocks:
1082
+ if output_hidden_states:
1083
+ all_hidden_states += (hidden_states,) # type: ignore
1084
+
1085
+ if self.gradient_checkpointing and self.training:
1086
+ block_outputs = self._gradient_checkpointing_func(
1087
+ block.__call__,
1088
+ hidden_states,
1089
+ attention_mask=causal_mask,
1090
+ position_ids=position_ids,
1091
+ past_key_values=past_key_values,
1092
+ output_attentions=output_attentions,
1093
+ output_router_logits=output_router_logits,
1094
+ use_cache=use_cache,
1095
+ cache_position=cache_position,
1096
+ )
1097
+ else:
1098
+ block_outputs = block(
1099
+ hidden_states,
1100
+ attention_mask=causal_mask,
1101
+ position_ids=position_ids,
1102
+ past_key_value=past_key_values,
1103
+ output_attentions=output_attentions,
1104
+ output_router_logits=output_router_logits,
1105
+ use_cache=use_cache,
1106
+ cache_position=cache_position,
1107
+ )
1108
+
1109
+ hidden_states = block_outputs[0]
1110
+
1111
+ if use_cache:
1112
+ next_decoder_cache = block_outputs[
1113
+ 2 if output_attentions else 1]
1114
+
1115
+ if output_attentions:
1116
+ all_self_attns += (block_outputs[1],) # type: ignore
1117
+
1118
+ if output_router_logits:
1119
+ all_router_logits += (block_outputs[-1],) # type: ignore
1120
+
1121
+ hidden_states = self.norm_f(hidden_states)
1122
+
1123
+ # add hidden states from the last decoder layer
1124
+ if output_hidden_states:
1125
+ all_hidden_states += (hidden_states,) # type: ignore
1126
+
1127
+ next_cache = None
1128
+ if use_cache:
1129
+ next_cache = (
1130
+ next_decoder_cache.to_legacy_cache() # type: ignore
1131
+ if isinstance(next_decoder_cache, Cache) else
1132
+ next_decoder_cache)
1133
+ if not return_dict:
1134
+ return tuple(v for v in [
1135
+ hidden_states, next_cache, all_hidden_states, all_self_attns,
1136
+ all_router_logits
1137
+ ] if v is not None)
1138
+ return MoeModelOutputWithPast(
1139
+ last_hidden_state=hidden_states,
1140
+ past_key_values=next_cache,
1141
+ hidden_states=all_hidden_states,
1142
+ attentions=all_self_attns,
1143
+ router_logits=all_router_logits,
1144
+ )
1145
+
1146
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
1147
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
1148
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
1149
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
1150
+ def _update_causal_mask(
1151
+ self, attention_mask: Optional[torch.Tensor],
1152
+ input_tensor: torch.Tensor,
1153
+ cache_position: torch.Tensor) -> Optional[torch.Tensor]:
1154
+ if self.config._attn_implementation == 'flash_attention_2':
1155
+ if attention_mask is not None and 0.0 in attention_mask:
1156
+ return attention_mask
1157
+ return None
1158
+
1159
+ dtype, device = input_tensor.dtype, input_tensor.device
1160
+ min_dtype = torch.finfo(dtype).min
1161
+ sequence_length = input_tensor.shape[1]
1162
+ if hasattr(self.blocks[0].norm_attn_norm.attn,
1163
+ 'past_key_value'): # static cache
1164
+ target_length = self.config.max_position_embeddings
1165
+ else: # dynamic cache
1166
+ target_length = (attention_mask.shape[-1] if isinstance(
1167
+ attention_mask, torch.Tensor) else cache_position[-1] + 1)
1168
+ target_length = int(target_length)
1169
+
1170
+ causal_mask = torch.full((sequence_length, target_length),
1171
+ fill_value=min_dtype,
1172
+ dtype=dtype,
1173
+ device=device)
1174
+ if sequence_length != 1:
1175
+ causal_mask = torch.triu(causal_mask, diagonal=1)
1176
+ causal_mask *= torch.arange(
1177
+ target_length, device=device) > cache_position.reshape(-1, 1)
1178
+ causal_mask = causal_mask[None,
1179
+ None, :, :].expand(input_tensor.shape[0], 1,
1180
+ -1, -1)
1181
+ if attention_mask is not None:
1182
+ causal_mask = causal_mask.clone(
1183
+ ) # copy to contiguous memory for in-place edit
1184
+ if attention_mask.dim() == 2:
1185
+ mask_length = attention_mask.shape[-1]
1186
+ padding_mask = causal_mask[..., :mask_length].eq(
1187
+ 0.0) * attention_mask[:, None, None, :].eq(0.0)
1188
+ causal_mask[..., :mask_length] = causal_mask[
1189
+ ..., :mask_length].masked_fill(padding_mask, min_dtype)
1190
+ elif attention_mask.dim() == 4:
1191
+ # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
1192
+ # cache. In that case, the 4D attention mask attends to the newest tokens only.
1193
+ if attention_mask.shape[
1194
+ -2] < cache_position[0] + sequence_length:
1195
+ offset = cache_position[0]
1196
+ else:
1197
+ offset = 0
1198
+ mask_shape = attention_mask.shape
1199
+ mask_slice = (attention_mask.eq(0.0)).to(
1200
+ dtype=dtype) * min_dtype
1201
+ causal_mask[:mask_shape[0], :mask_shape[1],
1202
+ offset:mask_shape[2] +
1203
+ offset, :mask_shape[3]] = mask_slice
1204
+
1205
+ if (self.config._attn_implementation == 'sdpa' and
1206
+ attention_mask is not None and
1207
+ attention_mask.device.type == 'cuda'):
1208
+ # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
1209
+ is_tracing = (
1210
+ torch.jit.is_tracing() or
1211
+ isinstance(input_tensor, torch.fx.Proxy) or # type: ignore
1212
+ (hasattr(torch, '_dynamo') and torch._dynamo.is_compiling()))
1213
+ if not is_tracing and torch.any(attention_mask != 1):
1214
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1215
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1216
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1217
+ causal_mask = AttentionMaskConverter._unmask_unattended(
1218
+ causal_mask, min_dtype)
1219
+
1220
+ return causal_mask
1221
+
1222
+
1223
+ class DbrxForCausalLM(DbrxPreTrainedModel):
1224
+
1225
+ def __init__(self, config: DbrxConfig):
1226
+ super().__init__(config)
1227
+ self.transformer = DbrxModel(config)
1228
+ self.vocab_size = config.vocab_size
1229
+ self.lm_head = nn.Linear(config.hidden_size,
1230
+ config.vocab_size,
1231
+ bias=False)
1232
+ self.router_aux_loss_coef = config.router_aux_loss_coef
1233
+ self.num_experts = config.ffn_config.moe_num_experts
1234
+ self.num_experts_per_tok = config.ffn_config.moe_top_k
1235
+
1236
+ # Initialize weights and apply final processing
1237
+ self.post_init()
1238
+
1239
+ def get_input_embeddings(self) -> nn.Embedding:
1240
+ return self.transformer.get_input_embeddings()
1241
+
1242
+ def set_input_embeddings(self, value: nn.Embedding):
1243
+ self.transformer.set_input_embeddings(value)
1244
+
1245
+ def get_output_embeddings(self) -> nn.Linear:
1246
+ return self.lm_head
1247
+
1248
+ def set_output_embeddings(self, new_embeddings: nn.Linear):
1249
+ self.lm_head = new_embeddings
1250
+
1251
+ def set_decoder(self, decoder: DbrxModel):
1252
+ self.transformer = decoder
1253
+
1254
+ def get_decoder(self) -> DbrxModel:
1255
+ return self.transformer
1256
+
1257
+ def forward(
1258
+ self,
1259
+ input_ids: Optional[torch.LongTensor] = None,
1260
+ attention_mask: Optional[torch.Tensor] = None,
1261
+ position_ids: Optional[torch.LongTensor] = None,
1262
+ past_key_values: Optional[Cache] = None,
1263
+ inputs_embeds: Optional[torch.Tensor] = None,
1264
+ labels: Optional[torch.LongTensor] = None,
1265
+ use_cache: Optional[bool] = None,
1266
+ output_attentions: Optional[bool] = None,
1267
+ output_hidden_states: Optional[bool] = None,
1268
+ output_router_logits: Optional[bool] = None,
1269
+ return_dict: Optional[bool] = None,
1270
+ cache_position: Optional[torch.LongTensor] = None,
1271
+ ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
1272
+ r"""Forward function for causal language modeling.
1273
+
1274
+ Example:
1275
+ ```python
1276
+ >>> from transformers import AutoTokenizer, DbrxForCausalLM
1277
+
1278
+ >>> model = DbrxForCausalLM.from_pretrained("databricks/dbrx")
1279
+ >>> tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx")
1280
+
1281
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1282
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1283
+
1284
+ >>> # Generate
1285
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1286
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1287
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1288
+ ```
1289
+ """
1290
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1291
+ output_hidden_states = (output_hidden_states
1292
+ if output_hidden_states is not None else
1293
+ self.config.output_hidden_states)
1294
+ output_router_logits = (output_router_logits
1295
+ if output_router_logits is not None else
1296
+ self.config.output_router_logits)
1297
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1298
+
1299
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1300
+ outputs = self.transformer(
1301
+ input_ids=input_ids,
1302
+ attention_mask=attention_mask,
1303
+ position_ids=position_ids,
1304
+ past_key_values=past_key_values,
1305
+ inputs_embeds=inputs_embeds,
1306
+ use_cache=use_cache,
1307
+ output_attentions=output_attentions,
1308
+ output_hidden_states=output_hidden_states,
1309
+ output_router_logits=output_router_logits,
1310
+ return_dict=return_dict,
1311
+ cache_position=cache_position,
1312
+ )
1313
+
1314
+ hidden_states = outputs[0]
1315
+ logits = self.lm_head(hidden_states)
1316
+
1317
+ loss = None
1318
+ if labels is not None:
1319
+ # Shift so that tokens < n predict n
1320
+ shift_logits = logits[..., :-1, :].contiguous()
1321
+ shift_labels = labels[..., 1:].contiguous()
1322
+ # Flatten the tokens
1323
+ loss_fct = nn.CrossEntropyLoss()
1324
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1325
+ shift_labels = shift_labels.view(-1)
1326
+ # Enable model parallelism
1327
+ shift_labels = shift_labels.to(shift_logits.device)
1328
+ loss = loss_fct(shift_logits, shift_labels)
1329
+
1330
+ aux_loss = None
1331
+ if output_router_logits:
1332
+ aux_loss = load_balancing_loss_func(
1333
+ outputs.router_logits if return_dict else outputs[-1],
1334
+ self.num_experts,
1335
+ self.num_experts_per_tok,
1336
+ attention_mask,
1337
+ )
1338
+ if labels is not None and loss is not None:
1339
+ loss += self.router_aux_loss_coef * aux_loss.to(
1340
+ loss.device) # make sure to reside in the same device
1341
+
1342
+ if not return_dict:
1343
+ output = (logits,) + outputs[1:]
1344
+ return (loss,) + output if loss is not None else output
1345
+
1346
+ return MoeCausalLMOutputWithPast(
1347
+ loss=loss,
1348
+ aux_loss=aux_loss,
1349
+ logits=logits,
1350
+ past_key_values=outputs.past_key_values,
1351
+ hidden_states=outputs.hidden_states,
1352
+ attentions=outputs.attentions,
1353
+ router_logits=outputs.router_logits,
1354
+ )
1355
+
1356
+ def prepare_inputs_for_generation(
1357
+ self,
1358
+ input_ids: torch.Tensor,
1359
+ past_key_values: Optional[Cache] = None,
1360
+ attention_mask: Optional[torch.Tensor] = None,
1361
+ inputs_embeds: Optional[torch.Tensor] = None,
1362
+ **kwargs: Any) -> Dict[str, Any]:
1363
+ past_length = 0
1364
+ if past_key_values is not None:
1365
+ if isinstance(past_key_values, Cache):
1366
+ cache_length = past_key_values.get_seq_length()
1367
+ past_length = past_key_values.seen_tokens
1368
+ max_cache_length = past_key_values.get_max_length()
1369
+ else:
1370
+ cache_length = past_length = past_key_values[0][0].shape[2]
1371
+ max_cache_length = None
1372
+
1373
+ # Keep only the unprocessed tokens:
1374
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1375
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1376
+ # input)
1377
+ if attention_mask is not None and attention_mask.shape[
1378
+ 1] > input_ids.shape[1]:
1379
+ input_ids = input_ids[:,
1380
+ -(attention_mask.shape[1] - past_length):]
1381
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1382
+ # input_ids based on the past_length.
1383
+ elif past_length < input_ids.shape[1]:
1384
+ input_ids = input_ids[:, past_length:]
1385
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1386
+
1387
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1388
+ if (max_cache_length is not None and attention_mask is not None and
1389
+ cache_length + input_ids.shape[1] > max_cache_length):
1390
+ attention_mask = attention_mask[:, -max_cache_length:]
1391
+
1392
+ position_ids = kwargs.get('position_ids', None)
1393
+ if attention_mask is not None and position_ids is None:
1394
+ # create position_ids on the fly for batch generation
1395
+ position_ids = attention_mask.long().cumsum(-1) - 1
1396
+ position_ids.masked_fill_(attention_mask == 0, 1)
1397
+ if past_key_values:
1398
+ position_ids = position_ids[:, -input_ids.shape[1]:]
1399
+
1400
+ if self.generation_config.cache_implementation == 'static':
1401
+ # generation with static cache
1402
+ cache_position = kwargs.get('cache_position', None)
1403
+ if cache_position is None:
1404
+ past_length = 0
1405
+ else:
1406
+ past_length = cache_position[-1] + 1
1407
+ input_ids = input_ids[:, past_length:]
1408
+ position_ids = position_ids[:,
1409
+ past_length:] if position_ids is not None else None
1410
+
1411
+ # TODO @gante we should only keep a `cache_position` in generate, and do +=1.
1412
+ # same goes for position ids. Could also help with continued generation.
1413
+ input_length = position_ids.shape[
1414
+ -1] if position_ids is not None else input_ids.shape[-1]
1415
+ cache_position = torch.arange(past_length,
1416
+ past_length + input_length,
1417
+ device=input_ids.device)
1418
+ position_ids = position_ids.contiguous(
1419
+ ) if position_ids is not None else None
1420
+
1421
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1422
+ if inputs_embeds is not None and past_key_values is None:
1423
+ model_inputs = {'inputs_embeds': inputs_embeds}
1424
+ else:
1425
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
1426
+ # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
1427
+ # TODO: use `next_tokens` directly instead.
1428
+ model_inputs = {'input_ids': input_ids.contiguous()}
1429
+
1430
+ model_inputs.update(
1431
+ { # type: ignore
1432
+ 'position_ids': position_ids,
1433
+ 'cache_position': cache_position,
1434
+ 'past_key_values': past_key_values,
1435
+ 'use_cache': kwargs.get('use_cache'),
1436
+ 'attention_mask': attention_mask,
1437
+ }
1438
+ )
1439
+ return model_inputs
1440
+
1441
+ @staticmethod
1442
+ def _reorder_cache(past_key_values: Cache, beam_idx: torch.LongTensor):
1443
+ reordered_past = ()
1444
+ for layer_past in past_key_values:
1445
+ reordered_past += (tuple(
1446
+ past_state.index_select(0, beam_idx.to(past_state.device))
1447
+ for past_state in layer_past),)
1448
+ return reordered_past