|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
from collections import namedtuple |
|
from datetime import datetime |
|
from typing import Any, Dict, Optional |
|
|
|
import torch |
|
|
|
|
|
GPUMemStats = namedtuple( |
|
"GPUMemStats", |
|
[ |
|
"max_active_gib", |
|
"max_active_pct", |
|
"max_reserved_gib", |
|
"max_reserved_pct", |
|
"num_alloc_retries", |
|
"num_ooms", |
|
], |
|
) |
|
|
|
|
|
class GPUMemoryMonitor: |
|
def __init__(self, logger, device: str = "cuda:0"): |
|
self.device = torch.device(device) |
|
self.device_name = torch.cuda.get_device_name(self.device) |
|
self.device_index = torch.cuda.current_device() |
|
self.device_capacity = torch.cuda.get_device_properties( |
|
self.device |
|
).total_memory |
|
self.device_capacity_gib = self._to_gib(self.device_capacity) |
|
|
|
self.logger = logger |
|
|
|
torch.cuda.reset_peak_memory_stats() |
|
torch.cuda.empty_cache() |
|
|
|
def _to_gib(self, memory_in_bytes): |
|
|
|
_gib_in_bytes = 1024 * 1024 * 1024 |
|
memory_in_gib = memory_in_bytes / _gib_in_bytes |
|
return memory_in_gib |
|
|
|
def _to_pct(self, memory): |
|
return 100 * memory / self.device_capacity |
|
|
|
def get_peak_stats(self): |
|
cuda_info = torch.cuda.memory_stats(self.device) |
|
|
|
max_active = cuda_info["active_bytes.all.peak"] |
|
max_active_gib = self._to_gib(max_active) |
|
max_active_pct = self._to_pct(max_active) |
|
|
|
max_reserved = cuda_info["reserved_bytes.all.peak"] |
|
max_reserved_gib = self._to_gib(max_reserved) |
|
max_reserved_pct = self._to_pct(max_reserved) |
|
|
|
num_retries = cuda_info["num_alloc_retries"] |
|
num_ooms = cuda_info["num_ooms"] |
|
|
|
if num_retries > 0: |
|
self.logger.warning(f"{num_retries} CUDA memory allocation retries.") |
|
if num_ooms > 0: |
|
self.logger.warning(f"{num_ooms} CUDA OOM errors thrown.") |
|
|
|
return GPUMemStats( |
|
max_active_gib, |
|
max_active_pct, |
|
max_reserved_gib, |
|
max_reserved_pct, |
|
num_retries, |
|
num_ooms, |
|
) |
|
|
|
def reset_peak_stats(self): |
|
torch.cuda.reset_peak_memory_stats() |
|
|
|
|
|
def build_gpu_memory_monitor(logger): |
|
gpu_memory_monitor = GPUMemoryMonitor(logger, "cuda") |
|
logger.info( |
|
f"GPU capacity: {gpu_memory_monitor.device_name} ({gpu_memory_monitor.device_index}) " |
|
f"with {gpu_memory_monitor.device_capacity_gib:.2f}GiB memory" |
|
) |
|
|
|
return gpu_memory_monitor |