# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import os from collections import namedtuple from datetime import datetime from typing import Any, Dict, Optional import torch # named tuple for passing GPU memory stats for logging GPUMemStats = namedtuple( "GPUMemStats", [ "max_active_gib", "max_active_pct", "max_reserved_gib", "max_reserved_pct", "num_alloc_retries", "num_ooms", ], ) class GPUMemoryMonitor: def __init__(self, logger, device: str = "cuda:0"): self.device = torch.device(device) # device object self.device_name = torch.cuda.get_device_name(self.device) self.device_index = torch.cuda.current_device() self.device_capacity = torch.cuda.get_device_properties( self.device ).total_memory self.device_capacity_gib = self._to_gib(self.device_capacity) self.logger = logger torch.cuda.reset_peak_memory_stats() torch.cuda.empty_cache() def _to_gib(self, memory_in_bytes): # NOTE: GiB (gibibyte) is 1024, vs GB is 1000 _gib_in_bytes = 1024 * 1024 * 1024 memory_in_gib = memory_in_bytes / _gib_in_bytes return memory_in_gib def _to_pct(self, memory): return 100 * memory / self.device_capacity def get_peak_stats(self): cuda_info = torch.cuda.memory_stats(self.device) max_active = cuda_info["active_bytes.all.peak"] max_active_gib = self._to_gib(max_active) max_active_pct = self._to_pct(max_active) max_reserved = cuda_info["reserved_bytes.all.peak"] max_reserved_gib = self._to_gib(max_reserved) max_reserved_pct = self._to_pct(max_reserved) num_retries = cuda_info["num_alloc_retries"] num_ooms = cuda_info["num_ooms"] if num_retries > 0: self.logger.warning(f"{num_retries} CUDA memory allocation retries.") if num_ooms > 0: self.logger.warning(f"{num_ooms} CUDA OOM errors thrown.") return GPUMemStats( max_active_gib, max_active_pct, max_reserved_gib, max_reserved_pct, num_retries, num_ooms, ) def reset_peak_stats(self): torch.cuda.reset_peak_memory_stats() def build_gpu_memory_monitor(logger): gpu_memory_monitor = GPUMemoryMonitor(logger, "cuda") logger.info( f"GPU capacity: {gpu_memory_monitor.device_name} ({gpu_memory_monitor.device_index}) " f"with {gpu_memory_monitor.device_capacity_gib:.2f}GiB memory" ) return gpu_memory_monitor