lexsming commited on
Commit
7c4d8d7
·
verified ·
1 Parent(s): 9d92469

Update modeling_bailing_moe.py

Browse files
Files changed (1) hide show
  1. modeling_bailing_moe.py +25 -371
modeling_bailing_moe.py CHANGED
@@ -20,17 +20,14 @@
20
  """ PyTorch BailingMoE model."""
21
  import math
22
  import warnings
23
- from dataclasses import dataclass
24
  from typing import List, Optional, Tuple, Union
25
 
26
  import torch
27
- import torch.distributed as dist
28
  import torch.nn.functional as F
29
  import torch.utils.checkpoint
30
- import transformers
31
- from packaging import version
32
  from torch import nn
33
  from torch.nn import CrossEntropyLoss
 
34
  from transformers.activations import ACT2FN
35
  from transformers.cache_utils import Cache, DynamicCache
36
  from transformers.modeling_attn_mask_utils import (
@@ -40,10 +37,8 @@ from transformers.modeling_attn_mask_utils import (
40
  _prepare_4d_causal_attention_mask_for_sdpa,
41
  )
42
  from transformers.modeling_outputs import (
43
- ModelOutput,
44
- MoeCausalLMOutputWithPast,
45
  MoeModelOutputWithPast,
46
- SequenceClassifierOutputWithPast,
47
  )
48
  from transformers.modeling_utils import PreTrainedModel
49
  from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
@@ -56,9 +51,9 @@ from transformers.utils import (
56
  replace_return_docstrings,
57
  )
58
  from transformers.utils.import_utils import is_torch_fx_available
59
-
60
  from .configuration_bailing_moe import BailingMoeConfig
61
 
 
62
  if is_flash_attn_2_available():
63
  from flash_attn import flash_attn_func, flash_attn_varlen_func
64
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
@@ -108,220 +103,6 @@ def _make_causal_mask(
108
  )
109
 
110
 
111
- def _unpack_router_logits(router_outputs):
112
- """
113
- Unpack the router tuple for blance loss calculation.
114
- """
115
- total_router_logits = []
116
- total_expert_indexes = []
117
- for router_output in router_outputs:
118
- if router_output[0] is not None:
119
- router_logits, expert_indexes = router_output
120
- total_router_logits.append(router_logits.unsqueeze(0))
121
- total_expert_indexes.append(expert_indexes.unsqueeze(0))
122
- return torch.cat(total_router_logits, dim=0), total_expert_indexes
123
-
124
-
125
- def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.Tensor, labels: torch.Tensor) -> float:
126
- num_layers, _, seq_len, num_experts = router_probs.shape
127
- num_experts = router_probs.shape[-1]
128
- new_labels = labels.clone().detach()
129
- ##
130
- for batch_tensor in new_labels:
131
- neg_mask = batch_tensor == -100
132
- diff_neg_ones = torch.diff(neg_mask.float())
133
- start_pos = torch.where(diff_neg_ones == 1.0)[0] # 找到-1序列开始的位置
134
- if start_pos.nelement() == 0: # 如果没有找到开始位置,可能需要根据实际情况调整
135
- pass
136
- else:
137
- last_start = start_pos[-1] # 需要修改的最后一串-1的开始位置
138
- batch_tensor[:last_start] = 0 # 将这部分-1全部改为0
139
- new_labels = new_labels.to(torch.int64)
140
-
141
- # cast the expert indices to int64, otherwise one-hot encoding will fail
142
-
143
- if expert_indices.dtype != torch.int64:
144
- expert_indices = expert_indices.to(torch.int64)
145
-
146
- if len(expert_indices.shape) == 3:
147
- expert_indices = expert_indices.unsqueeze(3)
148
-
149
- expert_mask = torch.nn.functional.one_hot(expert_indices, num_experts)
150
-
151
- # For a given token, determine if it was routed to a given expert.
152
- expert_mask = torch.max(expert_mask, axis=-2).values
153
-
154
- # cast to float32 otherwise mean will fail
155
- expert_mask = expert_mask.to(torch.float32)
156
- labels_mask = (new_labels[None, ..., None].expand_as(expert_mask) != -100).long()
157
-
158
- # sample level balance loss
159
- tokens_per_group_and_expert = torch.sum(expert_mask * labels_mask, dim=-2) / torch.sum(labels_mask, dim=-2)
160
- router_prob_per_group_and_expert = torch.sum(router_probs * labels_mask, dim=-2) / torch.sum(labels_mask, dim=-2)
161
- return torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) * (num_experts**2)
162
-
163
-
164
- def router_z_loss_func(router_logits: torch.Tensor, labels: torch.Tensor) -> float:
165
- r"""
166
- Compute the router z-loss implemented in PyTorch.
167
-
168
- The router z-loss was introduced in [Designing Effective Sparse Expert Models](https://arxiv.org/abs/2202.08906).
169
- It encourages router logits to remain small in an effort to improve stability.
170
-
171
- Args:
172
- router_logits (`float`):
173
- Input logits of shape [num_layers, batch_size, sequence_length, num_experts]
174
-
175
- Returns:
176
- Scalar router z-loss.
177
- """
178
- num_layers, num_groups, tokens_per_group, _ = router_logits.shape
179
- labels_mask = (labels[None, ..., None].expand_as(router_logits) != -100).long()
180
-
181
- ori_dtype = router_logits.dtype
182
- if ori_dtype == torch.bfloat16:
183
- loss_func_inputs = (router_logits * labels_mask).to(torch.float32)
184
- else:
185
- loss_func_inputs = router_logits * labels_mask
186
- log_z = torch.logsumexp(loss_func_inputs, dim=-1).to(ori_dtype)
187
- z_loss = log_z**2
188
-
189
- return torch.sum(z_loss) / (num_layers * num_groups * tokens_per_group)
190
-
191
-
192
- def auxiliary_loss(router_tuple, lm_logits, labels, config: BailingMoeConfig):
193
- balance_loss, z_loss, last_logits_l2_loss = 0.0, 0.0, 0.0
194
-
195
- loss = 0
196
- if router_tuple is not None:
197
- router_logits, layer_router_index = _unpack_router_logits(router_tuple)
198
- top1_expert_index = torch.cat(layer_router_index, dim=0)
199
- z_loss = router_z_loss_func(router_logits, labels)
200
- router_probs = torch.nn.Softmax(dim=-1)(router_logits)
201
- balance_loss = load_balancing_loss_func(router_probs, top1_expert_index, labels)
202
-
203
- num_layers = router_probs.shape[0]
204
- num_experts = router_probs.shape[-1]
205
- router_probs_log = router_probs.detach().view(num_layers, -1, num_experts)
206
- router_probs_mean = router_probs_log.mean(1)
207
- router_probs_sort_mean = router_probs_log.sort(-1, descending=True)[0].mean(1)
208
- router_probs_log = torch.stack([router_probs_mean, router_probs_sort_mean], dim=1)
209
- dist.all_reduce(router_probs_log, dist.ReduceOp.SUM)
210
- router_probs_log = router_probs_log / torch.distributed.get_world_size()
211
- if dist.get_rank() == 0:
212
- router_probs_log = router_probs_log.float()
213
- router_probs_log /= router_probs_log.sum(-1, keepdim=True)
214
-
215
- loss = float(config.router_z_loss_alpha) * z_loss + float(config.router_balance_loss_alpha) * balance_loss
216
-
217
- last_logits_l2_loss = 0.0
218
- if float(config.last_logits_l2_alpha) >= 0:
219
- shift_logits = lm_logits[..., :-1, :].contiguous()
220
- shift_labels = labels[..., 1:].contiguous()
221
-
222
- shift_logits = lm_logits.view(-1, lm_logits.size(-1))
223
- labels_mask = (shift_labels.view(-1) != -100).long()
224
-
225
- last_logits_l2_loss = torch.sum(torch.linalg.norm(shift_logits.float(), 2.0, dim=-1) * labels_mask) / torch.sum(
226
- labels_mask
227
- )
228
- loss += float(config.last_logits_l2_alpha) * last_logits_l2_loss
229
- last_logits_l2_loss = last_logits_l2_loss.item()
230
-
231
- return loss, balance_loss, z_loss, last_logits_l2_loss
232
-
233
-
234
- def local_token_level_cross_entropy(logits, labels, **kwargs):
235
- # 在每个batch内部做token-level的平均,然后在所有batch间做平均
236
- if isinstance(logits, ModelOutput):
237
- logits = logits.logits
238
- elif isinstance(logits, Tuple):
239
- logits = logits[0]
240
-
241
- logits = logits.float()
242
- shift_logits = logits[..., :-1, :].contiguous()
243
- shift_labels = labels[..., 1:].contiguous()
244
- loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
245
- loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
246
- return loss
247
-
248
-
249
- def sample_level_cross_entropy(logits, labels, **kwargs):
250
- # 先对所有样本字token-level的平均,然后计算所有sample的平均值
251
- if isinstance(logits, ModelOutput):
252
- logits = logits.logits
253
- elif isinstance(logits, Tuple):
254
- logits = logits[0]
255
-
256
- logits = logits.float()
257
- shift_logits = logits[..., :-1, :].contiguous()
258
- shift_labels = labels[..., 1:].contiguous()
259
- loss_fct = CrossEntropyLoss(ignore_index=-100, reduction='none')
260
- loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).reshape(
261
- shift_labels.shape[0], -1
262
- )
263
- loss = loss.sum(-1) / (shift_labels != -100).sum(-1)
264
- loss = loss.mean()
265
- return loss
266
-
267
-
268
- def global_token_level_cross_entropy(logits, labels, **kwargs):
269
- # 对所有样本一起做token-level的平均
270
- if isinstance(logits, ModelOutput):
271
- logits = logits.logits
272
- elif isinstance(logits, Tuple):
273
- logits = logits[0]
274
-
275
- logits = logits.float()
276
- shift_logits = logits[..., :-1, :].contiguous()
277
- shift_labels = labels[..., 1:].contiguous()
278
- loss_fct = CrossEntropyLoss(ignore_index=-100, reduction='none')
279
- loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).reshape(
280
- shift_labels.shape[0], -1
281
- )
282
- num_tokens = (shift_labels != -100).sum()
283
- loss = loss.sum()
284
-
285
- num_tokens_tensor = torch.zeros([1], device=loss.device, dtype=loss.dtype)
286
- num_tokens_tensor[0] = num_tokens.item()
287
-
288
- torch.distributed.all_reduce(num_tokens_tensor)
289
-
290
- global_num_tokens = num_tokens_tensor.sum()
291
-
292
- torch.distributed.barrier()
293
- # global_num_tokens是全局的token数,因为在梯度更新的时候回自动对所有卡求mean
294
- # 所有这里要乘一个world_size
295
- loss = loss.sum() / global_num_tokens * torch.distributed.get_world_size()
296
-
297
- return loss
298
-
299
-
300
- BAILING_LOSS_MAPPING = {
301
- 'local_token_level_cross_entropy': local_token_level_cross_entropy,
302
- 'sample_level_cross_entropy': sample_level_cross_entropy,
303
- 'global_token_level_cross_entropy': global_token_level_cross_entropy,
304
- }
305
-
306
-
307
- @dataclass
308
- class CustomMoeOutput(ModelOutput):
309
- """完全自定义的输出类,包含所有需要的字段"""
310
-
311
- loss: Optional[torch.FloatTensor] = None
312
- aux_loss: Optional[torch.FloatTensor] = None
313
- logits: torch.FloatTensor = None
314
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
315
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
316
- attentions: Optional[Tuple[torch.FloatTensor]] = None
317
- router_logits: Optional[Tuple[torch.FloatTensor]] = None
318
- # 额外的损失组件
319
- lm_loss: Optional[torch.FloatTensor] = None
320
- balance_loss: Optional[torch.FloatTensor] = None
321
- z_loss: Optional[torch.FloatTensor] = None
322
- last_logits_l2_loss: Optional[torch.FloatTensor] = None
323
-
324
-
325
  class BailingMoeRMSNorm(nn.Module):
