zaydzuhri's picture
Add files using upload-large-folder tool
3c70147 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import gc
import subprocess
import time
from dataclasses import dataclass
from typing import Optional
import torch
from torch._utils import _get_available_device_type, _get_device_module
from torchtitan.tools.logging import logger
def get_device_info():
device_type = _get_available_device_type()
if device_type is None:
device_type = "cuda" # default device_type: cuda
device_module = _get_device_module(device_type) # default device_module:torch.cuda
return device_type, device_module
device_type, device_module = get_device_info()
# used to avoid stragglers in garbage collection
class GarbageCollection:
def __init__(self, gc_freq=1000):
assert gc_freq > 0, "gc_freq must be a positive integer"
self.gc_freq = gc_freq
gc.disable()
self.collect("Initial GC collection.")
def run(self, step_count):
if step_count > 1 and step_count % self.gc_freq == 0:
self.collect("Peforming periodical GC collection.")
@staticmethod
def collect(reason: str):
begin = time.monotonic()
gc.collect(1)
logger.info("[GC] %s %.2f seconds.", reason, time.monotonic() - begin)
# hardcoded BF16 type peak flops for NVIDIA A100, H100, H200 GPU and AMD MI250, MI300X, AMD MI325X and Intel PVC
def get_peak_flops(device_name: str) -> int:
try:
# Run the lspci command and capture the output
result = subprocess.run(["lspci"], stdout=subprocess.PIPE, text=True)
# Filter the output for lines containing both "NVIDIA" and "H100"
filtered_lines = [
line
for line in result.stdout.splitlines()
if "NVIDIA" in line and "H100" in line
]
# Join all filtered lines into a single string
device_name = " ".join(filtered_lines) or device_name
except FileNotFoundError as e:
logger.warning(f"Error running lspci: {e}, fallback to use device_name")
if "A100" in device_name:
# data from https://www.nvidia.com/en-us/data-center/a100/
return 312e12
elif "H100" in device_name:
# data from https://www.nvidia.com/en-us/data-center/h100/
# NOTE: Specifications are one-half lower without sparsity.
if "NVL" in device_name:
return 835e12
elif "PCIe" in device_name:
return 756e12
else: # for H100 SXM and other variants
return 989e12
elif "H200" in device_name:
# data from https://www.nvidia.com/en-us/data-center/h200/
return 989e12
elif "MI300X" in device_name or "MI325X" in device_name:
# MI300X data from https://www.amd.com/en/products/accelerators/instinct/mi300/mi300x.html
# MI325X data from https://www.amd.com/en/products/accelerators/instinct/mi300/mi325x.html
return 1300e12
elif "MI250X" in device_name:
# data from https://www.amd.com/en/products/accelerators/instinct/mi200/mi250x.html (per GCD)
return 191.5e12
elif "Data Center GPU Max 1550" in device_name:
# Also known as Ponte Vecchio (PVC).
# data from https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html
# Dot Product Accumulate Systolic (DPAS):
# - Freq: 1300MHz
# - #ops: 512
# Full EU mode (i.e. 512 max compute units): 340.8 TFLOPS (BF16)
# Standard EU mode (i.e. 448 max compute units): 298.2 TFLOPS (BF16)
max_comp_units = torch.xpu.get_device_properties("xpu").max_compute_units
return 512 * max_comp_units * 1300 * 10**6
else: # for other GPU types, assume A100
logger.warning(f"Peak flops undefined for: {device_name}, fallback to A100")
return 312e12
@dataclass(frozen=True)
class Color:
black = "\033[30m"
red = "\033[31m"
green = "\033[32m"
yellow = "\033[33m"
blue = "\033[34m"
magenta = "\033[35m"
cyan = "\033[36m"
white = "\033[37m"
reset = "\033[39m"
@dataclass(frozen=True)
class NoColor:
black = ""
red = ""
green = ""
yellow = ""
blue = ""
magenta = ""
cyan = ""
white = ""
reset = ""
def check_if_feature_in_pytorch(
feature_name: str,
pull_request: str,
min_nightly_version: Optional[str] = None,
) -> None:
if "git" in torch.__version__: # pytorch is built from source
# notify users to check if the pull request is included in their pytorch
logger.warning(
"detected that the pytorch is built from source. Please make sure the PR "
f"({pull_request_link}) is included in pytorch for correct {feature_name}."
)
elif min_nightly_version is not None and torch.__version__ < min_nightly_version:
logger.warning(
f"detected that the pytorch version {torch.__version__} is older than "
f"{min_nightly_version}. Please upgrade a newer version to include the "
f"change in ({pull_request_link}) for correct {feature_name}."
)