|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
|
|
from .args import TransformerModelArgs |
|
|
|
|
|
class GroupedExperts(nn.Module): |
|
def __init__( |
|
self, |
|
dim: int, |
|
hidden_dim: int, |
|
num_experts: int, |
|
use_grouped_mm: bool, |
|
): |
|
super().__init__() |
|
self.num_experts = num_experts |
|
self.w1 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim)) |
|
self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim)) |
|
self.w3 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim)) |
|
self.use_grouped_mm = use_grouped_mm |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
num_local_tokens_per_expert: torch.Tensor | list[int] | None = None, |
|
) -> torch.Tensor: |
|
|
|
|
|
if not self.use_grouped_mm: |
|
if num_local_tokens_per_expert is not None: |
|
|
|
|
|
x = torch.split( |
|
x, |
|
split_size_or_sections=num_local_tokens_per_expert, |
|
dim=0, |
|
) |
|
out_experts_splits = [] |
|
for expert_idx, x_expert in enumerate(x): |
|
w1, w2, w3 = ( |
|
self.w1[expert_idx], |
|
self.w2[expert_idx], |
|
self.w3[expert_idx], |
|
) |
|
h = F.silu(torch.matmul(x_expert, w1)) |
|
h = h * torch.matmul(x_expert, w3) |
|
h = torch.matmul(h, w2) |
|
|
|
out_experts_splits.append(h) |
|
out = torch.cat(out_experts_splits, dim=0) |
|
else: |
|
|
|
h = F.silu(torch.bmm(x, self.w1)) |
|
h = h * torch.bmm(x, self.w3) |
|
|
|
out = torch.bmm(h, self.w2) |
|
|
|
return out |
|
|
|
|
|
if num_local_tokens_per_expert is not None: |
|
|
|
|
|
|
|
offsets = torch.cumsum( |
|
num_local_tokens_per_expert, dim=0, dtype=torch.int32 |
|
) |
|
|
|
assert x.dim() == 2 |
|
else: |
|
offsets = None |
|
|
|
assert x.dim() == 3 |
|
|
|
h = F.silu(torch._grouped_mm(x, self.w1, offs=offsets)) |
|
h = h * torch._grouped_mm(x, self.w3, offs=offsets) |
|
out = torch._grouped_mm(h, self.w2, offs=offsets) |
|
|
|
return out |
|
|
|
def init_weights(self, init_std: float): |
|
nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02) |
|
nn.init.trunc_normal_(self.w2, mean=0.0, std=init_std) |
|
nn.init.trunc_normal_(self.w3, mean=0.0, std=init_std) |
|
|
|
|
|
class TokenChoiceTopKRouter(nn.Module): |
|
"""This class implements token-choice routing. In token-choice top-K routing, each token is |
|
routed to top K experts based on the router scores. |
|
|
|
Args: |
|
gate (nn.Module): Gate module to calculate the scores, typically nn.Linear(dim, num_experts). |
|
dim (int): Dimension of input tokens. |
|
num_experts (int): Number of experts in each moe layer. |
|
top_k (int): Number of experts each token will be routed to in token-choice routing. |
|
use_sigmoid (bool): Whether to use sigmoid or softmax for router scores. Default is False. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
num_experts: int, |
|
top_k: int, |
|
use_sigmoid: bool = False, |
|
): |
|
super().__init__() |
|
self.gate = nn.Linear(dim, num_experts, bias=False) |
|
self.num_experts = num_experts |
|
self.top_k = top_k |
|
self.use_sigmoid = use_sigmoid |
|
|
|
def forward( |
|
self, x: torch.Tensor |
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
""" |
|
Args: |
|
x (torch.Tensor): Input tensor with shape ``(bs*slen, dim)``. |
|
|
|
Returns: |
|
routed_input (torch.Tensor): |
|
Tokens grouped together by experts indices with shape ``(bs*slen*top_k,)``. |
|
token_indices (torch.Tensor): |
|
Token indices for routed_input with shape ``(bs*slen*top_k,)``. |
|
num_local_tokens_per_expert (torch.Tensor): |
|
Number of tokens assigned to each expert with shape ``(num_experts,)``. |
|
""" |
|
|
|
scores = self.gate(x) |
|
|
|
|
|
if self.use_sigmoid: |
|
scores = torch.sigmoid(scores.to(torch.float32)).to(x.dtype) |
|
else: |
|
scores = F.softmax(scores.to(torch.float32), dim=1).to(x.dtype) |
|
|
|
|
|
top_scores, selected_experts_indices = torch.topk(scores, k=self.top_k, dim=1) |
|
|
|
|
|
|
|
num_local_tokens_per_expert = torch.histc( |
|
selected_experts_indices.view(-1), |
|
bins=self.num_experts, |
|
min=0, |
|
max=self.num_experts, |
|
) |
|
|
|
token_indices_experts_sorted = torch.argsort( |
|
selected_experts_indices.view(-1), stable=True |
|
) |
|
top_scores = top_scores.view(-1)[token_indices_experts_sorted] |
|
token_indices_experts_sorted = token_indices_experts_sorted // self.top_k |
|
|
|
return top_scores, token_indices_experts_sorted, num_local_tokens_per_expert |
|
|
|
def init_weights(self, init_std: float): |
|
nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std) |
|
|
|
|
|
|
|
class MoE(nn.Module): |
|
def __init__(self, model_args: TransformerModelArgs): |
|
super().__init__() |
|
dim = model_args.dim |
|
hidden_dim = 4 * model_args.dim |
|
ffn_dim_multiplier = model_args.ffn_dim_multiplier |
|
hidden_dim = int(2 * hidden_dim / 3) |
|
if ffn_dim_multiplier is not None: |
|
hidden_dim = int(ffn_dim_multiplier * hidden_dim) |
|
|
|
num_experts = model_args.num_experts |
|
|
|
hidden_dim_denom = 1 |
|
if model_args.auto_scale_hidden_dim: |
|
hidden_dim_denom = model_args.top_k + int(model_args.use_shared_expert) |
|
|
|
if model_args.auto_scale_hidden_dim: |
|
hidden_dim = int(hidden_dim / hidden_dim_denom) |
|
hidden_dim += -hidden_dim % model_args.multiple_of |
|
|
|
self.use_grouped_mm = model_args.use_grouped_mm |
|
self.experts = GroupedExperts( |
|
dim=dim, |
|
hidden_dim=hidden_dim, |
|
num_experts=num_experts, |
|
use_grouped_mm=self.use_grouped_mm, |
|
) |
|
self.router = TokenChoiceTopKRouter( |
|
dim=dim, num_experts=num_experts, top_k=model_args.top_k |
|
) |
|
self.shared_expert = ( |
|
GroupedExperts( |
|
dim=dim, |
|
hidden_dim=hidden_dim, |
|
num_experts=1, |
|
use_grouped_mm=self.use_grouped_mm, |
|
) |
|
if model_args.use_shared_expert |
|
else None |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Args: |
|
x (torch.Tensor): Input tensor with shape ``(bs, slen, dim)``. |
|
|
|
Returns: |
|
out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``. |
|
""" |
|
bs, slen, dim = x.shape |
|
|
|
|
|
( |
|
top_scores, |
|
token_indices, |
|
num_local_tokens_per_expert, |
|
) = self.router(x.reshape(bs * slen, dim)) |
|
|
|
|
|
token_indices = token_indices.reshape(-1, 1).expand(-1, dim) |
|
|
|
|
|
routed_input = torch.gather( |
|
x.view(-1, dim), |
|
dim=0, |
|
index=token_indices, |
|
) |
|
routed_input = routed_input * top_scores.reshape(-1, 1) |
|
|
|
if self.use_grouped_mm: |
|
|
|
|
|
|
|
|
|
from torchtitan.experiments.kernels.moe.indices import ( |
|
generate_permute_indices, |
|
) |
|
|
|
ALIGN_SIZE_M = 16 |
|
|
|
with torch.no_grad(): |
|
permuted_indices, m_sizes = generate_permute_indices( |
|
num_local_tokens_per_expert, |
|
self.experts.num_experts, |
|
1, |
|
token_indices.shape[0] + self.experts.num_experts * ALIGN_SIZE_M, |
|
ALIGN_SIZE_M, |
|
) |
|
num_local_tokens_per_expert = m_sizes |
|
token_indices = torch.vstack( |
|
(token_indices, token_indices.new_zeros((dim))) |
|
) |
|
token_indices = token_indices[permuted_indices, :] |
|
routed_input = torch.vstack((routed_input, routed_input.new_zeros((dim)))) |
|
routed_input = routed_input[permuted_indices, :] |
|
else: |
|
|
|
num_local_tokens_per_expert = num_local_tokens_per_expert.tolist() |
|
|
|
|
|
routed_output = self.experts(routed_input, num_local_tokens_per_expert) |
|
|
|
|
|
if self.shared_expert is not None: |
|
out = self.shared_expert(x.reshape(1, bs * slen, dim)).reshape( |
|
bs * slen, dim |
|
) |
|
else: |
|
out = torch.zeros_like(x.reshape(bs * slen, dim)) |
|
|
|
out = out.scatter_add(dim=0, index=token_indices, src=routed_output) |
|
out = out.reshape(bs, slen, dim) |
|
return out |
|
|
|
def init_weights(self, init_std: float): |
|
self.experts.init_weights(init_std) |
|
self.router.init_weights(init_std) |
|
if self.shared_expert is not None: |
|
self.shared_expert.init_weights(init_std) |
|
|