File size: 2,793 Bytes
d0e893e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
from collections import namedtuple
from datetime import datetime
from typing import Any, Dict, Optional
import torch
# named tuple for passing GPU memory stats for logging
GPUMemStats = namedtuple(
"GPUMemStats",
[
"max_active_gib",
"max_active_pct",
"max_reserved_gib",
"max_reserved_pct",
"num_alloc_retries",
"num_ooms",
],
)
class GPUMemoryMonitor:
def __init__(self, logger, device: str = "cuda:0"):
self.device = torch.device(device) # device object
self.device_name = torch.cuda.get_device_name(self.device)
self.device_index = torch.cuda.current_device()
self.device_capacity = torch.cuda.get_device_properties(
self.device
).total_memory
self.device_capacity_gib = self._to_gib(self.device_capacity)
self.logger = logger
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
def _to_gib(self, memory_in_bytes):
# NOTE: GiB (gibibyte) is 1024, vs GB is 1000
_gib_in_bytes = 1024 * 1024 * 1024
memory_in_gib = memory_in_bytes / _gib_in_bytes
return memory_in_gib
def _to_pct(self, memory):
return 100 * memory / self.device_capacity
def get_peak_stats(self):
cuda_info = torch.cuda.memory_stats(self.device)
max_active = cuda_info["active_bytes.all.peak"]
max_active_gib = self._to_gib(max_active)
max_active_pct = self._to_pct(max_active)
max_reserved = cuda_info["reserved_bytes.all.peak"]
max_reserved_gib = self._to_gib(max_reserved)
max_reserved_pct = self._to_pct(max_reserved)
num_retries = cuda_info["num_alloc_retries"]
num_ooms = cuda_info["num_ooms"]
if num_retries > 0:
self.logger.warning(f"{num_retries} CUDA memory allocation retries.")
if num_ooms > 0:
self.logger.warning(f"{num_ooms} CUDA OOM errors thrown.")
return GPUMemStats(
max_active_gib,
max_active_pct,
max_reserved_gib,
max_reserved_pct,
num_retries,
num_ooms,
)
def reset_peak_stats(self):
torch.cuda.reset_peak_memory_stats()
def build_gpu_memory_monitor(logger):
gpu_memory_monitor = GPUMemoryMonitor(logger, "cuda")
logger.info(
f"GPU capacity: {gpu_memory_monitor.device_name} ({gpu_memory_monitor.device_index}) "
f"with {gpu_memory_monitor.device_capacity_gib:.2f}GiB memory"
)
return gpu_memory_monitor |