Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| import logging | |
| from contextlib import contextmanager | |
| from functools import wraps | |
| import torch | |
| __all__ = ["retry_if_cuda_oom"] | |
| def _ignore_torch_cuda_oom(): | |
| """ | |
| A context which ignores CUDA OOM exception from pytorch. | |
| """ | |
| try: | |
| yield | |
| except RuntimeError as e: | |
| # NOTE: the string may change? | |
| if "CUDA out of memory. " in str(e): | |
| pass | |
| else: | |
| raise | |
| def retry_if_cuda_oom(func): | |
| """ | |
| Makes a function retry itself after encountering | |
| pytorch's CUDA OOM error. | |
| It will first retry after calling `torch.cuda.empty_cache()`. | |
| If that still fails, it will then retry by trying to convert inputs to CPUs. | |
| In this case, it expects the function to dispatch to CPU implementation. | |
| The return values may become CPU tensors as well and it's user's | |
| responsibility to convert it back to CUDA tensor if needed. | |
| Args: | |
| func: a stateless callable that takes tensor-like objects as arguments | |
| Returns: | |
| a callable which retries `func` if OOM is encountered. | |
| Examples: | |
| :: | |
| output = retry_if_cuda_oom(some_torch_function)(input1, input2) | |
| # output may be on CPU even if inputs are on GPU | |
| Note: | |
| 1. When converting inputs to CPU, it will only look at each argument and check | |
| if it has `.device` and `.to` for conversion. Nested structures of tensors | |
| are not supported. | |
| 2. Since the function might be called more than once, it has to be | |
| stateless. | |
| """ | |
| def maybe_to_cpu(x): | |
| try: | |
| like_gpu_tensor = x.device.type == "cuda" and hasattr(x, "to") | |
| except AttributeError: | |
| like_gpu_tensor = False | |
| if like_gpu_tensor: | |
| return x.to(device="cpu") | |
| else: | |
| return x | |
| def wrapped(*args, **kwargs): | |
| with _ignore_torch_cuda_oom(): | |
| return func(*args, **kwargs) | |
| # Clear cache and retry | |
| torch.cuda.empty_cache() | |
| with _ignore_torch_cuda_oom(): | |
| return func(*args, **kwargs) | |
| # Try on CPU. This slows down the code significantly, therefore print a notice. | |
| logger = logging.getLogger(__name__) | |
| logger.info("Attempting to copy inputs of {} to CPU due to CUDA OOM".format(str(func))) | |
| new_args = (maybe_to_cpu(x) for x in args) | |
| new_kwargs = {k: maybe_to_cpu(v) for k, v in kwargs.items()} | |
| return func(*new_args, **new_kwargs) | |
| return wrapped | |