# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import gc import subprocess import time from dataclasses import dataclass from typing import Optional import torch from torch._utils import _get_available_device_type, _get_device_module from torchtitan.tools.logging import logger def get_device_info(): device_type = _get_available_device_type() if device_type is None: device_type = "cuda" # default device_type: cuda device_module = _get_device_module(device_type) # default device_module:torch.cuda return device_type, device_module device_type, device_module = get_device_info() # used to avoid stragglers in garbage collection class GarbageCollection: def __init__(self, gc_freq=1000): assert gc_freq > 0, "gc_freq must be a positive integer" self.gc_freq = gc_freq gc.disable() self.collect("Initial GC collection.") def run(self, step_count): if step_count > 1 and step_count % self.gc_freq == 0: self.collect("Peforming periodical GC collection.") @staticmethod def collect(reason: str): begin = time.monotonic() gc.collect(1) logger.info("[GC] %s %.2f seconds.", reason, time.monotonic() - begin) # hardcoded BF16 type peak flops for NVIDIA A100, H100, H200 GPU and AMD MI250, MI300X, AMD MI325X and Intel PVC def get_peak_flops(device_name: str) -> int: try: # Run the lspci command and capture the output result = subprocess.run(["lspci"], stdout=subprocess.PIPE, text=True) # Filter the output for lines containing both "NVIDIA" and "H100" filtered_lines = [ line for line in result.stdout.splitlines() if "NVIDIA" in line and "H100" in line ] # Join all filtered lines into a single string device_name = " ".join(filtered_lines) or device_name except FileNotFoundError as e: logger.warning(f"Error running lspci: {e}, fallback to use device_name") if "A100" in device_name: # data from https://www.nvidia.com/en-us/data-center/a100/ return 312e12 elif "H100" in device_name: # data from https://www.nvidia.com/en-us/data-center/h100/ # NOTE: Specifications are one-half lower without sparsity. if "NVL" in device_name: return 835e12 elif "PCIe" in device_name: return 756e12 else: # for H100 SXM and other variants return 989e12 elif "H200" in device_name: # data from https://www.nvidia.com/en-us/data-center/h200/ return 989e12 elif "MI300X" in device_name or "MI325X" in device_name: # MI300X data from https://www.amd.com/en/products/accelerators/instinct/mi300/mi300x.html # MI325X data from https://www.amd.com/en/products/accelerators/instinct/mi300/mi325x.html return 1300e12 elif "MI250X" in device_name: # data from https://www.amd.com/en/products/accelerators/instinct/mi200/mi250x.html (per GCD) return 191.5e12 elif "Data Center GPU Max 1550" in device_name: # Also known as Ponte Vecchio (PVC). # data from https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html # Dot Product Accumulate Systolic (DPAS): # - Freq: 1300MHz # - #ops: 512 # Full EU mode (i.e. 512 max compute units): 340.8 TFLOPS (BF16) # Standard EU mode (i.e. 448 max compute units): 298.2 TFLOPS (BF16) max_comp_units = torch.xpu.get_device_properties("xpu").max_compute_units return 512 * max_comp_units * 1300 * 10**6 else: # for other GPU types, assume A100 logger.warning(f"Peak flops undefined for: {device_name}, fallback to A100") return 312e12 @dataclass(frozen=True) class Color: black = "\033[30m" red = "\033[31m" green = "\033[32m" yellow = "\033[33m" blue = "\033[34m" magenta = "\033[35m" cyan = "\033[36m" white = "\033[37m" reset = "\033[39m" @dataclass(frozen=True) class NoColor: black = "" red = "" green = "" yellow = "" blue = "" magenta = "" cyan = "" white = "" reset = "" def check_if_feature_in_pytorch( feature_name: str, pull_request: str, min_nightly_version: Optional[str] = None, ) -> None: if "git" in torch.__version__: # pytorch is built from source # notify users to check if the pull request is included in their pytorch logger.warning( "detected that the pytorch is built from source. Please make sure the PR " f"({pull_request_link}) is included in pytorch for correct {feature_name}." ) elif min_nightly_version is not None and torch.__version__ < min_nightly_version: logger.warning( f"detected that the pytorch version {torch.__version__} is older than " f"{min_nightly_version}. Please upgrade a newer version to include the " f"change in ({pull_request_link}) for correct {feature_name}." )