zaydzuhri's picture
Add files using upload-large-folder tool
75b6530 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import time
from collections import namedtuple
from datetime import datetime
from typing import Any
import torch
from torch.utils.tensorboard import SummaryWriter
from torchtitan.components.lr_scheduler import LRSchedulersContainer
from torchtitan.components.optimizer import OptimizersContainer
from torchtitan.config_manager import JobConfig
from torchtitan.distributed import ParallelDims
from torchtitan.tools import utils
from torchtitan.tools.logging import logger
from torchtitan.tools.utils import Color, device_module, device_type
# named tuple for passing device memory stats for logging
DeviceMemStats = namedtuple(
"DeviceMemStats",
[
"max_active_gib",
"max_active_pct",
"max_reserved_gib",
"max_reserved_pct",
"num_alloc_retries",
"num_ooms",
],
)
class DeviceMemoryMonitor:
def __init__(self, device: str = f"{device_type}:0"):
self.device = torch.device(device) # device object
self.device_name = device_module.get_device_name(self.device)
self.device_index = device_module.current_device()
self.device_capacity = device_module.get_device_properties(
self.device
).total_memory
self.device_capacity_gib = self._to_gib(self.device_capacity)
device_module.reset_peak_memory_stats()
device_module.empty_cache()
def _to_gib(self, memory_in_bytes):
# NOTE: GiB (gibibyte) is 1024, vs GB is 1000
_gib_in_bytes = 1024 * 1024 * 1024
memory_in_gib = memory_in_bytes / _gib_in_bytes
return memory_in_gib
def _to_pct(self, memory):
return 100 * memory / self.device_capacity
def get_peak_stats(self):
device_info = device_module.memory_stats(self.device)
max_active = device_info.get("active_bytes.all.peak", -1)
max_active_gib = self._to_gib(max_active)
max_active_pct = self._to_pct(max_active)
max_reserved = device_info.get("reserved_bytes.all.peak", -1)
max_reserved_gib = self._to_gib(max_reserved)
max_reserved_pct = self._to_pct(max_reserved)
num_retries = device_info.get("num_alloc_retries", -1)
num_ooms = device_info.get("num_ooms", -1)
if num_retries > 0:
logger.warning(
f"{num_retries} {device_type.upper()} memory allocation retries."
)
if num_ooms > 0:
logger.warning(f"{num_ooms} {device_type.upper()} OOM errors thrown.")
return DeviceMemStats(
max_active_gib,
max_active_pct,
max_reserved_gib,
max_reserved_pct,
num_retries,
num_ooms,
)
def reset_peak_stats(self):
device_module.reset_peak_memory_stats()
def build_device_memory_monitor():
device_memory_monitor = DeviceMemoryMonitor(device_type)
logger.info(
f"{device_type.upper()} capacity: {device_memory_monitor.device_name} "
f"with {device_memory_monitor.device_capacity_gib:.2f}GiB memory"
)
return device_memory_monitor
class BaseLogger:
"""Logger that does nothing, used when logging is disabled."""
def log(self, metrics: dict[str, Any], step: int) -> None:
pass
def close(self) -> None:
pass
class TensorBoardLogger(BaseLogger):
"""Logger implementation for TensorBoard."""
def __init__(self, log_dir: str, tag: str | None = None):
self.tag = tag
self.writer = SummaryWriter(log_dir, max_queue=1000)
logger.info(f"TensorBoard logging enabled. Logs will be saved at {log_dir}")
def log(self, metrics: dict[str, Any], step: int) -> None:
for k, v in metrics.items():
tag = k if self.tag is None else f"{self.tag}/{k}"
self.writer.add_scalar(tag, v, step)
def close(self) -> None:
self.writer.close()
class WandBLogger(BaseLogger):
"""Logger implementation for Weights & Biases."""
def __init__(self, log_dir: str, tag: str | None = None):
# Import wandb here to avoid startup import
import wandb
self.wandb = wandb
self.tag = tag
# Create logging directory
os.makedirs(log_dir, exist_ok=True)
self.wandb.init(
project=os.getenv("WANDB_PROJECT", "torchtitan"),
dir=log_dir,
)
logger.info("WandB logging enabled")
def log(self, metrics: dict[str, Any], step: int) -> None:
wandb_metrics = {
(k if self.tag is None else f"{self.tag}/{k}"): v
for k, v in metrics.items()
}
self.wandb.log(wandb_metrics, step=step)
def close(self) -> None:
if self.wandb.run is not None:
self.wandb.finish()
def ensure_pp_loss_visible(
parallel_dims: ParallelDims, job_config: JobConfig, color: Color
) -> None:
"""
Ensures that the loss is visible on the console for pipeline-parallel training.
For pipeline-parallel training, the loss is only visible on the last pipeline stage.
This function checks if the appropriate rank is included in the LOG_RANK environment
variable and warns if it's not.
"""
# V Block Schedules return loss on rank 0
if job_config.parallelism.pipeline_parallel_schedule == "ZBVZeroBubble":
return
# Calculate the rank where loss is visible (first rank of the last pipeline stage)
world_size = parallel_dims.world_size
pp_size = parallel_dims.pp
loss_visible_rank = (world_size // pp_size) * (pp_size - 1)
# Check if the loss-visible rank is included in LOG_RANK environment variable
env_logged_ranks = os.environ.get("LOG_RANK", "").split(",")
if env_logged_ranks == [""]:
env_logged_ranks = []
if str(loss_visible_rank) not in env_logged_ranks:
logger.warning(
f"{color.red}Pipeline Parallel loss is not visible. "
f"Please add {color.yellow}rank {loss_visible_rank}{color.red} "
f"to LOG_RANK environment variable in run_train.sh.{color.reset}"
)
def _get_metrics_rank(
parallel_dims: ParallelDims,
job_config: JobConfig,
) -> int:
"""
Determines which rank should log metrics.
Returns:
int: The rank responsible for logging metrics:
- Rank 0 for non-pipeline-parallel configs
- Rank 0 for pipeline-parallel 'ZBVZeroBubble' schedule
- The first rank of the last pipeline stage for other pipeline-parallel schedules
"""
# Early return for non-pipeline-parallel configurations
if not parallel_dims.pp_enabled:
return 0
# V Block Schedules return loss on rank 0
if job_config.parallelism.pipeline_parallel_schedule == "ZBVZeroBubble":
return 0
# Calculate first rank of the last pipeline stage
world_size = parallel_dims.world_size
pp_size = parallel_dims.pp
return (world_size // pp_size) * (pp_size - 1)
def _build_metric_logger(
job_config: JobConfig, parallel_dims: ParallelDims, tag: str | None = None
) -> BaseLogger:
"""
Build an appropriate metric logger based on configuration.
"""
metrics_config = job_config.metrics
# Log initial config state
logger.debug(
f"Building logger with config: wandb={metrics_config.enable_wandb}, "
f"tensorboard={metrics_config.enable_tensorboard}"
)
# Check if any logging backend is enabled
has_logging_enabled = (
metrics_config.enable_tensorboard or metrics_config.enable_wandb
)
# Determine if this rank should log
should_log = has_logging_enabled
if (not metrics_config.save_for_all_ranks) and should_log:
metrics_rank = _get_metrics_rank(parallel_dims, job_config)
should_log = torch.distributed.get_rank() == metrics_rank
logger.debug(
f"Logging decision: has_logging_enabled={has_logging_enabled}, should_log={should_log}"
)
if not should_log:
logger.debug("Returning BaseLogger due to should_log=False")
return BaseLogger()
# Setup logging directory
dump_dir = job_config.job.dump_folder
base_log_dir = os.path.join(
dump_dir, metrics_config.save_tb_folder, datetime.now().strftime("%Y%m%d-%H%M")
)
if metrics_config.save_for_all_ranks:
base_log_dir = os.path.join(
base_log_dir, f"rank_{torch.distributed.get_rank()}"
)
# Create loggers in priority order
if metrics_config.enable_wandb:
logger.debug("Attempting to create WandB logger")
try:
return WandBLogger(base_log_dir, tag)
except Exception as e:
if "No module named 'wandb'" in str(e):
logger.error(
"Failed to create WandB logger: No module named 'wandb'. Please install it using 'pip install wandb'."
)
else:
logger.error(f"Failed to create WandB logger: {e}")
if metrics_config.enable_tensorboard:
logger.debug("Creating TensorBoard logger")
return TensorBoardLogger(base_log_dir, tag)
logger.debug("No loggers enabled, returning BaseLogger")
return BaseLogger()
class MetricsProcessor:
"""Metrics processor to processes the metrics and log metrics.
The current MetricsProcessor log some metrics to STDOUT and some metrics to
TensorBoard or WandB.
Args:
job_config (JobConfig): Job configuration.
parallel_dims (ParallelDims): Parallel dimensions.
tag (Optional[str]): Tag to use for TensorBoard or WandB. Defaults to None.
"""
logger: BaseLogger
parallel_dims: ParallelDims
job_config: JobConfig
device_memory_monitor: DeviceMemoryMonitor
color: utils.NoColor | utils.Color
gpu_peak_flops: int
ntokens_since_last_log: int
data_loading_times: list[float]
time_last_log: float
num_flops_per_token: int
optimizers: OptimizersContainer | None
lr_schedulers: LRSchedulersContainer | None
def __init__(
self,
job_config: JobConfig,
parallel_dims: ParallelDims,
tag: str | None = None,
):
self.logger = _build_metric_logger(job_config, parallel_dims, tag)
self.parallel_dims = parallel_dims
self.job_config = job_config
self.device_memory_monitor = build_device_memory_monitor()
# used for colorful printing
self.color = (
utils.NoColor()
if job_config.metrics.disable_color_printing
else utils.Color()
)
self.gpu_peak_flops = utils.get_peak_flops(
self.device_memory_monitor.device_name
)
self.ntokens_since_last_log = 0
self.data_loading_times = []
self.time_last_log = time.perf_counter()
self.device_memory_monitor.reset_peak_stats()
# These variables have to be set later as they depend on other components or model.
self.num_flops_per_token = -1
self.optimizers = None
self.lr_schedulers = None
def should_log(self, step: int) -> bool:
return step == 1 or step % self.job_config.metrics.log_freq == 0
def log(
self,
step: int,
global_avg_loss: float,
global_max_loss: float,
extra_metrics: dict[str, Any] | None = None,
):
assert self.num_flops_per_token > 0, "num_flops_per_token must be set"
time_delta = time.perf_counter() - self.time_last_log
# tokens per second per device, abbreviated as tps
tps = self.ntokens_since_last_log / (
time_delta * self.parallel_dims.non_data_parallel_size
)
# model FLOPS utilization
# For its definition and calculation, please refer to the PaLM paper:
# https://arxiv.org/abs/2204.02311
mfu = 100 * self.num_flops_per_token * tps / self.gpu_peak_flops
tflops = self.num_flops_per_token * tps / 1e12
time_end_to_end = time_delta / self.job_config.metrics.log_freq
time_data_loading = sum(self.data_loading_times) / len(self.data_loading_times)
time_data_loading_pct = 100 * sum(self.data_loading_times) / time_delta
device_mem_stats = self.device_memory_monitor.get_peak_stats()
metrics = {
"loss_metrics/global_avg_loss": global_avg_loss,
"loss_metrics/global_max_loss": global_max_loss,
"throughput(tps)": tps,
"tflops": tflops,
"mfu(%)": mfu,
"time_metrics/end_to_end(s)": time_end_to_end,
"time_metrics/data_loading(s)": time_data_loading,
"time_metrics/data_loading(%)": time_data_loading_pct,
"memory/max_active(GiB)": device_mem_stats.max_active_gib,
"memory/max_active(%)": device_mem_stats.max_active_pct,
"memory/max_reserved(GiB)": device_mem_stats.max_reserved_gib,
"memory/max_reserved(%)": device_mem_stats.max_reserved_pct,
"memory/num_alloc_retries": device_mem_stats.num_alloc_retries,
"memory/num_ooms": device_mem_stats.num_ooms,
}
if extra_metrics:
metrics.update(extra_metrics)
self.logger.log(metrics, step)
color = self.color
construct_string = str(
f"{color.red}step: {step:2} "
f"{color.green}loss: {global_avg_loss:7.4f} "
f"{color.yellow}memory: {device_mem_stats.max_reserved_gib:5.2f}GiB"
f"({device_mem_stats.max_reserved_pct:.2f}%) "
f"{color.blue}tps: {round(tps):,} "
f"{color.cyan}tflops: {tflops:,.2f} "
f"{color.magenta}mfu: {mfu:.2f}%{color.reset}"
)
if extra_metrics:
for k, v in extra_metrics.items():
if "loss" in k:
construct_string += f" {color.white}{k.lstrip('loss_metrics/')}: {v:7.4f}"
logger.info(
construct_string
)
self.ntokens_since_last_log = 0
self.data_loading_times.clear()
self.time_last_log = time.perf_counter()
self.device_memory_monitor.reset_peak_stats()
def close(self):
self.logger.close()
def build_metrics_processor(
job_config: JobConfig, parallel_dims: ParallelDims, tag: str | None = None
) -> MetricsProcessor:
"""Create a metrics processor.
Args:
job_config (JobConfig): Job configuration.
parallel_dims (ParallelDims): Parallel dimensions.
tag (Optional[str]): Tag to use for TensorBoard or WandB. Defaults to None.
Returns:
MetricsProcessor: A metrics processor.
"""
return MetricsProcessor(job_config, parallel_dims, tag)