Update modeling_ernie_45t_vl.py
Browse files- modeling_ernie_45t_vl.py +5 -1
modeling_ernie_45t_vl.py
CHANGED
@@ -1635,7 +1635,11 @@ class MOELayer(nn.Module):
|
|
1635 |
S, H = x.shape
|
1636 |
E = gate_logits.shape[1]
|
1637 |
device = x.device
|
1638 |
-
|
|
|
|
|
|
|
|
|
1639 |
combine_weights = topk_prob # [S, k]
|
1640 |
expert_id = topk_idx # [S, k]
|
1641 |
y = x.new_zeros((E, capacity, H)) # [E, C, H]
|
|
|
1635 |
S, H = x.shape
|
1636 |
E = gate_logits.shape[1]
|
1637 |
device = x.device
|
1638 |
+
if self.use_correction_bias:
|
1639 |
+
_, topk_idx = torch.topk(gate_logits + self.moe_statics.e_score_correction_bias[0].detach().to(gate_logits.device), k, dim=-1)
|
1640 |
+
topk_prob = torch.gather(gate_logits, dim=1, index=topk_idx) # [Seq, k]
|
1641 |
+
else:
|
1642 |
+
topk_prob, topk_idx = torch.topk(gate_logits, k, dim=-1) # [S, k]
|
1643 |
combine_weights = topk_prob # [S, k]
|
1644 |
expert_id = topk_idx # [S, k]
|
1645 |
y = x.new_zeros((E, capacity, H)) # [E, C, H]
|