Update modeling_bailing_moe.py
Browse files- modeling_bailing_moe.py +25 -371
modeling_bailing_moe.py
CHANGED
|
@@ -20,17 +20,14 @@
|
|
| 20 |
""" PyTorch BailingMoE model."""
|
| 21 |
import math
|
| 22 |
import warnings
|
| 23 |
-
from dataclasses import dataclass
|
| 24 |
from typing import List, Optional, Tuple, Union
|
| 25 |
|
| 26 |
import torch
|
| 27 |
-
import torch.distributed as dist
|
| 28 |
import torch.nn.functional as F
|
| 29 |
import torch.utils.checkpoint
|
| 30 |
-
import transformers
|
| 31 |
-
from packaging import version
|
| 32 |
from torch import nn
|
| 33 |
from torch.nn import CrossEntropyLoss
|
|
|
|
| 34 |
from transformers.activations import ACT2FN
|
| 35 |
from transformers.cache_utils import Cache, DynamicCache
|
| 36 |
from transformers.modeling_attn_mask_utils import (
|
|
@@ -40,10 +37,8 @@ from transformers.modeling_attn_mask_utils import (
|
|
| 40 |
_prepare_4d_causal_attention_mask_for_sdpa,
|
| 41 |
)
|
| 42 |
from transformers.modeling_outputs import (
|
| 43 |
-
ModelOutput,
|
| 44 |
-
MoeCausalLMOutputWithPast,
|
| 45 |
MoeModelOutputWithPast,
|
| 46 |
-
|
| 47 |
)
|
| 48 |
from transformers.modeling_utils import PreTrainedModel
|
| 49 |
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
|
|
@@ -56,9 +51,9 @@ from transformers.utils import (
|
|
| 56 |
replace_return_docstrings,
|
| 57 |
)
|
| 58 |
from transformers.utils.import_utils import is_torch_fx_available
|
| 59 |
-
|
| 60 |
from .configuration_bailing_moe import BailingMoeConfig
|
| 61 |
|
|
|
|
| 62 |
if is_flash_attn_2_available():
|
| 63 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 64 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
|
@@ -108,220 +103,6 @@ def _make_causal_mask(
|
|
| 108 |
)
|
| 109 |
|
| 110 |
|
| 111 |
-
def _unpack_router_logits(router_outputs):
|
| 112 |
-
"""
|
| 113 |
-
Unpack the router tuple for blance loss calculation.
|
| 114 |
-
"""
|
| 115 |
-
total_router_logits = []
|
| 116 |
-
total_expert_indexes = []
|
| 117 |
-
for router_output in router_outputs:
|
| 118 |
-
if router_output[0] is not None:
|
| 119 |
-
router_logits, expert_indexes = router_output
|
| 120 |
-
total_router_logits.append(router_logits.unsqueeze(0))
|
| 121 |
-
total_expert_indexes.append(expert_indexes.unsqueeze(0))
|
| 122 |
-
return torch.cat(total_router_logits, dim=0), total_expert_indexes
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.Tensor, labels: torch.Tensor) -> float:
|
| 126 |
-
num_layers, _, seq_len, num_experts = router_probs.shape
|
| 127 |
-
num_experts = router_probs.shape[-1]
|
| 128 |
-
new_labels = labels.clone().detach()
|
| 129 |
-
##
|
| 130 |
-
for batch_tensor in new_labels:
|
| 131 |
-
neg_mask = batch_tensor == -100
|
| 132 |
-
diff_neg_ones = torch.diff(neg_mask.float())
|
| 133 |
-
start_pos = torch.where(diff_neg_ones == 1.0)[0] # 找到-1序列开始的位置
|
| 134 |
-
if start_pos.nelement() == 0: # 如果没有找到开始位置,可能需要根据实际情况调整
|
| 135 |
-
pass
|
| 136 |
-
else:
|
| 137 |
-
last_start = start_pos[-1] # 需要修改的最后一串-1的开始位置
|
| 138 |
-
batch_tensor[:last_start] = 0 # 将这部分-1全部改为0
|
| 139 |
-
new_labels = new_labels.to(torch.int64)
|
| 140 |
-
|
| 141 |
-
# cast the expert indices to int64, otherwise one-hot encoding will fail
|
| 142 |
-
|
| 143 |
-
if expert_indices.dtype != torch.int64:
|
| 144 |
-
expert_indices = expert_indices.to(torch.int64)
|
| 145 |
-
|
| 146 |
-
if len(expert_indices.shape) == 3:
|
| 147 |
-
expert_indices = expert_indices.unsqueeze(3)
|
| 148 |
-
|
| 149 |
-
expert_mask = torch.nn.functional.one_hot(expert_indices, num_experts)
|
| 150 |
-
|
| 151 |
-
# For a given token, determine if it was routed to a given expert.
|
| 152 |
-
expert_mask = torch.max(expert_mask, axis=-2).values
|
| 153 |
-
|
| 154 |
-
# cast to float32 otherwise mean will fail
|
| 155 |
-
expert_mask = expert_mask.to(torch.float32)
|
| 156 |
-
labels_mask = (new_labels[None, ..., None].expand_as(expert_mask) != -100).long()
|
| 157 |
-
|
| 158 |
-
# sample level balance loss
|
| 159 |
-
tokens_per_group_and_expert = torch.sum(expert_mask * labels_mask, dim=-2) / torch.sum(labels_mask, dim=-2)
|
| 160 |
-
router_prob_per_group_and_expert = torch.sum(router_probs * labels_mask, dim=-2) / torch.sum(labels_mask, dim=-2)
|
| 161 |
-
return torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) * (num_experts**2)
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
def router_z_loss_func(router_logits: torch.Tensor, labels: torch.Tensor) -> float:
|
| 165 |
-
r"""
|
| 166 |
-
Compute the router z-loss implemented in PyTorch.
|
| 167 |
-
|
| 168 |
-
The router z-loss was introduced in [Designing Effective Sparse Expert Models](https://arxiv.org/abs/2202.08906).
|
| 169 |
-
It encourages router logits to remain small in an effort to improve stability.
|
| 170 |
-
|
| 171 |
-
Args:
|
| 172 |
-
router_logits (`float`):
|
| 173 |
-
Input logits of shape [num_layers, batch_size, sequence_length, num_experts]
|
| 174 |
-
|
| 175 |
-
Returns:
|
| 176 |
-
Scalar router z-loss.
|
| 177 |
-
"""
|
| 178 |
-
num_layers, num_groups, tokens_per_group, _ = router_logits.shape
|
| 179 |
-
labels_mask = (labels[None, ..., None].expand_as(router_logits) != -100).long()
|
| 180 |
-
|
| 181 |
-
ori_dtype = router_logits.dtype
|
| 182 |
-
if ori_dtype == torch.bfloat16:
|
| 183 |
-
loss_func_inputs = (router_logits * labels_mask).to(torch.float32)
|
| 184 |
-
else:
|
| 185 |
-
loss_func_inputs = router_logits * labels_mask
|
| 186 |
-
log_z = torch.logsumexp(loss_func_inputs, dim=-1).to(ori_dtype)
|
| 187 |
-
z_loss = log_z**2
|
| 188 |
-
|
| 189 |
-
return torch.sum(z_loss) / (num_layers * num_groups * tokens_per_group)
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
def auxiliary_loss(router_tuple, lm_logits, labels, config: BailingMoeConfig):
|
| 193 |
-
balance_loss, z_loss, last_logits_l2_loss = 0.0, 0.0, 0.0
|
| 194 |
-
|
| 195 |
-
loss = 0
|
| 196 |
-
if router_tuple is not None:
|
| 197 |
-
router_logits, layer_router_index = _unpack_router_logits(router_tuple)
|
| 198 |
-
top1_expert_index = torch.cat(layer_router_index, dim=0)
|
| 199 |
-
z_loss = router_z_loss_func(router_logits, labels)
|
| 200 |
-
router_probs = torch.nn.Softmax(dim=-1)(router_logits)
|
| 201 |
-
balance_loss = load_balancing_loss_func(router_probs, top1_expert_index, labels)
|
| 202 |
-
|
| 203 |
-
num_layers = router_probs.shape[0]
|
| 204 |
-
num_experts = router_probs.shape[-1]
|
| 205 |
-
router_probs_log = router_probs.detach().view(num_layers, -1, num_experts)
|
| 206 |
-
router_probs_mean = router_probs_log.mean(1)
|
| 207 |
-
router_probs_sort_mean = router_probs_log.sort(-1, descending=True)[0].mean(1)
|
| 208 |
-
router_probs_log = torch.stack([router_probs_mean, router_probs_sort_mean], dim=1)
|
| 209 |
-
dist.all_reduce(router_probs_log, dist.ReduceOp.SUM)
|
| 210 |
-
router_probs_log = router_probs_log / torch.distributed.get_world_size()
|
| 211 |
-
if dist.get_rank() == 0:
|
| 212 |
-
router_probs_log = router_probs_log.float()
|
| 213 |
-
router_probs_log /= router_probs_log.sum(-1, keepdim=True)
|
| 214 |
-
|
| 215 |
-
loss = float(config.router_z_loss_alpha) * z_loss + float(config.router_balance_loss_alpha) * balance_loss
|
| 216 |
-
|
| 217 |
-
last_logits_l2_loss = 0.0
|
| 218 |
-
if float(config.last_logits_l2_alpha) >= 0:
|
| 219 |
-
shift_logits = lm_logits[..., :-1, :].contiguous()
|
| 220 |
-
shift_labels = labels[..., 1:].contiguous()
|
| 221 |
-
|
| 222 |
-
shift_logits = lm_logits.view(-1, lm_logits.size(-1))
|
| 223 |
-
labels_mask = (shift_labels.view(-1) != -100).long()
|
| 224 |
-
|
| 225 |
-
last_logits_l2_loss = torch.sum(torch.linalg.norm(shift_logits.float(), 2.0, dim=-1) * labels_mask) / torch.sum(
|
| 226 |
-
labels_mask
|
| 227 |
-
)
|
| 228 |
-
loss += float(config.last_logits_l2_alpha) * last_logits_l2_loss
|
| 229 |
-
last_logits_l2_loss = last_logits_l2_loss.item()
|
| 230 |
-
|
| 231 |
-
return loss, balance_loss, z_loss, last_logits_l2_loss
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
def local_token_level_cross_entropy(logits, labels, **kwargs):
|
| 235 |
-
# 在每个batch内部做token-level的平均,然后在所有batch间做平均
|
| 236 |
-
if isinstance(logits, ModelOutput):
|
| 237 |
-
logits = logits.logits
|
| 238 |
-
elif isinstance(logits, Tuple):
|
| 239 |
-
logits = logits[0]
|
| 240 |
-
|
| 241 |
-
logits = logits.float()
|
| 242 |
-
shift_logits = logits[..., :-1, :].contiguous()
|
| 243 |
-
shift_labels = labels[..., 1:].contiguous()
|
| 244 |
-
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
|
| 245 |
-
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
| 246 |
-
return loss
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
def sample_level_cross_entropy(logits, labels, **kwargs):
|
| 250 |
-
# 先对所有样本字token-level的平均,然后计算所有sample的平均值
|
| 251 |
-
if isinstance(logits, ModelOutput):
|
| 252 |
-
logits = logits.logits
|
| 253 |
-
elif isinstance(logits, Tuple):
|
| 254 |
-
logits = logits[0]
|
| 255 |
-
|
| 256 |
-
logits = logits.float()
|
| 257 |
-
shift_logits = logits[..., :-1, :].contiguous()
|
| 258 |
-
shift_labels = labels[..., 1:].contiguous()
|
| 259 |
-
loss_fct = CrossEntropyLoss(ignore_index=-100, reduction='none')
|
| 260 |
-
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).reshape(
|
| 261 |
-
shift_labels.shape[0], -1
|
| 262 |
-
)
|
| 263 |
-
loss = loss.sum(-1) / (shift_labels != -100).sum(-1)
|
| 264 |
-
loss = loss.mean()
|
| 265 |
-
return loss
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
def global_token_level_cross_entropy(logits, labels, **kwargs):
|
| 269 |
-
# 对所有样本一起做token-level的平均
|
| 270 |
-
if isinstance(logits, ModelOutput):
|
| 271 |
-
logits = logits.logits
|
| 272 |
-
elif isinstance(logits, Tuple):
|
| 273 |
-
logits = logits[0]
|
| 274 |
-
|
| 275 |
-
logits = logits.float()
|
| 276 |
-
shift_logits = logits[..., :-1, :].contiguous()
|
| 277 |
-
shift_labels = labels[..., 1:].contiguous()
|
| 278 |
-
loss_fct = CrossEntropyLoss(ignore_index=-100, reduction='none')
|
| 279 |
-
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).reshape(
|
| 280 |
-
shift_labels.shape[0], -1
|
| 281 |
-
)
|
| 282 |
-
num_tokens = (shift_labels != -100).sum()
|
| 283 |
-
loss = loss.sum()
|
| 284 |
-
|
| 285 |
-
num_tokens_tensor = torch.zeros([1], device=loss.device, dtype=loss.dtype)
|
| 286 |
-
num_tokens_tensor[0] = num_tokens.item()
|
| 287 |
-
|
| 288 |
-
torch.distributed.all_reduce(num_tokens_tensor)
|
| 289 |
-
|
| 290 |
-
global_num_tokens = num_tokens_tensor.sum()
|
| 291 |
-
|
| 292 |
-
torch.distributed.barrier()
|
| 293 |
-
# global_num_tokens是全局的token数,因为在梯度更新的时候回自动对所有卡求mean
|
| 294 |
-
# 所有这里要乘一个world_size
|
| 295 |
-
loss = loss.sum() / global_num_tokens * torch.distributed.get_world_size()
|
| 296 |
-
|
| 297 |
-
return loss
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
BAILING_LOSS_MAPPING = {
|
| 301 |
-
'local_token_level_cross_entropy': local_token_level_cross_entropy,
|
| 302 |
-
'sample_level_cross_entropy': sample_level_cross_entropy,
|
| 303 |
-
'global_token_level_cross_entropy': global_token_level_cross_entropy,
|
| 304 |
-
}
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
@dataclass
|
| 308 |
-
class CustomMoeOutput(ModelOutput):
|
| 309 |
-
"""完全自定义的输出类,包含所有需要的字段"""
|
| 310 |
-
|
| 311 |
-
loss: Optional[torch.FloatTensor] = None
|
| 312 |
-
aux_loss: Optional[torch.FloatTensor] = None
|
| 313 |
-
logits: torch.FloatTensor = None
|
| 314 |
-
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
| 315 |
-
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 316 |
-
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 317 |
-
router_logits: Optional[Tuple[torch.FloatTensor]] = None
|
| 318 |
-
# 额外的损失组件
|
| 319 |
-
lm_loss: Optional[torch.FloatTensor] = None
|
| 320 |
-
balance_loss: Optional[torch.FloatTensor] = None
|
| 321 |
-
z_loss: Optional[torch.FloatTensor] = None
|
| 322 |
-
last_logits_l2_loss: Optional[torch.FloatTensor] = None
|
| 323 |
-
|
| 324 |
-
|
| 325 |
class BailingMoeRMSNorm(nn.Module):
|
| 326 |
def __init__(self, hidden_size, eps=1e-6):
|
| 327 |
"""
|
|
@@ -696,7 +477,6 @@ class BailingMoeAttention(nn.Module):
|
|
| 696 |
value_states = value_states.transpose(1, 2)
|
| 697 |
|
| 698 |
kv_seq_len = key_states.shape[-2]
|
| 699 |
-
|
| 700 |
if past_key_value is not None:
|
| 701 |
if self.layer_idx is None:
|
| 702 |
raise ValueError(
|
|
@@ -705,7 +485,6 @@ class BailingMoeAttention(nn.Module):
|
|
| 705 |
"with a layer index."
|
| 706 |
)
|
| 707 |
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 708 |
-
|
| 709 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 710 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 711 |
|
|
@@ -1564,67 +1343,36 @@ class BailingMoeForCausalLM(BailingMoePreTrainedModel):
|
|
| 1564 |
|
| 1565 |
logits = logits.float()
|
| 1566 |
|
| 1567 |
-
|
| 1568 |
aux_loss = None
|
| 1569 |
|
| 1570 |
if labels is not None:
|
| 1571 |
-
|
| 1572 |
-
|
| 1573 |
-
|
| 1574 |
-
|
| 1575 |
-
|
| 1576 |
-
|
| 1577 |
-
|
| 1578 |
-
|
| 1579 |
-
|
| 1580 |
-
|
| 1581 |
-
f"`loss_type={loss_type}` was set in the config but it is unrecognised. "
|
| 1582 |
-
f"Using the default loss: `global_token_level_cross_entropy`."
|
| 1583 |
-
)
|
| 1584 |
-
loss_type = "global_token_level_cross_entropy"
|
| 1585 |
-
|
| 1586 |
-
loss_fct = built_in_loss_mapping[loss_type]
|
| 1587 |
-
lm_loss = loss_fct(logits, labels)
|
| 1588 |
-
|
| 1589 |
-
loss = lm_loss
|
| 1590 |
-
if output_router_logits and labels is not None:
|
| 1591 |
-
aux_loss, balance_loss, z_loss, last_logits_l2_loss = auxiliary_loss(
|
| 1592 |
-
outputs.router_logits, logits, labels, self.config
|
| 1593 |
-
)
|
| 1594 |
-
loss = lm_loss + self.config.router_aux_loss_coef * aux_loss
|
| 1595 |
|
| 1596 |
if not return_dict:
|
| 1597 |
output = (logits,) + outputs[1:]
|
| 1598 |
-
if output_router_logits
|
| 1599 |
-
output = (aux_loss,
|
| 1600 |
return (loss,) + output if loss is not None else output
|
| 1601 |
|
| 1602 |
-
|
| 1603 |
-
|
| 1604 |
-
|
| 1605 |
-
|
| 1606 |
-
|
| 1607 |
-
|
| 1608 |
-
|
| 1609 |
-
|
| 1610 |
-
|
| 1611 |
-
lm_loss=lm_loss,
|
| 1612 |
-
balance_loss=balance_loss,
|
| 1613 |
-
z_loss=z_loss,
|
| 1614 |
-
last_logits_l2_loss=last_logits_l2_loss,
|
| 1615 |
-
)
|
| 1616 |
-
|
| 1617 |
-
return moe_output
|
| 1618 |
-
else:
|
| 1619 |
-
return MoeCausalLMOutputWithPast(
|
| 1620 |
-
loss=loss,
|
| 1621 |
-
aux_loss=aux_loss,
|
| 1622 |
-
logits=logits,
|
| 1623 |
-
past_key_values=outputs.past_key_values,
|
| 1624 |
-
hidden_states=outputs.hidden_states,
|
| 1625 |
-
attentions=outputs.attentions,
|
| 1626 |
-
router_logits=outputs.router_logits,
|
| 1627 |
-
)
|
| 1628 |
|
| 1629 |
def prepare_inputs_for_generation(
|
| 1630 |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, token_type_ids=None, **kwargs
|
|
@@ -1693,97 +1441,3 @@ class BailingMoeForCausalLM(BailingMoePreTrainedModel):
|
|
| 1693 |
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
| 1694 |
)
|
| 1695 |
return reordered_past
|
| 1696 |
-
|
| 1697 |
-
|
| 1698 |
-
class BailingMoeForRewardModel(BailingMoePreTrainedModel):
|
| 1699 |
-
def __init__(self, config: BailingMoeConfig, model: BailingMoeModel = None):
|
| 1700 |
-
super().__init__(config)
|
| 1701 |
-
self.num_labels = 1 # config.num_labels
|
| 1702 |
-
if model:
|
| 1703 |
-
self.model = model
|
| 1704 |
-
else:
|
| 1705 |
-
self.model = BailingMoeModel(config)
|
| 1706 |
-
self.value_head = nn.Sequential(
|
| 1707 |
-
nn.Linear(config.hidden_size, config.hidden_size), nn.ReLU(), nn.Linear(config.hidden_size, self.num_labels)
|
| 1708 |
-
)
|
| 1709 |
-
|
| 1710 |
-
# Initialize weights and apply final processing
|
| 1711 |
-
self.post_init()
|
| 1712 |
-
|
| 1713 |
-
def get_input_embeddings(self):
|
| 1714 |
-
return self.model.word_embeddings
|
| 1715 |
-
|
| 1716 |
-
def set_input_embeddings(self, value):
|
| 1717 |
-
self.model.word_embeddings = value
|
| 1718 |
-
|
| 1719 |
-
@add_start_docstrings_to_model_forward(BAILINGMOE_INPUTS_DOCSTRING)
|
| 1720 |
-
def forward(
|
| 1721 |
-
self,
|
| 1722 |
-
input_ids: torch.LongTensor = None,
|
| 1723 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 1724 |
-
position_ids: Optional[torch.LongTensor] = None,
|
| 1725 |
-
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 1726 |
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1727 |
-
labels: Optional[torch.LongTensor] = None,
|
| 1728 |
-
use_cache: Optional[bool] = None,
|
| 1729 |
-
output_attentions: Optional[bool] = None,
|
| 1730 |
-
output_hidden_states: Optional[bool] = None,
|
| 1731 |
-
return_dict: Optional[bool] = None,
|
| 1732 |
-
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
| 1733 |
-
r"""
|
| 1734 |
-
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1735 |
-
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1736 |
-
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1737 |
-
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1738 |
-
"""
|
| 1739 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1740 |
-
|
| 1741 |
-
transformer_outputs = self.model(
|
| 1742 |
-
input_ids,
|
| 1743 |
-
attention_mask=attention_mask,
|
| 1744 |
-
position_ids=position_ids,
|
| 1745 |
-
past_key_values=past_key_values,
|
| 1746 |
-
inputs_embeds=inputs_embeds,
|
| 1747 |
-
use_cache=use_cache,
|
| 1748 |
-
output_attentions=output_attentions,
|
| 1749 |
-
output_hidden_states=output_hidden_states,
|
| 1750 |
-
return_dict=return_dict,
|
| 1751 |
-
)
|
| 1752 |
-
|
| 1753 |
-
if return_dict:
|
| 1754 |
-
last_hidden_state = transformer_outputs.last_hidden_state
|
| 1755 |
-
else:
|
| 1756 |
-
last_hidden_state = transformer_outputs[0]
|
| 1757 |
-
|
| 1758 |
-
logits = self.value_head(last_hidden_state)
|
| 1759 |
-
|
| 1760 |
-
if input_ids is not None:
|
| 1761 |
-
batch_size = input_ids.shape[0]
|
| 1762 |
-
else:
|
| 1763 |
-
batch_size = inputs_embeds.shape[0]
|
| 1764 |
-
|
| 1765 |
-
if self.config.pad_token_id is None and batch_size != 1:
|
| 1766 |
-
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
| 1767 |
-
if self.config.pad_token_id is None:
|
| 1768 |
-
sequence_lengths = -1
|
| 1769 |
-
else:
|
| 1770 |
-
if input_ids is not None:
|
| 1771 |
-
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
| 1772 |
-
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
| 1773 |
-
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
| 1774 |
-
sequence_lengths = sequence_lengths.to(logits.device)
|
| 1775 |
-
else:
|
| 1776 |
-
sequence_lengths = -1
|
| 1777 |
-
|
| 1778 |
-
if isinstance(sequence_lengths, int) and sequence_lengths == -1:
|
| 1779 |
-
sequence_lengths = (attention_mask.sum(dim=-1, keepdim=True) - 1).squeeze()
|
| 1780 |
-
|
| 1781 |
-
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] # logits of last token
|
| 1782 |
-
pooled_logits = pooled_logits.squeeze()
|
| 1783 |
-
|
| 1784 |
-
return SequenceClassifierOutputWithPast(
|
| 1785 |
-
logits=pooled_logits,
|
| 1786 |
-
past_key_values=transformer_outputs.past_key_values,
|
| 1787 |
-
hidden_states=transformer_outputs.hidden_states,
|
| 1788 |
-
attentions=transformer_outputs.hidden_states,
|
| 1789 |
-
)
|
|
|
|
| 20 |
""" PyTorch BailingMoE model."""
|
| 21 |
import math
|
| 22 |
import warnings
|
|
|
|
| 23 |
from typing import List, Optional, Tuple, Union
|
| 24 |
|
| 25 |
import torch
|
|
|
|
| 26 |
import torch.nn.functional as F
|
| 27 |
import torch.utils.checkpoint
|
|
|
|
|
|
|
| 28 |
from torch import nn
|
| 29 |
from torch.nn import CrossEntropyLoss
|
| 30 |
+
|
| 31 |
from transformers.activations import ACT2FN
|
| 32 |
from transformers.cache_utils import Cache, DynamicCache
|
| 33 |
from transformers.modeling_attn_mask_utils import (
|
|
|
|
| 37 |
_prepare_4d_causal_attention_mask_for_sdpa,
|
| 38 |
)
|
| 39 |
from transformers.modeling_outputs import (
|
|
|
|
|
|
|
| 40 |
MoeModelOutputWithPast,
|
| 41 |
+
MoeCausalLMOutputWithPast,
|
| 42 |
)
|
| 43 |
from transformers.modeling_utils import PreTrainedModel
|
| 44 |
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
|
|
|
|
| 51 |
replace_return_docstrings,
|
| 52 |
)
|
| 53 |
from transformers.utils.import_utils import is_torch_fx_available
|
|
|
|
| 54 |
from .configuration_bailing_moe import BailingMoeConfig
|
| 55 |
|
| 56 |
+
|
| 57 |
if is_flash_attn_2_available():
|
| 58 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 59 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
|
|
|
| 103 |
)
|
| 104 |
|
| 105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
class BailingMoeRMSNorm(nn.Module):
|
| 107 |
def __init__(self, hidden_size, eps=1e-6):
|
| 108 |
"""
|
|
|
|
| 477 |
value_states = value_states.transpose(1, 2)
|
| 478 |
|
| 479 |
kv_seq_len = key_states.shape[-2]
|
|
|
|
| 480 |
if past_key_value is not None:
|
| 481 |
if self.layer_idx is None:
|
| 482 |
raise ValueError(
|
|
|
|
| 485 |
"with a layer index."
|
| 486 |
)
|
| 487 |
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
|
|
|
| 488 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 489 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 490 |
|
|
|
|
| 1343 |
|
| 1344 |
logits = logits.float()
|
| 1345 |
|
| 1346 |
+
loss = None
|
| 1347 |
aux_loss = None
|
| 1348 |
|
| 1349 |
if labels is not None:
|
| 1350 |
+
# Shift so that tokens < n predict n
|
| 1351 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 1352 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 1353 |
+
# Flatten the tokens
|
| 1354 |
+
loss_fct = CrossEntropyLoss()
|
| 1355 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
| 1356 |
+
shift_labels = shift_labels.view(-1)
|
| 1357 |
+
# Enable model parallelism
|
| 1358 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
| 1359 |
+
loss = loss_fct(shift_logits, shift_labels)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1360 |
|
| 1361 |
if not return_dict:
|
| 1362 |
output = (logits,) + outputs[1:]
|
| 1363 |
+
if output_router_logits:
|
| 1364 |
+
output = (aux_loss,) + output
|
| 1365 |
return (loss,) + output if loss is not None else output
|
| 1366 |
|
| 1367 |
+
return MoeCausalLMOutputWithPast(
|
| 1368 |
+
loss=loss,
|
| 1369 |
+
aux_loss=aux_loss,
|
| 1370 |
+
logits=logits,
|
| 1371 |
+
past_key_values=outputs.past_key_values,
|
| 1372 |
+
hidden_states=outputs.hidden_states,
|
| 1373 |
+
attentions=outputs.attentions,
|
| 1374 |
+
router_logits=outputs.router_logits,
|
| 1375 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1376 |
|
| 1377 |
def prepare_inputs_for_generation(
|
| 1378 |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, token_type_ids=None, **kwargs
|
|
|
|
| 1441 |
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
| 1442 |
)
|
| 1443 |
return reordered_past
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|