rbao2018 commited on
Commit
8ba2030
·
1 Parent(s): 55fbe95

fix four files

Browse files
config.json CHANGED
@@ -6,7 +6,8 @@
6
  "auto_map": {
7
  "AutoConfig": "configuration_bailing_moe.BailingMoeConfig",
8
  "AutoModel": "modeling_bailing_moe.BailingMoeModel",
9
- "AutoModelForCausalLM": "modeling_bailing_moe.BailingMoeForCausalLM"
 
10
  },
11
  "eos_token_id": 126081,
12
  "pad_token_id": 126081,
@@ -40,6 +41,5 @@
40
  "embedding_dropout": 0.1,
41
  "norm_head": true,
42
  "norm_softmax": false,
43
- "output_dropout": 0.1,
44
- "head_dim": 0
45
  }
 
6
  "auto_map": {
7
  "AutoConfig": "configuration_bailing_moe.BailingMoeConfig",
8
  "AutoModel": "modeling_bailing_moe.BailingMoeModel",
9
+ "AutoModelForCausalLM": "modeling_bailing_moe.BailingMoeForCausalLM",
10
+ "AutoModelForTokenClassification": "modeling_bailing_moe.BailingMoeForTokenClassification"
11
  },
12
  "eos_token_id": 126081,
13
  "pad_token_id": 126081,
 
41
  "embedding_dropout": 0.1,
42
  "norm_head": true,
43
  "norm_softmax": false,
44
+ "output_dropout": 0.1
 
45
  }
modeling_bailing_moe.py CHANGED
@@ -24,7 +24,6 @@ from typing import List, Optional, Tuple, Union
24
 
25
  import torch
26
  import torch.nn.functional as F
27
- import torch.utils.checkpoint
28
  from torch import nn
29
  from torch.nn import CrossEntropyLoss
30
 
@@ -72,6 +71,81 @@ logger = logging.get_logger(__name__)
72
 
73
  _CONFIG_FOR_DOC = "BailingMoeConfig"
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  def _get_unpad_data(attention_mask):
77
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
@@ -421,7 +495,7 @@ class BailingMoeSparseMoeBlock(nn.Module):
421
  y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(bsz, seq_len, h)
422
  if self.config.num_shared_experts is not None:
423
  y = y + self.shared_experts(identity)
424
- return y, (router_logits.view(bsz, seq_len, -1), topk_idx.view(bsz, seq_len, -1))
425
 
426
  @torch.no_grad()
427
  def moe_infer(self, x, topk_ids, topk_weight):
@@ -1363,21 +1437,12 @@ class BailingMoeForCausalLM(BailingMoePreTrainedModel):
1363
 
1364
  def compute_logit(self, hidden_states):
1365
  if self.norm_head:
1366
- if self.training:
1367
- norm_weight = (
1368
- self.lm_head.weight / (torch.norm(self.lm_head.weight, p=2, dim=0, keepdim=True) + 1e-7).detach()
1369
- )
1370
- logits = F.linear(hidden_states, norm_weight, None)
1371
- else:
1372
- self.lm_head.weight.data = (
1373
- self.lm_head.weight.data.float()
1374
- / (torch.norm(self.lm_head.weight.data.float(), p=2, dim=0, keepdim=True) + 1e-7)
1375
- ).to(hidden_states.dtype)
1376
- logits = F.linear(hidden_states, self.lm_head.weight.data, None)
1377
- self.norm_head = False
1378
  else:
1379
  logits = self.lm_head(hidden_states)
1380
- return logits
1381
 
1382
  @add_start_docstrings_to_model_forward(BAILINGMOE_INPUTS_DOCSTRING)
1383
  @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
@@ -1452,6 +1517,14 @@ class BailingMoeForCausalLM(BailingMoePreTrainedModel):
1452
  loss = None
1453
  aux_loss = None
1454
 
 
 
 
 
 
 
 
 
1455
  if labels is not None:
1456
  # Shift so that tokens < n predict n