326
  def __init__(self, hidden_size, eps=1e-6):
327
  """
@@ -696,7 +477,6 @@ class BailingMoeAttention(nn.Module):
696
  value_states = value_states.transpose(1, 2)
697
 
698
  kv_seq_len = key_states.shape[-2]
699
-
700
  if past_key_value is not None:
701
  if self.layer_idx is None:
702
  raise ValueError(
@@ -705,7 +485,6 @@ class BailingMoeAttention(nn.Module):
705
  "with a layer index."
706
  )
707
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
708
-
709
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
710
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
711
 
@@ -1564,67 +1343,36 @@ class BailingMoeForCausalLM(BailingMoePreTrainedModel):
1564
 
1565
  logits = logits.float()
1566
 
1567
- lm_loss = None
1568
  aux_loss = None
1569
 
1570
  if labels is not None:
1571
- built_in_loss_mapping = {}
1572
- if version.parse(transformers.__version__) >= version.parse("4.46.0"):
1573
- from transformers.loss.loss_utils import LOSS_MAPPING
1574
-
1575
- built_in_loss_mapping = dict(LOSS_MAPPING)
1576
- built_in_loss_mapping.update(BAILING_LOSS_MAPPING)
1577
-
1578
- loss_type = getattr(self.config, "loss_type", None)
1579
- if loss_type is None or loss_type not in built_in_loss_mapping:
1580
- logger.warning_once(
1581
- f"`loss_type={loss_type}` was set in the config but it is unrecognised. "
1582
- f"Using the default loss: `global_token_level_cross_entropy`."
1583
- )
1584
- loss_type = "global_token_level_cross_entropy"
1585
-
1586
- loss_fct = built_in_loss_mapping[loss_type]
1587
- lm_loss = loss_fct(logits, labels)
1588
-
1589
- loss = lm_loss
1590
- if output_router_logits and labels is not None:
1591
- aux_loss, balance_loss, z_loss, last_logits_l2_loss = auxiliary_loss(
1592
- outputs.router_logits, logits, labels, self.config
1593
- )
1594
- loss = lm_loss + self.config.router_aux_loss_coef * aux_loss
1595
 
1596
  if not return_dict:
1597
  output = (logits,) + outputs[1:]
1598
- if output_router_logits and labels is not None:
1599
- output = (aux_loss, balance_loss, z_loss, last_logits_l2_loss) + output
1600
  return (loss,) + output if loss is not None else output
1601
 
1602
- if output_router_logits and labels is not None:
1603
- moe_output = CustomMoeOutput(
1604
- loss=loss,
1605
- aux_loss=aux_loss,
1606
- logits=logits,
1607
- past_key_values=outputs.past_key_values,
1608
- hidden_states=outputs.hidden_states,
1609
- attentions=outputs.attentions,
1610
- router_logits=outputs.router_logits,
1611
- lm_loss=lm_loss,
1612
- balance_loss=balance_loss,
1613
- z_loss=z_loss,
1614
- last_logits_l2_loss=last_logits_l2_loss,
1615
- )
1616
-
1617
- return moe_output
1618
- else:
1619
- return MoeCausalLMOutputWithPast(
1620
- loss=loss,
1621
- aux_loss=aux_loss,
1622
- logits=logits,
1623
- past_key_values=outputs.past_key_values,
1624
- hidden_states=outputs.hidden_states,
1625
- attentions=outputs.attentions,
1626
- router_logits=outputs.router_logits,
1627
- )
1628
 
1629
  def prepare_inputs_for_generation(
1630
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, token_type_ids=None, **kwargs
@@ -1693,97 +1441,3 @@ class BailingMoeForCausalLM(BailingMoePreTrainedModel):
1693
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1694
  )
1695
  return reordered_past
1696
-
1697
-
1698
- class BailingMoeForRewardModel(BailingMoePreTrainedModel):
1699
- def __init__(self, config: BailingMoeConfig, model: BailingMoeModel = None):
1700
- super().__init__(config)
1701
- self.num_labels = 1 # config.num_labels
1702
- if model:
1703
- self.model = model
1704
- else:
1705
- self.model = BailingMoeModel(config)
1706
- self.value_head = nn.Sequential(
1707
- nn.Linear(config.hidden_size, config.hidden_size), nn.ReLU(), nn.Linear(config.hidden_size, self.num_labels)
1708
- )
1709
-
1710
- # Initialize weights and apply final processing
1711
- self.post_init()
1712
-
1713
- def get_input_embeddings(self):
1714
- return self.model.word_embeddings
1715
-
1716
- def set_input_embeddings(self, value):
1717
- self.model.word_embeddings = value
1718
-
1719
- @add_start_docstrings_to_model_forward(BAILINGMOE_INPUTS_DOCSTRING)
1720
- def forward(
1721
- self,
1722
- input_ids: torch.LongTensor = None,
1723
- attention_mask: Optional[torch.Tensor] = None,
1724
- position_ids: Optional[torch.LongTensor] = None,
1725
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1726
- inputs_embeds: Optional[torch.FloatTensor] = None,
1727
- labels: Optional[torch.LongTensor] = None,
1728
- use_cache: Optional[bool] = None,
1729
- output_attentions: Optional[bool] = None,
1730
- output_hidden_states: Optional[bool] = None,
1731
- return_dict: Optional[bool] = None,
1732
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1733
- r"""
1734
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1735
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1736
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1737
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1738
- """
1739
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1740
-
1741
- transformer_outputs = self.model(
1742
- input_ids,
1743
- attention_mask=attention_mask,
1744
- position_ids=position_ids,
1745
- past_key_values=past_key_values,
1746
- inputs_embeds=inputs_embeds,
1747
- use_cache=use_cache,
1748
- output_attentions=output_attentions,
1749
- output_hidden_states=output_hidden_states,
1750
- return_dict=return_dict,
1751
- )
1752
-
1753
- if return_dict:
1754
- last_hidden_state = transformer_outputs.last_hidden_state
1755
- else:
1756
- last_hidden_state = transformer_outputs[0]
1757
-
1758
- logits = self.value_head(last_hidden_state)
1759
-
1760
- if input_ids is not None:
1761
- batch_size = input_ids.shape[0]
1762
- else:
1763
- batch_size = inputs_embeds.shape[0]
1764
-
1765
- if self.config.pad_token_id is None and batch_size != 1:
1766
- raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1767
- if self.config.pad_token_id is None:
1768
- sequence_lengths = -1
1769
- else:
1770
- if input_ids is not None:
1771
- # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1772
- sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1773
- sequence_lengths = sequence_lengths % input_ids.shape[-1]
1774
- sequence_lengths = sequence_lengths.to(logits.device)
1775
- else:
1776
- sequence_lengths = -1
1777
-
1778
- if isinstance(sequence_lengths, int) and sequence_lengths == -1:
1779
- sequence_lengths = (attention_mask.sum(dim=-1, keepdim=True) - 1).squeeze()
1780
-
1781
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] # logits of last token
1782
- pooled_logits = pooled_logits.squeeze()
1783
-
1784
- return SequenceClassifierOutputWithPast(
1785
- logits=pooled_logits,
1786
- past_key_values=transformer_outputs.past_key_values,
1787
- hidden_states=transformer_outputs.hidden_states,
1788
- attentions=transformer_outputs.hidden_states,
1789
- )
 
20
  """ PyTorch BailingMoE model."""
21
  import math
22
  import warnings
 
23
  from typing import List, Optional, Tuple, Union
24
 
25
  import torch
 
26
  import torch.nn.functional as F
27
  import torch.utils.checkpoint
 
 
28
  from torch import nn
29
  from torch.nn import CrossEntropyLoss
30
+
31
  from transformers.activations import ACT2FN
32
  from transformers.cache_utils import Cache, DynamicCache
33
  from transformers.modeling_attn_mask_utils import (
 
37
  _prepare_4d_causal_attention_mask_for_sdpa,
38
  )
39
  from transformers.modeling_outputs import (
 
 
40
  MoeModelOutputWithPast,
41
+ MoeCausalLMOutputWithPast,
42
  )
43
  from transformers.modeling_utils import PreTrainedModel
44
  from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
 
51
  replace_return_docstrings,
52
  )
53
  from transformers.utils.import_utils import is_torch_fx_available
 
54
  from .configuration_bailing_moe import BailingMoeConfig
55
 
56
+
57
  if is_flash_attn_2_available():
58
  from flash_attn import flash_attn_func, flash_attn_varlen_func
59
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
 
103
  )
104
 
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  class BailingMoeRMSNorm(nn.Module):
107
  def __init__(self, hidden_size, eps=1e-6):
108
  """
 
