SFLY5 commited on
Commit
c45abda
·
verified ·
1 Parent(s): 84c1c32

Update modeling_ernie_45t_vl.py

Browse files
Files changed (1) hide show
  1. modeling_ernie_45t_vl.py +5 -1
modeling_ernie_45t_vl.py CHANGED
@@ -1635,7 +1635,11 @@ class MOELayer(nn.Module):
1635
  S, H = x.shape
1636
  E = gate_logits.shape[1]
1637
  device = x.device
1638
- topk_prob, topk_idx = torch.topk(gate_logits, k, dim=-1) # [S, k]
 
 
 
 
1639
  combine_weights = topk_prob # [S, k]
1640
  expert_id = topk_idx # [S, k]
1641
  y = x.new_zeros((E, capacity, H)) # [E, C, H]
 
1635
  S, H = x.shape
1636
  E = gate_logits.shape[1]
1637
  device = x.device
1638
+ if self.use_correction_bias:
1639
+ _, topk_idx = torch.topk(gate_logits + self.moe_statics.e_score_correction_bias[0].detach().to(gate_logits.device), k, dim=-1)
1640
+ topk_prob = torch.gather(gate_logits, dim=1, index=topk_idx) # [Seq, k]
1641
+ else:
1642
+ topk_prob, topk_idx = torch.topk(gate_logits, k, dim=-1) # [S, k]
1643
  combine_weights = topk_prob # [S, k]
1644
  expert_id = topk_idx # [S, k]
1645
  y = x.new_zeros((E, capacity, H)) # [E, C, H]