1457
  shift_logits = logits[..., :-1, :].contiguous()
@@ -1547,3 +1620,107 @@ class BailingMoeForCausalLM(BailingMoePreTrainedModel):
1547
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1548
  )
1549
  return reordered_past
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  import torch
26
  import torch.nn.functional as F
 
27
  from torch import nn
28
  from torch.nn import CrossEntropyLoss
29
 
 
71
 
72
  _CONFIG_FOR_DOC = "BailingMoeConfig"
73
 
74
+ # Copied from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func
75
+ def load_balancing_loss_func(
76
+ gate_logits_and_topk: Union[torch.Tensor, Tuple[torch.Tensor], None],
77
+ num_experts: Optional[int] = None,
78
+ top_k=2,
79
+ attention_mask: Optional[torch.Tensor] = None,
80
+ ) -> Union[torch.Tensor, int]:
81
+ r"""
82
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
83
+
84
+ See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
85
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
86
+ experts is too unbalanced.
87
+
88
+ Args:
89
+ gate_logits:
90
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
91
+ shape [batch_size X sequence_length, num_experts].
92
+ num_experts:
93
+ Number of experts
94
+ top_k:
95
+ The number of experts to route per-token, can be also interpreted as the `top-k` routing
96
+ parameter.
97
+ attention_mask (`torch.Tensor`, *optional*):
98
+ The attention_mask used in forward function
99
+ shape [batch_size X sequence_length] if not None.
100
+
101
+ Returns:
102
+ The auxiliary loss.
103
+ """
104
+ if gate_logits_and_topk is None or not isinstance(gate_logits_and_topk, tuple):
105
+ return 0
106
+
107
+ if isinstance(gate_logits_and_topk, tuple):
108
+ # concatenated_gate_logits.shape = [batch_size * num_layers * seq_len, num_experts]
109
+ concatenated_gate_logits = torch.cat([layer_gate[0] for layer_gate in gate_logits_and_topk], dim=0)
110
+ # selected_experts.shape = [batch_size * num_layers * seq_len, top_k_experts]
111
+ selected_experts = torch.cat([layer_gate[1] for layer_gate in gate_logits_and_topk], dim=0)
112
+ selected_experts.to(concatenated_gate_logits.device)
113
+
114
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
115
+
116
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
117
+
118
+ if attention_mask is None:
119
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
120
+ # Compute the average probability of routing to these experts
121
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
122
+ else:
123
+ batch_size, sequence_length = attention_mask.shape
124
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
125
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
126
+ expert_attention_mask = (
127
+ attention_mask[None, :, :, None, None]
128
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
129
+ .reshape(-1, top_k, num_experts)
130
+ )
131
+
132
+ # Compute the percentage of tokens routed to each experts
133
+ tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
134
+ expert_attention_mask, dim=0
135
+ )
136
+
137
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
138
+ router_per_expert_attention_mask = (
139
+ attention_mask[None, :, :, None]
140
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
141
+ .reshape(-1, num_experts)
142
+ )
143
+
144
+ # Compute the average probability of routing to these experts
145
+ router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(router_per_expert_attention_mask, dim=0)
146
+
147
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
148
+ return overall_loss
149
 
150
  def _get_unpad_data(attention_mask):
151
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
 
495
  y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(bsz, seq_len, h)
496
  if self.config.num_shared_experts is not None:
497
  y = y + self.shared_experts(identity)
498
+ return y, (router_logits, topk_idx)
499
 
500
  @torch.no_grad()
501
  def moe_infer(self, x, topk_ids, topk_weight):
 
1437
 
1438
  def compute_logit(self, hidden_states):
1439
  if self.norm_head:
1440
+ weight_float = self.lm_head.weight.float()
1441
+ norm = torch.norm(weight_float, p=2, dim=0, keepdim=True).clamp(min=1e-7)
1442
+ norm_weight = (weight_float / norm).to(hidden_states.dtype)
1443
+ logits = F.linear(hidden_states, norm_weight, None)
 
 
 
 
 
 
 
 
1444
  else:
1445
  logits = self.lm_head(hidden_states)
 
1446
 
1447
  @add_start_docstrings_to_model_forward(BAILINGMOE_INPUTS_DOCSTRING)
1448
  @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
 
1517
  loss = None
1518
  aux_loss = None
1519
 
1520
+ if output_router_logits:
1521
+ aux_loss = load_balancing_loss_func(
1522
+ outputs.router_logits if return_dict else outputs[-1],
1523
+ self.num_experts,
1524
+ self.num_experts_per_tok,
1525
+ attention_mask,
1526
+ )
1527
+
1528
  if labels is not None:
1529
  # Shift so that tokens < n predict n
1530
  shift_logits = logits[..., :-1, :].contiguous()
 
1620
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1621
  )
1622
  return reordered_past
1623
+
1624
+
1625
+ # Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2Moe, LLAMA->QWEN2MOE
1626
+ class BailingMoeForTokenClassification(BailingMoePreTrainedModel):
1627
+ def __init__(self, config):
1628
+ super().__init__(config)
1629
+ self.num_labels = config.num_labels
1630
+
1631
+ self.num_experts = config.num_experts
1632
+ self.num_experts_per_tok = config.num_experts_per_tok
1633
+
1634
+ self.model = BailingMoeModel(config)
1635
+ if getattr(config, "classifier_dropout", None) is not None:
1636
+ classifier_dropout = config.classifier_dropout
1637
+ elif getattr(config, "hidden_dropout", None) is not None:
1638
+ classifier_dropout = config.hidden_dropout
1639
+ else:
1640
+ classifier_dropout = 0.1
1641
+ self.dropout = nn.Dropout(classifier_dropout)
1642
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
1643
+
1644
+ # Initialize weights and apply final processing
1645
+ self.post_init()
1646
+
1647
+ def get_input_embeddings(self):
1648
+ return self.model.embed_tokens
1649
+
1650
+ def set_input_embeddings(self, value):
1651
+ self.model.embed_tokens = value
1652
+
1653
+ @add_start_docstrings_to_model_forward(BAILINGMOE_INPUTS_DOCSTRING)
1654
+ def forward(
1655
+ self,
1656
+ input_ids: Optional[torch.LongTensor] = None,
1657
+ attention_mask: Optional[torch.Tensor] = None,
1658
+ position_ids: Optional[torch.LongTensor] = None,
1659
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1660
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1661
+ labels: Optional[torch.LongTensor] = None,
1662
+ use_cache: Optional[bool] = None,
1663
+ output_attentions: Optional[bool] = None,
1664
+ output_hidden_states: Optional[bool] = None,
1665
+ output_router_logits: Optional[bool] = None,
1666
+ return_dict: Optional[bool] = None,
1667
+ ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
1668
+ r"""
1669
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1670
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1671
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1672
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1673
+ """
1674
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1675
+ output_router_logits = (
1676
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
1677
+ )
1678
+
1679
+ outputs = self.model(
1680
+ input_ids,
1681
+ attention_mask=attention_mask,
1682
+ position_ids=position_ids,
1683
+ past_key_values=past_key_values,
1684
+ inputs_embeds=inputs_embeds,
1685
+ use_cache=use_cache,
1686
+ output_attentions=output_attentions,
1687
+ output_hidden_states=output_hidden_states,
1688
+ output_router_logits=output_router_logits,
1689
+ return_dict=return_dict,
1690
+ )
1691
+ sequence_output = outputs[0]
1692
+ sequence_output = self.dropout(sequence_output)
1693
+ logits = self.score(sequence_output)
1694
+
1695
+ loss = None
1696
+ aux_loss = None
1697
+ if labels is not None:
1698
+ loss = self.loss_function(logits, labels, self.config)
1699
+
1700
+ if output_router_logits:
1701
+ aux_loss = load_balancing_loss_func(
1702
+ outputs.router_logits if return_dict else outputs[-1],
1703
+ self.num_experts,
1704
+ self.num_experts_per_tok,
1705
+ attention_mask,
1706
+ )
1707
+
1708
+ if not return_dict:
1709
+ output = (logits,) + outputs[1:]
1710
+ if output_router_logits:
1711
+ output = (aux_loss,) + output
1712
+ return (loss,) + output if loss is not None else output
1713
+
1714
+ if not return_dict:
1715
+ output = (logits,) + outputs[2:]
1716
+ return ((loss,) + output) if loss is not None else output
1717
+
1718
+ return MoeCausalLMOutputWithPast(
1719
+ loss=loss,
1720
+ aux_loss=aux_loss,
1721
+ logits=logits,
1722
+ past_key_values=outputs.past_key_values,
1723
+ hidden_states=outputs.hidden_states,
1724
+ attentions=outputs.attentions,
1725
+ router_logits=outputs.router_logits,
1726
+ )
special_tokens_map.json CHANGED
@@ -1,10 +1,15 @@
1
  {
 
 
 
 
 
 
 
 
2
  "bos_token": "<|startoftext|>",
3
  "cls_token": "[CLS]",
4
  "eos_token": "<|endoftext|>",
5
  "gmask_token": "[gMASK]",
6
- "additional_special_tokens": [
7
- "<role>",
8
- "</role>"
9
- ]
10
  }
 
1
  {
2
+ "additional_special_tokens": [
3
+ "<|arithmetic_start|>",
4
+ "<|arithmetic_end|>",
5
+ "<role>",
6
+ "</role>",
7
+ "<|number_end|>",
8
+ "<|number_start|>"
9
+ ],
10
  "bos_token": "<|startoftext|>",
11
  "cls_token": "[CLS]",
12
  "eos_token": "<|endoftext|>",
13
  "gmask_token": "[gMASK]",
14
+ "pad_token": "<|endoftext|>"
 
 
 
15
  }
tokenizer_config.json CHANGED
@@ -1,15 +1,25 @@
1
  {
2
  "add_bos_token": false,
3
  "add_eos_token": false,
 
 
 
 
 
 
 
 
4
  "bos_token": "<|startoftext|>",
 
5
  "clean_up_tokenization_spaces": false,
6
  "cls_token": "[CLS]",
7
  "eos_token": "<|endoftext|>",
8
  "gmask_token": "[gMASK]",
9
  "merges_file": null,
10
  "model_max_length": 1000000000000000019884624838656,
 
11
  "tokenizer_class": "PreTrainedTokenizerFast",
 
12
  "vocab_file": null,
13
- "pad_token": "<|endoftext|>",
14
  "fast_tokenizer": true
15
  }
 
1
  {
2
  "add_bos_token": false,
3
  "add_eos_token": false,
4
+ "additional_special_tokens": [
5
+ "<role>",
6
+ "</role>",
7
+ "<|arithmetic_start|>",
8
+ "<|arithmetic_end|>",
9
+ "<|number_start|>",
10
+ "<|number_end|>"
11
+ ],
12
  "bos_token": "<|startoftext|>",
13
+ "chat_template": "{% for message in messages %}{% set role = message['role'] | lower %}{% if role == 'user' %}{% set role = 'HUMAN' %}{% endif %}{% set role = role | upper %}{{ '<role>' + role + '</role>' + message['content'] }}{% if role == 'ASSISTANT' %}{{ eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<role>ASSISTANT</role>' }}{% endif %}",
14
  "clean_up_tokenization_spaces": false,
15
  "cls_token": "[CLS]",
16
  "eos_token": "<|endoftext|>",
17
  "gmask_token": "[gMASK]",
18
  "merges_file": null,
19
  "model_max_length": 1000000000000000019884624838656,
20
+ "pad_token": "<|endoftext|>",
21
  "tokenizer_class": "PreTrainedTokenizerFast",
22
+ "trust_remote_code": true,
23
  "vocab_file": null,
 
24
  "fast_tokenizer": true
25
  }