|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass |
|
from typing import Optional |
|
|
|
from torch import nn |
|
from torchtitan.components.tokenizer import Tokenizer |
|
from torchtitan.config_manager import JobConfig |
|
|
|
from torchtitan.protocols.train_spec import BaseModelArgs |
|
from torchtitan.tools.logging import logger |
|
|
|
|
|
@dataclass |
|
class TransformerModelArgs(BaseModelArgs): |
|
dim: int = 4096 |
|
n_layers: int = 32 |
|
n_heads: int = 32 |
|
n_kv_heads: Optional[int] = None |
|
vocab_size: int = -1 |
|
multiple_of: int = 256 |
|
ffn_dim_multiplier: Optional[float] = None |
|
norm_eps: float = 1e-5 |
|
rope_theta: float = 10000 |
|
|
|
max_seq_len: int = 2048 |
|
|
|
|
|
depth_init: bool = True |
|
norm_type: str = "rmsnorm" |
|
|
|
use_flex_attn: bool = False |
|
attn_mask_type: str = "causal" |
|
eos_id: int = 0 |
|
|
|
|
|
moe_enabled: bool = True |
|
num_experts: int = 8 |
|
use_shared_expert: bool = True |
|
auto_scale_hidden_dim: bool = True |
|
|
|
interleave_moe_layer_step: int = 2 |
|
|
|
top_k: int = 1 |
|
use_grouped_mm: bool = True |
|
|
|
def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None: |
|
self.norm_type = job_config.model.norm_type |
|
self.vocab_size = tokenizer.n_words |
|
self.max_seq_len = job_config.training.seq_len |
|
self.use_flex_attn = job_config.model.use_flex_attn |
|
|
|
def get_nparams_and_flops( |
|
self, model: nn.Module, seq_len: int |
|
) -> tuple[int, float]: |
|
nparams_embedding = 0 |
|
nparams_moe_router = 0 |
|
nparams_shared_expert = 0 |
|
nparams_experts = 0 |
|
nparams_dense = 0 |
|
|
|
for name, p in model.named_parameters(): |
|
if "embedding" in name: |
|
nparams_embedding += p.numel() |
|
nparams_dense += p.numel() |
|
elif "moe.shared_expert" in name: |
|
nparams_shared_expert += p.numel() |
|
elif "moe.router" in name: |
|
nparams_moe_router += p.numel() |
|
elif "moe.experts" in name: |
|
nparams_experts += p.numel() |
|
else: |
|
nparams_dense += p.numel() |
|
|
|
nparams_sparse = nparams_moe_router + nparams_shared_expert + nparams_experts |
|
nparams = nparams_dense + nparams_sparse |
|
nparams_sparse_active = ( |
|
nparams_moe_router |
|
+ nparams_shared_expert |
|
+ nparams_experts * self.top_k // self.num_experts |
|
) |
|
|
|
logger.info( |
|
f"Total parameter count: dense {nparams_dense:,}, " |
|
f"sparse {nparams_sparse:,}, active {nparams_dense + nparams_sparse_active:,}" |
|
) |
|
|
|
l, h, q, t = ( |
|
self.n_layers, |
|
self.n_heads, |
|
self.dim // self.n_heads, |
|
seq_len, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
num_flops_per_token = ( |
|
6 * (nparams_dense - nparams_embedding + nparams_sparse_active) |
|
+ 12 * l * h * q * t |
|
) |
|
|
|
return nparams, num_flops_per_token |
|
|