477
  value_states = value_states.transpose(1, 2)
478
 
479
  kv_seq_len = key_states.shape[-2]
 
480
  if past_key_value is not None:
481
  if self.layer_idx is None:
482
  raise ValueError(
 
485
  "with a layer index."
486
  )
487
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 
488
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
489
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
490
 
 
1343
 
1344
  logits = logits.float()
1345
 
1346
+ loss = None
1347
  aux_loss = None
1348
 
1349
  if labels is not None:
1350
+ # Shift so that tokens < n predict n
1351
+ shift_logits = logits[..., :-1, :].contiguous()
1352
+ shift_labels = labels[..., 1:].contiguous()
1353
+ # Flatten the tokens
1354
+ loss_fct = CrossEntropyLoss()
1355
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1356
+ shift_labels = shift_labels.view(-1)
1357
+ # Enable model parallelism
1358
+ shift_labels = shift_labels.to(shift_logits.device)
1359
+ loss = loss_fct(shift_logits, shift_labels)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1360
 
1361
  if not return_dict:
1362
  output = (logits,) + outputs[1:]
1363
+ if output_router_logits:
1364
+ output = (aux_loss,) + output
1365
  return (loss,) + output if loss is not None else output
1366
 
1367
+ return MoeCausalLMOutputWithPast(
1368
+ loss=loss,
1369
+ aux_loss=aux_loss,
1370
+ logits=logits,
1371
+ past_key_values=outputs.past_key_values,
1372
+ hidden_states=outputs.hidden_states,
1373
+ attentions=outputs.attentions,
1374
+ router_logits=outputs.router_logits,
1375
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1376
 
1377
  def prepare_inputs_for_generation(
1378
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, token_type_ids=None, **kwargs
 
1441
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1442
  )
1443
  return reordered_past