# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the BSD-style license found in the | |
# LICENSE file in the root directory of this source tree. | |
from torch import nn | |
from torchtitan.tools.logging import logger | |
def get_nparams_and_flops(model: nn.Module, model_config, seq_len: int) -> tuple[int, int]: | |
nparams = sum(p.numel() for p in model.parameters()) | |
nparams_embedding = sum( | |
sum(p.numel() for p in m.parameters()) | |
for m in model.children() | |
if isinstance(m, nn.Embedding) | |
) | |
if hasattr(model_config, "num_heads"): | |
num_heads = model_config.num_heads | |
elif hasattr(model_config, "num_attention_heads"): | |
num_heads = model_config.num_attention_heads | |
else: | |
num_heads = 1 | |
logger.warning("num_heads not found in model_config, defaulting to 1. ") | |
l, h, q, t = ( | |
model_config.num_hidden_layers, | |
num_heads, | |
model_config.hidden_size // num_heads, | |
seq_len, | |
) | |
# Reasoning behind the factor of 12 for the self-attention part of the formula: | |
# 1. each self-attention has 2 matmul in the forward and 4 in the backward (6) | |
# 2. the flash attention does 1 more matmul recomputation in the backward | |
# but recomputation should not be counted in calculating MFU (+0) | |
# 3. each matmul performs 1 multiplication and 1 addition (*2) | |
# 4. we follow the convention and do not account for sparsity in causal attention | |
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t | |
return nparams, num_flops_per_token | |