zhujiangang commited on
Commit
ad87bcb
·
verified ·
1 Parent(s): c9c10d7

Update modeling_bailing_moe.py

Browse files
Files changed (1) hide show
  1. modeling_bailing_moe.py +127 -22
modeling_bailing_moe.py CHANGED
@@ -207,6 +207,90 @@ class BailingMoeDynamicNTKScalingRotaryEmbedding(BailingMoeRotaryEmbedding):
207
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
208
 
209
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  # Copied from transformers.models.llama.modeling_llama.rotate_half
211
  def rotate_half(x):
212
  """Rotates half the hidden dims of the input."""
@@ -278,7 +362,7 @@ class BailingMoeGate(nn.Module):
278
 
279
  init.kaiming_uniform_(self.weight, a=math.sqrt(5))
280
 
281
- def forward(self, hidden_states):
282
  bsz, seq_len, h = hidden_states.shape
283
  # compute gating score
284
  hidden_states = hidden_states.view(-1, h)
@@ -286,7 +370,7 @@ class BailingMoeGate(nn.Module):
286
  scores = logits.softmax(dim=-1, dtype=torch.float32)
287
 
288
  # select top-k experts
289
- topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
290
 
291
  # norm gate to sum 1
292
  if self.top_k > 1 and self.norm_topk_prob:
@@ -305,7 +389,7 @@ class BailingMoeSparseMoeBlock(nn.Module):
305
  super().__init__()
306
  self.config = config
307
  self.num_experts_per_tok = config.num_experts_per_tok
308
- self.experts = self._setup_experts()
309
  self.gate = BailingMoeGate(config)
310
  if config.num_shared_experts is not None:
311
  self.shared_experts = BailingMoeMLP(
@@ -313,7 +397,7 @@ class BailingMoeSparseMoeBlock(nn.Module):
313
  )
314
 
315
  def _setup_experts(self):
