Update modeling_bailing_moe.py
Browse files- modeling_bailing_moe.py +127 -22
modeling_bailing_moe.py
CHANGED
@@ -207,6 +207,90 @@ class BailingMoeDynamicNTKScalingRotaryEmbedding(BailingMoeRotaryEmbedding):
|
|
207 |
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
208 |
|
209 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
211 |
def rotate_half(x):
|
212 |
"""Rotates half the hidden dims of the input."""
|
@@ -278,7 +362,7 @@ class BailingMoeGate(nn.Module):
|
|
278 |
|
279 |
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
280 |
|
281 |
-
def forward(self, hidden_states):
|
282 |
bsz, seq_len, h = hidden_states.shape
|
283 |
# compute gating score
|
284 |
hidden_states = hidden_states.view(-1, h)
|
@@ -286,7 +370,7 @@ class BailingMoeGate(nn.Module):
|
|
286 |
scores = logits.softmax(dim=-1, dtype=torch.float32)
|
287 |
|
288 |
# select top-k experts
|
289 |
-
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=
|
290 |
|
291 |
# norm gate to sum 1
|
292 |
if self.top_k > 1 and self.norm_topk_prob:
|
@@ -305,7 +389,7 @@ class BailingMoeSparseMoeBlock(nn.Module):
|
|
305 |
super().__init__()
|
306 |
self.config = config
|
307 |
self.num_experts_per_tok = config.num_experts_per_tok
|
308 |
-
self.
|
309 |
self.gate = BailingMoeGate(config)
|
310 |
if config.num_shared_experts is not None:
|
311 |
self.shared_experts = BailingMoeMLP(
|
@@ -313,7 +397,7 @@ class BailingMoeSparseMoeBlock(nn.Module):
|
|
313 |
)
|
314 |
|
315 |
def _setup_experts(self):
|
316 |
-
|
317 |
[
|
318 |
BailingMoeMLP(config=self.config, intermediate_size=self.config.moe_intermediate_size)
|
319 |
for _ in range(self.config.num_experts)
|
@@ -443,6 +527,25 @@ class BailingMoeAttention(nn.Module):
|
|
443 |
scaling_factor=scaling_factor,
|
444 |
base=self.rope_theta,
|
445 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
446 |
else:
|
447 |
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
448 |
|
@@ -1258,6 +1361,24 @@ class BailingMoeForCausalLM(BailingMoePreTrainedModel):
|
|
1258 |
def get_decoder(self):
|
1259 |
return self.model
|
1260 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1261 |
@add_start_docstrings_to_model_forward(BAILINGMOE_INPUTS_DOCSTRING)
|
1262 |
@replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1263 |
def forward(
|
@@ -1325,22 +1446,7 @@ class BailingMoeForCausalLM(BailingMoePreTrainedModel):
|
|
1325 |
|
1326 |
hidden_states = outputs[0]
|
1327 |
|
1328 |
-
|
1329 |
-
if self.training:
|
1330 |
-
norm_weight = (
|
1331 |
-
self.lm_head.weight / (torch.norm(self.lm_head.weight, p=2, dim=0, keepdim=True) + 1e-7).detach()
|
1332 |
-
)
|
1333 |
-
logits = F.linear(hidden_states, norm_weight, None)
|
1334 |
-
else:
|
1335 |
-
self.lm_head.weight.data = (
|
1336 |
-
self.lm_head.weight.data.float()
|
1337 |
-
/ (torch.norm(self.lm_head.weight.data.float(), p=2, dim=0, keepdim=True) + 1e-7)
|
1338 |
-
).to(hidden_states.dtype)
|
1339 |
-
logits = F.linear(hidden_states, self.lm_head.weight.data, None)
|
1340 |
-
self.norm_head = False
|
1341 |
-
else:
|
1342 |
-
logits = self.lm_head(hidden_states)
|
1343 |
-
|
1344 |
logits = logits.float()
|
1345 |
|
1346 |
loss = None
|
@@ -1392,8 +1498,7 @@ class BailingMoeForCausalLM(BailingMoePreTrainedModel):
|
|
1392 |
|
1393 |
# Keep only the unprocessed tokens:
|
1394 |
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
1395 |
-
# some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
|
1396 |
-
# input)
|
1397 |
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
1398 |
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
1399 |
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
|
|
207 |
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
208 |
|
209 |
|
210 |
+
# Inverse dim formula to find dim based on number of rotations
|
211 |
+
def yarn_find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
|
212 |
+
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
|
213 |
+
|
214 |
+
|
215 |
+
# Find dim range bounds based on rotations
|
216 |
+
def yarn_find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
|
217 |
+
low = math.floor(yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
|
218 |
+
high = math.ceil(yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
|
219 |
+
return max(low, 0), min(high, dim - 1) # Clamp values just in case
|
220 |
+
|
221 |
+
|
222 |
+
def yarn_get_mscale(scale=1, mscale=1):
|
223 |
+
if scale <= 1:
|
224 |
+
return 1.0
|
225 |
+
return 0.1 * mscale * math.log(scale) + 1.0
|
226 |
+
|
227 |
+
|
228 |
+
def yarn_linear_ramp_mask(min, max, dim):
|
229 |
+
if min == max:
|
230 |
+
max += 0.001 # Prevent singularity
|
231 |
+
|
232 |
+
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
|
233 |
+
ramp_func = torch.clamp(linear_func, 0, 1)
|
234 |
+
return ramp_func
|
235 |
+
|
236 |
+
|
237 |
+
class BailingMoeYarnRotaryEmbedding(BailingMoeRotaryEmbedding):
|
238 |
+
|
239 |
+
def __init__(
|
240 |
+
self,
|
241 |
+
dim,
|
242 |
+
max_position_embeddings=2048,
|
243 |
+
base=10000,
|
244 |
+
device=None,
|
245 |
+
scaling_factor=1.0,
|
246 |
+
original_max_position_embeddings=4096,
|
247 |
+
beta_fast=32,
|
248 |
+
beta_slow=1,
|
249 |
+
mscale=1,
|
250 |
+
mscale_all_dim=0,
|
251 |
+
):
|
252 |
+
self.scaling_factor = scaling_factor
|
253 |
+
self.original_max_position_embeddings = original_max_position_embeddings
|
254 |
+
self.beta_fast = beta_fast
|
255 |
+
self.beta_slow = beta_slow
|
256 |
+
self.mscale = mscale
|
257 |
+
self.mscale_all_dim = mscale_all_dim
|
258 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
259 |
+
|
260 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
261 |
+
self.max_seq_len_cached = seq_len
|
262 |
+
dim = self.dim
|
263 |
+
|
264 |
+
freq_extra = 1.0 / (self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
|
265 |
+
freq_inter = 1.0 / (
|
266 |
+
self.scaling_factor * self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
|
267 |
+
)
|
268 |
+
|
269 |
+
low, high = yarn_find_correction_range(
|
270 |
+
self.beta_fast,
|
271 |
+
self.beta_slow,
|
272 |
+
dim,
|
273 |
+
self.base,
|
274 |
+
self.original_max_position_embeddings,
|
275 |
+
)
|
276 |
+
inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(device=device, dtype=torch.float32)
|
277 |
+
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
|
278 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
279 |
+
|
280 |
+
t = torch.arange(seq_len, device=device, dtype=torch.float32)
|
281 |
+
|
282 |
+
freqs = torch.outer(t, inv_freq)
|
283 |
+
|
284 |
+
_mscale = float(
|
285 |
+
yarn_get_mscale(self.scaling_factor, self.mscale)
|
286 |
+
/ yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
|
287 |
+
)
|
288 |
+
|
289 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
290 |
+
self.register_buffer("cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False)
|
291 |
+
self.register_buffer("sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False)
|
292 |
+
|
293 |
+
|
294 |
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
295 |
def rotate_half(x):
|
296 |
"""Rotates half the hidden dims of the input."""
|
|
|
362 |
|
363 |
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
364 |
|
365 |
+
def forward(self, hidden_states, sort=False):
|
366 |
bsz, seq_len, h = hidden_states.shape
|
367 |
# compute gating score
|
368 |
hidden_states = hidden_states.view(-1, h)
|
|
|
370 |
scores = logits.softmax(dim=-1, dtype=torch.float32)
|
371 |
|
372 |
# select top-k experts
|
373 |
+
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=sort)
|
374 |
|
375 |
# norm gate to sum 1
|
376 |
if self.top_k > 1 and self.norm_topk_prob:
|
|
|
389 |
super().__init__()
|
390 |
self.config = config
|
391 |
self.num_experts_per_tok = config.num_experts_per_tok
|
392 |
+
self._setup_experts()
|
393 |
self.gate = BailingMoeGate(config)
|
394 |
if config.num_shared_experts is not None:
|
395 |
self.shared_experts = BailingMoeMLP(
|
|
|
397 |
)
|
398 |
|
399 |
def _setup_experts(self):
|
400 |
+
self.experts = nn.ModuleList(
|
401 |
[
|
402 |
BailingMoeMLP(config=self.config, intermediate_size=self.config.moe_intermediate_size)
|
403 |
for _ in range(self.config.num_experts)
|
|
|
527 |
scaling_factor=scaling_factor,
|
528 |
base=self.rope_theta,
|
529 |
)
|
530 |
+
elif scaling_type == "yarn":
|
531 |
+
kwargs = {
|
532 |
+
key: self.config.rope_scaling[key]
|
533 |
+
for key in [
|
534 |
+
"original_max_position_embeddings",
|
535 |
+
"beta_fast",
|
536 |
+
"beta_slow",
|
537 |
+
"mscale",
|
538 |
+
"mscale_all_dim",
|
539 |
+
]
|
540 |
+
if key in self.config.rope_scaling
|
541 |
+
}
|
542 |
+
self.rotary_emb = BailingMoeYarnRotaryEmbedding(
|
543 |
+
self.head_dim,
|
544 |
+
max_position_embeddings=self.max_position_embeddings,
|
545 |
+
scaling_factor=scaling_factor,
|
546 |
+
base=self.rope_theta,
|
547 |
+
**kwargs,
|
548 |
+
)
|
549 |
else:
|
550 |
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
551 |
|
|
|
1361 |
def get_decoder(self):
|
1362 |
return self.model
|
1363 |
|
1364 |
+
def compute_logit(self, hidden_states):
|
1365 |
+
if self.norm_head:
|
1366 |
+
if self.training:
|
1367 |
+
norm_weight = (
|
1368 |
+
self.lm_head.weight / (torch.norm(self.lm_head.weight, p=2, dim=0, keepdim=True) + 1e-7).detach()
|
1369 |
+
)
|
1370 |
+
logits = F.linear(hidden_states, norm_weight, None)
|
1371 |
+
else:
|
1372 |
+
self.lm_head.weight.data = (
|
1373 |
+
self.lm_head.weight.data.float()
|
1374 |
+
/ (torch.norm(self.lm_head.weight.data.float(), p=2, dim=0, keepdim=True) + 1e-7)
|
1375 |
+
).to(hidden_states.dtype)
|
1376 |
+
logits = F.linear(hidden_states, self.lm_head.weight.data, None)
|
1377 |
+
self.norm_head = False
|
1378 |
+
else:
|
1379 |
+
logits = self.lm_head(hidden_states)
|
1380 |
+
return logits
|
1381 |
+
|
1382 |
@add_start_docstrings_to_model_forward(BAILINGMOE_INPUTS_DOCSTRING)
|
1383 |
@replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1384 |
def forward(
|
|
|
1446 |
|
1447 |
hidden_states = outputs[0]
|
1448 |
|
1449 |
+
logits = self.compute_logit(hidden_states=hidden_states)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1450 |
logits = logits.float()
|
1451 |
|
1452 |
loss = None
|
|
|
1498 |
|
1499 |
# Keep only the unprocessed tokens:
|
1500 |
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
1501 |
+
# some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as input)
|
|
|
1502 |
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
1503 |
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
1504 |
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|