Qwen
/

Text Generation
Transformers
Safetensors
qwen3_moe
conversational

为什么不在MoE模块的forward函数中区分sequence_length等于1和大于1的情况呢?

#26
by nifeng154 - opened

我在moe模块中使用
if sequence_length == 1:
for expert_idx in selected_experts[0]:
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])

                current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)

                current_hidden_states = expert_layer(hidden_states) * routing_weights[top_x, idx, None]

                final_logits.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
        else:
            for expert_idx in range(self.num_experts):
                expert_layer = self.experts[expert_idx]
                idx, top_x = torch.where(expert_mask[expert_idx])

                # Index the correct hidden states and compute the expert hidden state for
                # the current expert. We need to make sure to multiply the output hidden
                # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
                current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
                current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]

                # However `index_add_` only support torch tensors for indexing so we'll use
                # the `top_x` tensor here.
                final_logits.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))

得到了显著性的生成速度提升

nifeng154 changed discussion status to closed

Sign up or log in to comment