316
- return nn.ModuleList(
317
  [
318
  BailingMoeMLP(config=self.config, intermediate_size=self.config.moe_intermediate_size)
319
  for _ in range(self.config.num_experts)
@@ -443,6 +527,25 @@ class BailingMoeAttention(nn.Module):
443
  scaling_factor=scaling_factor,
444
  base=self.rope_theta,
445
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
446
  else:
447
  raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
448
 
@@ -1258,6 +1361,24 @@ class BailingMoeForCausalLM(BailingMoePreTrainedModel):
1258
  def get_decoder(self):
1259
  return self.model
1260
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1261
  @add_start_docstrings_to_model_forward(BAILINGMOE_INPUTS_DOCSTRING)
1262
  @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1263
  def forward(
@@ -1325,22 +1446,7 @@ class BailingMoeForCausalLM(BailingMoePreTrainedModel):
1325
 
1326
  hidden_states = outputs[0]
1327
 
1328
- if self.norm_head:
1329
- if self.training:
1330
- norm_weight = (
1331
- self.lm_head.weight / (torch.norm(self.lm_head.weight, p=2, dim=0, keepdim=True) + 1e-7).detach()
1332
- )
1333
- logits = F.linear(hidden_states, norm_weight, None)
1334
- else:
1335
- self.lm_head.weight.data = (
1336
- self.lm_head.weight.data.float()
1337
- / (torch.norm(self.lm_head.weight.data.float(), p=2, dim=0, keepdim=True) + 1e-7)
1338
- ).to(hidden_states.dtype)
1339
- logits = F.linear(hidden_states, self.lm_head.weight.data, None)
1340
- self.norm_head = False
1341
- else:
1342
- logits = self.lm_head(hidden_states)
1343
-
1344
  logits = logits.float()
1345
 
1346
  loss = None
@@ -1392,8 +1498,7 @@ class BailingMoeForCausalLM(BailingMoePreTrainedModel):
1392
 
1393
  # Keep only the unprocessed tokens:
1394
  # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1395
- # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
1396
- # input)
1397
  if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1398
  input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1399
  # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
 
207
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
208
 
209
 
210
+ # Inverse dim formula to find dim based on number of rotations
211
+ def yarn_find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
212
+ return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
213
+
214
+
215
+ # Find dim range bounds based on rotations
216
+ def yarn_find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
217
+ low = math.floor(yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
218
+ high = math.ceil(yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
219
+ return max(low, 0), min(high, dim - 1) # Clamp values just in case
220
+
221
+
222
+ def yarn_get_mscale(scale=1, mscale=1):
223
+ if scale <= 1:
224
+ return 1.0
225
+ return 0.1 * mscale * math.log(scale) + 1.0
226
+
227
+
228
+ def yarn_linear_ramp_mask(min, max, dim):
229
+ if min == max:
230
+ max += 0.001 # Prevent singularity
231
+
232
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
233
+ ramp_func = torch.clamp(linear_func, 0, 1)
234
+ return ramp_func
235
+
236
+
237
+ class BailingMoeYarnRotaryEmbedding(BailingMoeRotaryEmbedding):
238
+
239
+ def __init__(
240
+ self,
241
+ dim,
242
+ max_position_embeddings=2048,
243
+ base=10000,
244
+ device=None,
245
+ scaling_factor=1.0,
246
+ original_max_position_embeddings=4096,
247
+ beta_fast=32,
248
+ beta_slow=1,
249
+ mscale=1,
250
+ mscale_all_dim=0,
251
+ ):
252
+ self.scaling_factor = scaling_factor
253
+ self.original_max_position_embeddings = original_max_position_embeddings
254
+ self.beta_fast = beta_fast
255
+ self.beta_slow = beta_slow
256
+ self.mscale = mscale
257
+ self.mscale_all_dim = mscale_all_dim
258
+ super().__init__(dim, max_position_embeddings, base, device)
259
+
260
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
261
+ self.max_seq_len_cached = seq_len
262
+ dim = self.dim
263
+
264
+ freq_extra = 1.0 / (self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
265
+ freq_inter = 1.0 / (
266
+ self.scaling_factor * self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
267
+ )
268
+
269
+ low, high = yarn_find_correction_range(
270
+ self.beta_fast,
271
+ self.beta_slow,
272
+ dim,
273
+ self.base,
274
+ self.original_max_position_embeddings,
275
+ )
276
+ inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(device=device, dtype=torch.float32)
277
+ inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
278
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
279
+
280
+ t = torch.arange(seq_len, device=device, dtype=torch.float32)
281
+
282
+ freqs = torch.outer(t, inv_freq)
283
+
284
+ _mscale = float(
285
+ yarn_get_mscale(self.scaling_factor, self.mscale)
286
+ / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
287
+ )
288
+
289
+ emb = torch.cat((freqs, freqs), dim=-1)
290
+ self.register_buffer("cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False)
291
+ self.register_buffer("sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False)
292
+
293
+
294
  # Copied from transformers.models.llama.modeling_llama.rotate_half
295
  def rotate_half(x):
296
  """Rotates half the hidden dims of the input."""
 
362
 
363
  init.kaiming_uniform_(self.weight, a=math.sqrt(5))
364
 
365
+ def forward(self, hidden_states, sort=False):
366
  bsz, seq_len, h = hidden_states.shape
367
  # compute gating score
368
  hidden_states = hidden_states.view(-1, h)
 
370
  scores = logits.softmax(dim=-1, dtype=torch.float32)
371
 
372
  # select top-k experts
373
+ topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=sort)
374
 
375
  # norm gate to sum 1
376
  if self.top_k > 1 and self.norm_topk_prob:
 
389
  super().__init__()
390
  self.config = config
391
  self.num_experts_per_tok = config.num_experts_per_tok
392
+ self._setup_experts()
393
  self.gate = BailingMoeGate(config)
394
  if config.num_shared_experts is not None:
395
  self.shared_experts = BailingMoeMLP(
 
397
  )
398
 
399
  def _setup_experts(self):
400
+ self.experts = nn.ModuleList(
401
  [
402
  BailingMoeMLP(config=self.config, intermediate_size=self.config.moe_intermediate_size)
403
  for _ in range(self.config.num_experts)
 
527
  scaling_factor=scaling_factor,
528
  base=self.rope_theta,
529
  )
530
+ elif scaling_type == "yarn":
531
+ kwargs = {
532
+ key: self.config.rope_scaling[key]
533
+ for key in [
534
+ "original_max_position_embeddings",
535
+ "beta_fast",
536
+ "beta_slow",
537
+ "mscale",
538
+ "mscale_all_dim",
539
+ ]
540
+ if key in self.config.rope_scaling
541
+ }
542
+ self.rotary_emb = BailingMoeYarnRotaryEmbedding(
543
+ self.head_dim,
544
+ max_position_embeddings=self.max_position_embeddings,
545
+ scaling_factor=scaling_factor,
546
+ base=self.rope_theta,
547
+ **kwargs,
548
+ )
549
  else:
550
  raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
551
 
 
1361
  def get_decoder(self):
1362
  return self.model
1363
 
1364
+ def compute_logit(self, hidden_states):
1365
+ if self.norm_head:
1366
+ if self.training:
1367
+ norm_weight = (
1368
+ self.lm_head.weight / (torch.norm(self.lm_head.weight, p=2, dim=0, keepdim=True) + 1e-7).detach()
1369
+ )
1370
+ logits = F.linear(hidden_states, norm_weight, None)
1371
+ else:
1372
+ self.lm_head.weight.data = (
1373
+ self.lm_head.weight.data.float()
1374
+ / (torch.norm(self.lm_head.weight.data.float(), p=2, dim=0, keepdim=True) + 1e-7)
1375
+ ).to(hidden_states.dtype)
1376
+ logits = F.linear(hidden_states, self.lm_head.weight.data, None)
1377
+ self.norm_head = False
1378
+ else:
1379
+ logits = self.lm_head(hidden_states)
1380
+ return logits
1381
+
1382
  @add_start_docstrings_to_model_forward(BAILINGMOE_INPUTS_DOCSTRING)
1383
  @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1384
  def forward(
 
1446
 
1447
  hidden_states = outputs[0]
1448
 
1449
+ logits = self.compute_logit(hidden_states=hidden_states)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1450
  logits = logits.float()
1451
 
1452
  loss = None
 
1498
 
1499
  # Keep only the unprocessed tokens:
1500
  # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1501
+ # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as input)
 
1502
  if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1503
  input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1504
  # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard