# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import os import time from collections import namedtuple from datetime import datetime from typing import Any import torch from torch.utils.tensorboard import SummaryWriter from torchtitan.components.lr_scheduler import LRSchedulersContainer from torchtitan.components.optimizer import OptimizersContainer from torchtitan.config_manager import JobConfig from torchtitan.distributed import ParallelDims from torchtitan.tools import utils from torchtitan.tools.logging import logger from torchtitan.tools.utils import Color, device_module, device_type # named tuple for passing device memory stats for logging DeviceMemStats = namedtuple( "DeviceMemStats", [ "max_active_gib", "max_active_pct", "max_reserved_gib", "max_reserved_pct", "num_alloc_retries", "num_ooms", ], ) class DeviceMemoryMonitor: def __init__(self, device: str = f"{device_type}:0"): self.device = torch.device(device) # device object self.device_name = device_module.get_device_name(self.device) self.device_index = device_module.current_device() self.device_capacity = device_module.get_device_properties( self.device ).total_memory self.device_capacity_gib = self._to_gib(self.device_capacity) device_module.reset_peak_memory_stats() device_module.empty_cache() def _to_gib(self, memory_in_bytes): # NOTE: GiB (gibibyte) is 1024, vs GB is 1000 _gib_in_bytes = 1024 * 1024 * 1024 memory_in_gib = memory_in_bytes / _gib_in_bytes return memory_in_gib def _to_pct(self, memory): return 100 * memory / self.device_capacity def get_peak_stats(self): device_info = device_module.memory_stats(self.device) max_active = device_info.get("active_bytes.all.peak", -1) max_active_gib = self._to_gib(max_active) max_active_pct = self._to_pct(max_active) max_reserved = device_info.get("reserved_bytes.all.peak", -1) max_reserved_gib = self._to_gib(max_reserved) max_reserved_pct = self._to_pct(max_reserved) num_retries = device_info.get("num_alloc_retries", -1) num_ooms = device_info.get("num_ooms", -1) if num_retries > 0: logger.warning( f"{num_retries} {device_type.upper()} memory allocation retries." ) if num_ooms > 0: logger.warning(f"{num_ooms} {device_type.upper()} OOM errors thrown.") return DeviceMemStats( max_active_gib, max_active_pct, max_reserved_gib, max_reserved_pct, num_retries, num_ooms, ) def reset_peak_stats(self): device_module.reset_peak_memory_stats() def build_device_memory_monitor(): device_memory_monitor = DeviceMemoryMonitor(device_type) logger.info( f"{device_type.upper()} capacity: {device_memory_monitor.device_name} " f"with {device_memory_monitor.device_capacity_gib:.2f}GiB memory" ) return device_memory_monitor class BaseLogger: """Logger that does nothing, used when logging is disabled.""" def log(self, metrics: dict[str, Any], step: int) -> None: pass def close(self) -> None: pass class TensorBoardLogger(BaseLogger): """Logger implementation for TensorBoard.""" def __init__(self, log_dir: str, tag: str | None = None): self.tag = tag self.writer = SummaryWriter(log_dir, max_queue=1000) logger.info(f"TensorBoard logging enabled. Logs will be saved at {log_dir}") def log(self, metrics: dict[str, Any], step: int) -> None: for k, v in metrics.items(): tag = k if self.tag is None else f"{self.tag}/{k}" self.writer.add_scalar(tag, v, step) def close(self) -> None: self.writer.close() class WandBLogger(BaseLogger): """Logger implementation for Weights & Biases.""" def __init__(self, log_dir: str, tag: str | None = None): # Import wandb here to avoid startup import import wandb self.wandb = wandb self.tag = tag # Create logging directory os.makedirs(log_dir, exist_ok=True) self.wandb.init( project=os.getenv("WANDB_PROJECT", "torchtitan"), dir=log_dir, ) logger.info("WandB logging enabled") def log(self, metrics: dict[str, Any], step: int) -> None: wandb_metrics = { (k if self.tag is None else f"{self.tag}/{k}"): v for k, v in metrics.items() } self.wandb.log(wandb_metrics, step=step) def close(self) -> None: if self.wandb.run is not None: self.wandb.finish() def ensure_pp_loss_visible( parallel_dims: ParallelDims, job_config: JobConfig, color: Color ) -> None: """ Ensures that the loss is visible on the console for pipeline-parallel training. For pipeline-parallel training, the loss is only visible on the last pipeline stage. This function checks if the appropriate rank is included in the LOG_RANK environment variable and warns if it's not. """ # V Block Schedules return loss on rank 0 if job_config.parallelism.pipeline_parallel_schedule == "ZBVZeroBubble": return # Calculate the rank where loss is visible (first rank of the last pipeline stage) world_size = parallel_dims.world_size pp_size = parallel_dims.pp loss_visible_rank = (world_size // pp_size) * (pp_size - 1) # Check if the loss-visible rank is included in LOG_RANK environment variable env_logged_ranks = os.environ.get("LOG_RANK", "").split(",") if env_logged_ranks == [""]: env_logged_ranks = [] if str(loss_visible_rank) not in env_logged_ranks: logger.warning( f"{color.red}Pipeline Parallel loss is not visible. " f"Please add {color.yellow}rank {loss_visible_rank}{color.red} " f"to LOG_RANK environment variable in run_train.sh.{color.reset}" ) def _get_metrics_rank( parallel_dims: ParallelDims, job_config: JobConfig, ) -> int: """ Determines which rank should log metrics. Returns: int: The rank responsible for logging metrics: - Rank 0 for non-pipeline-parallel configs - Rank 0 for pipeline-parallel 'ZBVZeroBubble' schedule - The first rank of the last pipeline stage for other pipeline-parallel schedules """ # Early return for non-pipeline-parallel configurations if not parallel_dims.pp_enabled: return 0 # V Block Schedules return loss on rank 0 if job_config.parallelism.pipeline_parallel_schedule == "ZBVZeroBubble": return 0 # Calculate first rank of the last pipeline stage world_size = parallel_dims.world_size pp_size = parallel_dims.pp return (world_size // pp_size) * (pp_size - 1) def _build_metric_logger( job_config: JobConfig, parallel_dims: ParallelDims, tag: str | None = None ) -> BaseLogger: """ Build an appropriate metric logger based on configuration. """ metrics_config = job_config.metrics # Log initial config state logger.debug( f"Building logger with config: wandb={metrics_config.enable_wandb}, " f"tensorboard={metrics_config.enable_tensorboard}" ) # Check if any logging backend is enabled has_logging_enabled = ( metrics_config.enable_tensorboard or metrics_config.enable_wandb ) # Determine if this rank should log should_log = has_logging_enabled if (not metrics_config.save_for_all_ranks) and should_log: metrics_rank = _get_metrics_rank(parallel_dims, job_config) should_log = torch.distributed.get_rank() == metrics_rank logger.debug( f"Logging decision: has_logging_enabled={has_logging_enabled}, should_log={should_log}" ) if not should_log: logger.debug("Returning BaseLogger due to should_log=False") return BaseLogger() # Setup logging directory dump_dir = job_config.job.dump_folder base_log_dir = os.path.join( dump_dir, metrics_config.save_tb_folder, datetime.now().strftime("%Y%m%d-%H%M") ) if metrics_config.save_for_all_ranks: base_log_dir = os.path.join( base_log_dir, f"rank_{torch.distributed.get_rank()}" ) # Create loggers in priority order if metrics_config.enable_wandb: logger.debug("Attempting to create WandB logger") try: return WandBLogger(base_log_dir, tag) except Exception as e: if "No module named 'wandb'" in str(e): logger.error( "Failed to create WandB logger: No module named 'wandb'. Please install it using 'pip install wandb'." ) else: logger.error(f"Failed to create WandB logger: {e}") if metrics_config.enable_tensorboard: logger.debug("Creating TensorBoard logger") return TensorBoardLogger(base_log_dir, tag) logger.debug("No loggers enabled, returning BaseLogger") return BaseLogger() class MetricsProcessor: """Metrics processor to processes the metrics and log metrics. The current MetricsProcessor log some metrics to STDOUT and some metrics to TensorBoard or WandB. Args: job_config (JobConfig): Job configuration. parallel_dims (ParallelDims): Parallel dimensions. tag (Optional[str]): Tag to use for TensorBoard or WandB. Defaults to None. """ logger: BaseLogger parallel_dims: ParallelDims job_config: JobConfig device_memory_monitor: DeviceMemoryMonitor color: utils.NoColor | utils.Color gpu_peak_flops: int ntokens_since_last_log: int data_loading_times: list[float] time_last_log: float num_flops_per_token: int optimizers: OptimizersContainer | None lr_schedulers: LRSchedulersContainer | None def __init__( self, job_config: JobConfig, parallel_dims: ParallelDims, tag: str | None = None, ): self.logger = _build_metric_logger(job_config, parallel_dims, tag) self.parallel_dims = parallel_dims self.job_config = job_config self.device_memory_monitor = build_device_memory_monitor() # used for colorful printing self.color = ( utils.NoColor() if job_config.metrics.disable_color_printing else utils.Color() ) self.gpu_peak_flops = utils.get_peak_flops( self.device_memory_monitor.device_name ) self.ntokens_since_last_log = 0 self.data_loading_times = [] self.time_last_log = time.perf_counter() self.device_memory_monitor.reset_peak_stats() # These variables have to be set later as they depend on other components or model. self.num_flops_per_token = -1 self.optimizers = None self.lr_schedulers = None def should_log(self, step: int) -> bool: return step == 1 or step % self.job_config.metrics.log_freq == 0 def log( self, step: int, global_avg_loss: float, global_max_loss: float, extra_metrics: dict[str, Any] | None = None, ): assert self.num_flops_per_token > 0, "num_flops_per_token must be set" time_delta = time.perf_counter() - self.time_last_log # tokens per second per device, abbreviated as tps tps = self.ntokens_since_last_log / ( time_delta * self.parallel_dims.non_data_parallel_size ) # model FLOPS utilization # For its definition and calculation, please refer to the PaLM paper: # https://arxiv.org/abs/2204.02311 mfu = 100 * self.num_flops_per_token * tps / self.gpu_peak_flops tflops = self.num_flops_per_token * tps / 1e12 time_end_to_end = time_delta / self.job_config.metrics.log_freq time_data_loading = sum(self.data_loading_times) / len(self.data_loading_times) time_data_loading_pct = 100 * sum(self.data_loading_times) / time_delta device_mem_stats = self.device_memory_monitor.get_peak_stats() metrics = { "loss_metrics/global_avg_loss": global_avg_loss, "loss_metrics/global_max_loss": global_max_loss, "throughput(tps)": tps, "tflops": tflops, "mfu(%)": mfu, "time_metrics/end_to_end(s)": time_end_to_end, "time_metrics/data_loading(s)": time_data_loading, "time_metrics/data_loading(%)": time_data_loading_pct, "memory/max_active(GiB)": device_mem_stats.max_active_gib, "memory/max_active(%)": device_mem_stats.max_active_pct, "memory/max_reserved(GiB)": device_mem_stats.max_reserved_gib, "memory/max_reserved(%)": device_mem_stats.max_reserved_pct, "memory/num_alloc_retries": device_mem_stats.num_alloc_retries, "memory/num_ooms": device_mem_stats.num_ooms, } if extra_metrics: metrics.update(extra_metrics) self.logger.log(metrics, step) color = self.color construct_string = str( f"{color.red}step: {step:2} " f"{color.green}loss: {global_avg_loss:7.4f} " f"{color.yellow}memory: {device_mem_stats.max_reserved_gib:5.2f}GiB" f"({device_mem_stats.max_reserved_pct:.2f}%) " f"{color.blue}tps: {round(tps):,} " f"{color.cyan}tflops: {tflops:,.2f} " f"{color.magenta}mfu: {mfu:.2f}%{color.reset}" ) if extra_metrics: for k, v in extra_metrics.items(): if "loss" in k: construct_string += f" {color.white}{k.lstrip('loss_metrics/')}: {v:7.4f}" logger.info( construct_string ) self.ntokens_since_last_log = 0 self.data_loading_times.clear() self.time_last_log = time.perf_counter() self.device_memory_monitor.reset_peak_stats() def close(self): self.logger.close() def build_metrics_processor( job_config: JobConfig, parallel_dims: ParallelDims, tag: str | None = None ) -> MetricsProcessor: """Create a metrics processor. Args: job_config (JobConfig): Job configuration. parallel_dims (ParallelDims): Parallel dimensions. tag (Optional[str]): Tag to use for TensorBoard or WandB. Defaults to None. Returns: MetricsProcessor: A metrics processor. """ return MetricsProcessor(job_config, parallel_dims, tag)