# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
class TorchAutocast: | |
"""TorchAutocast utility class. | |
Allows you to enable and disable autocast. This is specially useful | |
when dealing with different architectures and clusters with different | |
levels of support. | |
Args: | |
enabled (bool): Whether to enable torch.autocast or not. | |
args: Additional args for torch.autocast. | |
kwargs: Additional kwargs for torch.autocast | |
""" | |
def __init__(self, enabled: bool, *args, **kwargs): | |
self.autocast = torch.autocast(*args, **kwargs) if enabled else None | |
def __enter__(self): | |
if self.autocast is None: | |
return | |
try: | |
self.autocast.__enter__() | |
except RuntimeError: | |
device = self.autocast.device | |
dtype = self.autocast.fast_dtype | |
raise RuntimeError( | |
f"There was an error autocasting with dtype={dtype} device={device}\n" | |
"If you are on the FAIR Cluster, you might need to use autocast_dtype=float16" | |
) | |
def __exit__(self, *args, **kwargs): | |
if self.autocast is None: | |
return | |
self.autocast.__exit__(*args, **kwargs) | |