|
import typing |
|
from collections.abc import Callable |
|
from collections import defaultdict |
|
from typing import Any, Dict, TYPE_CHECKING, Optional, Tuple, List |
|
|
|
import torch |
|
import copy |
|
|
|
from torch import Tensor |
|
from torch.nn import Module |
|
import torch.nn.functional as F |
|
|
|
if TYPE_CHECKING: |
|
Base = Module[Tensor] |
|
else: |
|
Base = Module |
|
|
|
|
|
MOE_TOP_K = 2 |
|
Constant = 2 |
|
|
|
|
|
class CopyExpert(torch.nn.Module): |
|
def __init__(self, expert): |
|
super(CopyExpert, self).__init__() |
|
pass |
|
|
|
def forward(self, inputs): |
|
return inputs |
|
|
|
|
|
class ZeroExpert(torch.nn.Module): |
|
def __init__(self, expert): |
|
super(ZeroExpert, self).__init__() |
|
pass |
|
|
|
def forward(self, inputs): |
|
return torch.zeros_like(inputs).to(inputs.dtype).to(inputs.device) |
|
|
|
|
|
class ConstantExpert(torch.nn.Module): |
|
def __init__(self, expert): |
|
super(ConstantExpert, self).__init__() |
|
self.constant = torch.nn.Parameter( |
|
torch.empty((expert.hidden_size))) |
|
torch.nn.init.normal_(self.constant) |
|
|
|
self.wg = torch.nn.Linear(expert.hidden_size, 2, bias=False) |
|
self.softmax = torch.nn.Softmax(dim=-1) |
|
|
|
def forward(self, inputs): |
|
|
|
weight = self.wg(inputs) |
|
weight = self.softmax(weight) |
|
return torch.einsum('b,bd->bd', [weight[:, 0].type_as(inputs), inputs]) + torch.einsum( |
|
'b,d->bd', [weight[:, 1].type_as(inputs), self.constant.type_as(inputs)]) |
|
|
|
|
|
def gating(logits: Tensor, moe_use_mixtral_gating=False, moe_use_logits_norm=False, moe_gate_norm_std=1.0) -> Dict[int, List[Tuple[int, float]]]: |
|
|
|
num_experts = logits.size(1) |
|
if moe_use_mixtral_gating: |
|
if moe_use_logits_norm: |
|
target_std = moe_gate_norm_std |
|
logits_std = logits.std(dim=1, keepdim=True) |
|
logits = logits / (logits_std / target_std) |
|
gates, indices = torch.topk(logits, k=MOE_TOP_K, dim=1) |
|
gates = F.softmax(gates, dim=1) |
|
else: |
|
target_std = moe_gate_norm_std |
|
if moe_use_logits_norm: |
|
logits_std = logits.std(dim=1, keepdim=True) |
|
gates = F.softmax(logits / (logits_std / target_std), dim=1) |
|
else: |
|
gates = F.softmax(logits, dim=1) |
|
|
|
|
|
gates, indices = torch.topk(gates, k=MOE_TOP_K, dim=1) |
|
gates = torch.where(indices==(num_experts-1), torch.zeros_like(gates).to(gates.dtype).to(gates.device), gates) |
|
gates /= torch.sum(gates, dim=1, keepdim=True) |
|
|
|
expert_info = defaultdict(list) |
|
for expert_id in range(num_experts): |
|
token_ids, score_ids = torch.nonzero(indices == expert_id, as_tuple=True) |
|
expert_info[expert_id] = [token_ids, gates[token_ids, score_ids]] |
|
|
|
return expert_info |
|
|
|
|
|
class Router(Module): |
|
def __init__(self, |
|
model_dim: int, |
|
num_experts: int, |
|
moe_use_mixtral_gating: bool, |
|
moe_2layer_gate: bool, |
|
moe_use_logits_norm: bool, |
|
moe_gate_norm_std: float, |
|
) -> None: |
|
super().__init__() |
|
|
|
if moe_2layer_gate: |
|
self.wg = torch.nn.Sequential( |
|
torch.nn.Linear(model_dim, num_experts * 8, bias=False).float(), |
|
torch.nn.Tanh(), |
|
torch.nn.Linear(num_experts * 8, num_experts, bias=False).float()).float() |
|
else: |
|
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float() |
|
|
|
self.gate_map = torch.nn.Linear(num_experts, num_experts, bias=False) |
|
|
|
self.gate = gating |
|
self.moe_use_mixtral_gating = moe_use_mixtral_gating |
|
self.moe_use_logits_norm = moe_use_logits_norm |
|
self.moe_gate_norm_std = moe_gate_norm_std |
|
|
|
def forward(self, input: torch.Tensor, gate_residual=None) -> Dict[int, List[Tuple[int, float]]]: |
|
if isinstance(self.wg, torch.nn.Linear): |
|
if self.wg.weight.dtype != torch.float32: |
|
self.wg = self.wg.float() |
|
setattr(self.wg.weight, 'router', True) |
|
else: |
|
if self.wg[0].weight.dtype != torch.float32: |
|
self.wg = self.wg.float() |
|
setattr(self.wg[0].weight, "router", True) |
|
setattr(self.wg[2].weight, "router", True) |
|
input_fp32 = input.float() |
|
logits = self.wg(input_fp32) |
|
|
|
if gate_residual is not None: |
|
gate_residual = self.gate_map(gate_residual.to(self.gate_map.weight.dtype)) |
|
logits += gate_residual |
|
|
|
gate_output = self.gate(logits, self.moe_use_mixtral_gating, self.moe_use_logits_norm, self.moe_gate_norm_std) |
|
|
|
return gate_output, logits |
|
|
|
|
|
class Experts(torch.nn.Module): |
|
def __init__(self, expert, num_local_experts=1): |
|
super(Experts, self).__init__() |
|
|
|
self.experts = torch.nn.ModuleList( |
|
[copy.deepcopy(expert) for _ in range(num_local_experts - 2 - Constant)] + |
|
[ConstantExpert(expert) for _ in range(Constant)] + |
|
[CopyExpert(expert), ZeroExpert(expert)]) |
|
|
|
def forward(self, inputs): |
|
raise NotImplementedError |
|
|
|
|
|
class MOELayer(Base): |
|
def __init__(self, |
|
gate: Module, |
|
experts: Module, |
|
ep_size, |
|
num_local_experts: int, |
|
moe_use_mixtral_gating: bool, |
|
moe_feature_no_mul_topk: bool) -> None: |
|
super().__init__() |
|
self.gate = gate |
|
self.experts = experts |
|
self.ep_size = ep_size |
|
self.num_local_experts = num_local_experts |
|
self.moe_use_mixtral_gating = moe_use_mixtral_gating |
|
self.moe_feature_no_mul_topk = moe_feature_no_mul_topk |
|
|
|
def forward(self, *input: Tensor, gate_residual=None, **kwargs: Any) -> Tensor: |
|
d_model = input[0].shape[-1] |
|
reshaped_input = input[0].reshape(-1, d_model) |
|
output = torch.zeros_like(reshaped_input) |
|
expert_info, gate_residual = self.gate(reshaped_input, gate_residual) |
|
if not (self.moe_use_mixtral_gating or self.moe_feature_no_mul_topk): |
|
reshaped_input *= MOE_TOP_K |
|
for expert, token_indices_and_gates in expert_info.items(): |
|
indices, gating = token_indices_and_gates |
|
gating = gating.unsqueeze(-1) |
|
tokens = reshaped_input.index_select(dim=0, index=indices) |
|
expert_output = self.experts.experts[expert](tokens) |
|
expert_output *= gating |
|
output.index_add_(dim=0, index=indices, source=expert_output) |
|
output = output.reshape(input[0].shape) |
|
|
|
return output, gate_residual |
|
|
|
|
|
class MOE(torch.nn.Module): |
|
def __init__(self, |
|
hidden_size, |
|
expert, |
|
num_experts=1, |
|
ep_size=1, |
|
moe_use_mixtral_gating=False, |
|
moe_2layer_gate=True, |
|
moe_use_logits_norm=False, |
|
moe_gate_norm_std=1.0, |
|
moe_feature_no_mul_topk=False): |
|
super(MOE, self).__init__() |
|
|
|
self.ep_size = ep_size |
|
self.num_experts = num_experts |
|
self.num_local_experts = num_experts // self.ep_size |
|
self.moe_use_mixtral_gating = moe_use_mixtral_gating |
|
self.moe_2layer_gate = moe_2layer_gate |
|
self.moe_use_logits_norm = moe_use_logits_norm |
|
self.moe_gate_norm_std = moe_gate_norm_std |
|
self.moe_feature_no_mul_topk = moe_feature_no_mul_topk |
|
|
|
experts = Experts(expert, self.num_local_experts) |
|
self.moe = MOELayer(Router(hidden_size, |
|
num_experts, |
|
self.moe_use_mixtral_gating, |
|
self.moe_2layer_gate, |
|
self.moe_use_logits_norm, |
|
self.moe_gate_norm_std), |
|
experts, |
|
self.ep_size, |
|
self.num_local_experts, |
|
self.moe_use_mixtral_gating, |
|
self.moe_feature_no_mul_topk, |
|
) |
|
|
|
def forward(self, hidden_states, used_token=None, gate_residual=None): |
|
output, gate_residual = self.moe(hidden_states, used_token, gate_residual=gate_residual) |
|
return output, gate_residual |
|
|