|
|
|
|
|
|
|
|
|
|
|
|
|
import gc |
|
import subprocess |
|
import time |
|
from dataclasses import dataclass |
|
from typing import Optional |
|
|
|
import torch |
|
from torch._utils import _get_available_device_type, _get_device_module |
|
|
|
from torchtitan.tools.logging import logger |
|
|
|
|
|
def get_device_info(): |
|
device_type = _get_available_device_type() |
|
if device_type is None: |
|
device_type = "cuda" |
|
device_module = _get_device_module(device_type) |
|
return device_type, device_module |
|
|
|
|
|
device_type, device_module = get_device_info() |
|
|
|
|
|
|
|
class GarbageCollection: |
|
def __init__(self, gc_freq=1000): |
|
assert gc_freq > 0, "gc_freq must be a positive integer" |
|
self.gc_freq = gc_freq |
|
gc.disable() |
|
self.collect("Initial GC collection.") |
|
|
|
def run(self, step_count): |
|
if step_count > 1 and step_count % self.gc_freq == 0: |
|
self.collect("Peforming periodical GC collection.") |
|
|
|
@staticmethod |
|
def collect(reason: str): |
|
begin = time.monotonic() |
|
gc.collect(1) |
|
logger.info("[GC] %s %.2f seconds.", reason, time.monotonic() - begin) |
|
|
|
|
|
|
|
def get_peak_flops(device_name: str) -> int: |
|
try: |
|
|
|
result = subprocess.run(["lspci"], stdout=subprocess.PIPE, text=True) |
|
|
|
filtered_lines = [ |
|
line |
|
for line in result.stdout.splitlines() |
|
if "NVIDIA" in line and "H100" in line |
|
] |
|
|
|
device_name = " ".join(filtered_lines) or device_name |
|
except FileNotFoundError as e: |
|
logger.warning(f"Error running lspci: {e}, fallback to use device_name") |
|
if "A100" in device_name: |
|
|
|
return 312e12 |
|
elif "H100" in device_name: |
|
|
|
|
|
if "NVL" in device_name: |
|
return 835e12 |
|
elif "PCIe" in device_name: |
|
return 756e12 |
|
else: |
|
return 989e12 |
|
elif "H200" in device_name: |
|
|
|
return 989e12 |
|
elif "MI300X" in device_name or "MI325X" in device_name: |
|
|
|
|
|
return 1300e12 |
|
elif "MI250X" in device_name: |
|
|
|
return 191.5e12 |
|
elif "Data Center GPU Max 1550" in device_name: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
max_comp_units = torch.xpu.get_device_properties("xpu").max_compute_units |
|
return 512 * max_comp_units * 1300 * 10**6 |
|
else: |
|
logger.warning(f"Peak flops undefined for: {device_name}, fallback to A100") |
|
return 312e12 |
|
|
|
|
|
@dataclass(frozen=True) |
|
class Color: |
|
black = "\033[30m" |
|
red = "\033[31m" |
|
green = "\033[32m" |
|
yellow = "\033[33m" |
|
blue = "\033[34m" |
|
magenta = "\033[35m" |
|
cyan = "\033[36m" |
|
white = "\033[37m" |
|
reset = "\033[39m" |
|
|
|
|
|
@dataclass(frozen=True) |
|
class NoColor: |
|
black = "" |
|
red = "" |
|
green = "" |
|
yellow = "" |
|
blue = "" |
|
magenta = "" |
|
cyan = "" |
|
white = "" |
|
reset = "" |
|
|
|
|
|
def check_if_feature_in_pytorch( |
|
feature_name: str, |
|
pull_request: str, |
|
min_nightly_version: Optional[str] = None, |
|
) -> None: |
|
if "git" in torch.__version__: |
|
|
|
logger.warning( |
|
"detected that the pytorch is built from source. Please make sure the PR " |
|
f"({pull_request_link}) is included in pytorch for correct {feature_name}." |
|
) |
|
elif min_nightly_version is not None and torch.__version__ < min_nightly_version: |
|
logger.warning( |
|
f"detected that the pytorch version {torch.__version__} is older than " |
|
f"{min_nightly_version}. Please upgrade a newer version to include the " |
|
f"change in ({pull_request_link}) for correct {feature_name}." |
|
) |
|
|