MoE-Plus-Plus-7B / moe_plus_plus_layer.py
Chat-UniVi's picture
Upload 9 files
5d71f88 verified
import typing
from collections.abc import Callable
from collections import defaultdict
from typing import Any, Dict, TYPE_CHECKING, Optional, Tuple, List
import torch
import copy
from torch import Tensor
from torch.nn import Module
import torch.nn.functional as F
if TYPE_CHECKING:
Base = Module[Tensor]
else:
Base = Module
MOE_TOP_K = 2
Constant = 2
class CopyExpert(torch.nn.Module):
def __init__(self, expert):
super(CopyExpert, self).__init__()
pass
def forward(self, inputs):
return inputs
class ZeroExpert(torch.nn.Module):
def __init__(self, expert):
super(ZeroExpert, self).__init__()
pass
def forward(self, inputs):
return torch.zeros_like(inputs).to(inputs.dtype).to(inputs.device)
class ConstantExpert(torch.nn.Module):
def __init__(self, expert):
super(ConstantExpert, self).__init__()
self.constant = torch.nn.Parameter(
torch.empty((expert.hidden_size)))
torch.nn.init.normal_(self.constant)
self.wg = torch.nn.Linear(expert.hidden_size, 2, bias=False)
self.softmax = torch.nn.Softmax(dim=-1)
def forward(self, inputs):
# print(inputs.size())
weight = self.wg(inputs)
weight = self.softmax(weight)
return torch.einsum('b,bd->bd', [weight[:, 0].type_as(inputs), inputs]) + torch.einsum(
'b,d->bd', [weight[:, 1].type_as(inputs), self.constant.type_as(inputs)])
def gating(logits: Tensor, moe_use_mixtral_gating=False, moe_use_logits_norm=False, moe_gate_norm_std=1.0) -> Dict[int, List[Tuple[int, float]]]:
# gates shape [num_tokens, num_experts]
num_experts = logits.size(1)
if moe_use_mixtral_gating:
if moe_use_logits_norm:
target_std = moe_gate_norm_std
logits_std = logits.std(dim=1, keepdim=True)
logits = logits / (logits_std / target_std)
gates, indices = torch.topk(logits, k=MOE_TOP_K, dim=1)
gates = F.softmax(gates, dim=1)
else:
target_std = moe_gate_norm_std
if moe_use_logits_norm:
logits_std = logits.std(dim=1, keepdim=True)
gates = F.softmax(logits / (logits_std / target_std), dim=1)
else:
gates = F.softmax(logits, dim=1)
# gates shape [num_tokens, MOE_TOP_K]
# indices shape [num_tokens, MOE_TOP_K]
gates, indices = torch.topk(gates, k=MOE_TOP_K, dim=1)
gates = torch.where(indices==(num_experts-1), torch.zeros_like(gates).to(gates.dtype).to(gates.device), gates)
gates /= torch.sum(gates, dim=1, keepdim=True)
expert_info = defaultdict(list)
for expert_id in range(num_experts):
token_ids, score_ids = torch.nonzero(indices == expert_id, as_tuple=True)
expert_info[expert_id] = [token_ids, gates[token_ids, score_ids]]
return expert_info
class Router(Module):
def __init__(self,
model_dim: int,
num_experts: int,
moe_use_mixtral_gating: bool,
moe_2layer_gate: bool,
moe_use_logits_norm: bool,
moe_gate_norm_std: float,
) -> None:
super().__init__()
if moe_2layer_gate:
self.wg = torch.nn.Sequential(
torch.nn.Linear(model_dim, num_experts * 8, bias=False).float(),
torch.nn.Tanh(),
torch.nn.Linear(num_experts * 8, num_experts, bias=False).float()).float()
else:
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float()
self.gate_map = torch.nn.Linear(num_experts, num_experts, bias=False)
self.gate = gating
self.moe_use_mixtral_gating = moe_use_mixtral_gating
self.moe_use_logits_norm = moe_use_logits_norm
self.moe_gate_norm_std = moe_gate_norm_std
def forward(self, input: torch.Tensor, gate_residual=None) -> Dict[int, List[Tuple[int, float]]]:
if isinstance(self.wg, torch.nn.Linear):
if self.wg.weight.dtype != torch.float32:
self.wg = self.wg.float()
setattr(self.wg.weight, 'router', True)
else:
if self.wg[0].weight.dtype != torch.float32:
self.wg = self.wg.float()
setattr(self.wg[0].weight, "router", True)
setattr(self.wg[2].weight, "router", True)
input_fp32 = input.float()
logits = self.wg(input_fp32)
if gate_residual is not None:
gate_residual = self.gate_map(gate_residual.to(self.gate_map.weight.dtype))
logits += gate_residual
gate_output = self.gate(logits, self.moe_use_mixtral_gating, self.moe_use_logits_norm, self.moe_gate_norm_std)
return gate_output, logits
class Experts(torch.nn.Module):
def __init__(self, expert, num_local_experts=1):
super(Experts, self).__init__()
self.experts = torch.nn.ModuleList(
[copy.deepcopy(expert) for _ in range(num_local_experts - 2 - Constant)] +
[ConstantExpert(expert) for _ in range(Constant)] +
[CopyExpert(expert), ZeroExpert(expert)])
def forward(self, inputs):
raise NotImplementedError
class MOELayer(Base):
def __init__(self,
gate: Module,
experts: Module,
ep_size,
num_local_experts: int,
moe_use_mixtral_gating: bool,
moe_feature_no_mul_topk: bool) -> None:
super().__init__()
self.gate = gate
self.experts = experts
self.ep_size = ep_size
self.num_local_experts = num_local_experts
self.moe_use_mixtral_gating = moe_use_mixtral_gating
self.moe_feature_no_mul_topk = moe_feature_no_mul_topk
def forward(self, *input: Tensor, gate_residual=None, **kwargs: Any) -> Tensor:
d_model = input[0].shape[-1]
reshaped_input = input[0].reshape(-1, d_model)
output = torch.zeros_like(reshaped_input)
expert_info, gate_residual = self.gate(reshaped_input, gate_residual)
if not (self.moe_use_mixtral_gating or self.moe_feature_no_mul_topk):
reshaped_input *= MOE_TOP_K
for expert, token_indices_and_gates in expert_info.items():
indices, gating = token_indices_and_gates
gating = gating.unsqueeze(-1)
tokens = reshaped_input.index_select(dim=0, index=indices)
expert_output = self.experts.experts[expert](tokens)
expert_output *= gating
output.index_add_(dim=0, index=indices, source=expert_output)
output = output.reshape(input[0].shape)
return output, gate_residual
class MOE(torch.nn.Module):
def __init__(self,
hidden_size,
expert,
num_experts=1,
ep_size=1,
moe_use_mixtral_gating=False,
moe_2layer_gate=True,
moe_use_logits_norm=False,
moe_gate_norm_std=1.0,
moe_feature_no_mul_topk=False):
super(MOE, self).__init__()
self.ep_size = ep_size
self.num_experts = num_experts
self.num_local_experts = num_experts // self.ep_size
self.moe_use_mixtral_gating = moe_use_mixtral_gating
self.moe_2layer_gate = moe_2layer_gate
self.moe_use_logits_norm = moe_use_logits_norm
self.moe_gate_norm_std = moe_gate_norm_std
self.moe_feature_no_mul_topk = moe_feature_no_mul_topk
experts = Experts(expert, self.num_local_experts)
self.moe = MOELayer(Router(hidden_size,
num_experts,
self.moe_use_mixtral_gating,
self.moe_2layer_gate,
self.moe_use_logits_norm,
self.moe_gate_norm_std),
experts,
self.ep_size,
self.num_local_experts,
self.moe_use_mixtral_gating,
self.moe_feature_no_mul_topk,
)
def forward(self, hidden_states, used_token=None, gate_residual=None):
output, gate_residual = self.moe(hidden_states, used_token, gate_residual=gate_residual)
return output, gate_residual