|
|
|
import functools |
|
import warnings |
|
from collections import abc |
|
from inspect import getfullargspec |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
|
|
from annotator.uniformer.mmcv.utils import TORCH_VERSION, digit_version |
|
from .dist_utils import allreduce_grads as _allreduce_grads |
|
|
|
try: |
|
|
|
|
|
|
|
|
|
from torch.cuda.amp import autocast |
|
except ImportError: |
|
pass |
|
|
|
|
|
def cast_tensor_type(inputs, src_type, dst_type): |
|
"""Recursively convert Tensor in inputs from src_type to dst_type. |
|
|
|
Args: |
|
inputs: Inputs that to be casted. |
|
src_type (torch.dtype): Source type.. |
|
dst_type (torch.dtype): Destination type. |
|
|
|
Returns: |
|
The same type with inputs, but all contained Tensors have been cast. |
|
""" |
|
if isinstance(inputs, nn.Module): |
|
return inputs |
|
elif isinstance(inputs, torch.Tensor): |
|
return inputs.to(dst_type) |
|
elif isinstance(inputs, str): |
|
return inputs |
|
elif isinstance(inputs, np.ndarray): |
|
return inputs |
|
elif isinstance(inputs, abc.Mapping): |
|
return type(inputs)({ |
|
k: cast_tensor_type(v, src_type, dst_type) |
|
for k, v in inputs.items() |
|
}) |
|
elif isinstance(inputs, abc.Iterable): |
|
return type(inputs)( |
|
cast_tensor_type(item, src_type, dst_type) for item in inputs) |
|
else: |
|
return inputs |
|
|
|
|
|
def auto_fp16(apply_to=None, out_fp32=False): |
|
"""Decorator to enable fp16 training automatically. |
|
|
|
This decorator is useful when you write custom modules and want to support |
|
mixed precision training. If inputs arguments are fp32 tensors, they will |
|
be converted to fp16 automatically. Arguments other than fp32 tensors are |
|
ignored. If you are using PyTorch >= 1.6, torch.cuda.amp is used as the |
|
backend, otherwise, original mmcv implementation will be adopted. |
|
|
|
Args: |
|
apply_to (Iterable, optional): The argument names to be converted. |
|
`None` indicates all arguments. |
|
out_fp32 (bool): Whether to convert the output back to fp32. |
|
|
|
Example: |
|
|
|
>>> import torch.nn as nn |
|
>>> class MyModule1(nn.Module): |
|
>>> |
|
>>> # Convert x and y to fp16 |
|
>>> @auto_fp16() |
|
>>> def forward(self, x, y): |
|
>>> pass |
|
|
|
>>> import torch.nn as nn |
|
>>> class MyModule2(nn.Module): |
|
>>> |
|
>>> # convert pred to fp16 |
|
>>> @auto_fp16(apply_to=('pred', )) |
|
>>> def do_something(self, pred, others): |
|
>>> pass |
|
""" |
|
|
|
def auto_fp16_wrapper(old_func): |
|
|
|
@functools.wraps(old_func) |
|
def new_func(*args, **kwargs): |
|
|
|
|
|
if not isinstance(args[0], torch.nn.Module): |
|
raise TypeError('@auto_fp16 can only be used to decorate the ' |
|
'method of nn.Module') |
|
if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled): |
|
return old_func(*args, **kwargs) |
|
|
|
|
|
args_info = getfullargspec(old_func) |
|
|
|
args_to_cast = args_info.args if apply_to is None else apply_to |
|
|
|
new_args = [] |
|
|
|
if args: |
|
arg_names = args_info.args[:len(args)] |
|
for i, arg_name in enumerate(arg_names): |
|
if arg_name in args_to_cast: |
|
new_args.append( |
|
cast_tensor_type(args[i], torch.float, torch.half)) |
|
else: |
|
new_args.append(args[i]) |
|
|
|
new_kwargs = {} |
|
if kwargs: |
|
for arg_name, arg_value in kwargs.items(): |
|
if arg_name in args_to_cast: |
|
new_kwargs[arg_name] = cast_tensor_type( |
|
arg_value, torch.float, torch.half) |
|
else: |
|
new_kwargs[arg_name] = arg_value |
|
|
|
if (TORCH_VERSION != 'parrots' and |
|
digit_version(TORCH_VERSION) >= digit_version('1.6.0')): |
|
with autocast(enabled=True): |
|
output = old_func(*new_args, **new_kwargs) |
|
else: |
|
output = old_func(*new_args, **new_kwargs) |
|
|
|
if out_fp32: |
|
output = cast_tensor_type(output, torch.half, torch.float) |
|
return output |
|
|
|
return new_func |
|
|
|
return auto_fp16_wrapper |
|
|
|
|
|
def force_fp32(apply_to=None, out_fp16=False): |
|
"""Decorator to convert input arguments to fp32 in force. |
|
|
|
This decorator is useful when you write custom modules and want to support |
|
mixed precision training. If there are some inputs that must be processed |
|
in fp32 mode, then this decorator can handle it. If inputs arguments are |
|
fp16 tensors, they will be converted to fp32 automatically. Arguments other |
|
than fp16 tensors are ignored. If you are using PyTorch >= 1.6, |
|
torch.cuda.amp is used as the backend, otherwise, original mmcv |
|
implementation will be adopted. |
|
|
|
Args: |
|
apply_to (Iterable, optional): The argument names to be converted. |
|
`None` indicates all arguments. |
|
out_fp16 (bool): Whether to convert the output back to fp16. |
|
|
|
Example: |
|
|
|
>>> import torch.nn as nn |
|
>>> class MyModule1(nn.Module): |
|
>>> |
|
>>> # Convert x and y to fp32 |
|
>>> @force_fp32() |
|
>>> def loss(self, x, y): |
|
>>> pass |
|
|
|
>>> import torch.nn as nn |
|
>>> class MyModule2(nn.Module): |
|
>>> |
|
>>> # convert pred to fp32 |
|
>>> @force_fp32(apply_to=('pred', )) |
|
>>> def post_process(self, pred, others): |
|
>>> pass |
|
""" |
|
|
|
def force_fp32_wrapper(old_func): |
|
|
|
@functools.wraps(old_func) |
|
def new_func(*args, **kwargs): |
|
|
|
|
|
if not isinstance(args[0], torch.nn.Module): |
|
raise TypeError('@force_fp32 can only be used to decorate the ' |
|
'method of nn.Module') |
|
if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled): |
|
return old_func(*args, **kwargs) |
|
|
|
args_info = getfullargspec(old_func) |
|
|
|
args_to_cast = args_info.args if apply_to is None else apply_to |
|
|
|
new_args = [] |
|
if args: |
|
arg_names = args_info.args[:len(args)] |
|
for i, arg_name in enumerate(arg_names): |
|
if arg_name in args_to_cast: |
|
new_args.append( |
|
cast_tensor_type(args[i], torch.half, torch.float)) |
|
else: |
|
new_args.append(args[i]) |
|
|
|
new_kwargs = dict() |
|
if kwargs: |
|
for arg_name, arg_value in kwargs.items(): |
|
if arg_name in args_to_cast: |
|
new_kwargs[arg_name] = cast_tensor_type( |
|
arg_value, torch.half, torch.float) |
|
else: |
|
new_kwargs[arg_name] = arg_value |
|
|
|
if (TORCH_VERSION != 'parrots' and |
|
digit_version(TORCH_VERSION) >= digit_version('1.6.0')): |
|
with autocast(enabled=False): |
|
output = old_func(*new_args, **new_kwargs) |
|
else: |
|
output = old_func(*new_args, **new_kwargs) |
|
|
|
if out_fp16: |
|
output = cast_tensor_type(output, torch.float, torch.half) |
|
return output |
|
|
|
return new_func |
|
|
|
return force_fp32_wrapper |
|
|
|
|
|
def allreduce_grads(params, coalesce=True, bucket_size_mb=-1): |
|
warnings.warning( |
|
'"mmcv.runner.fp16_utils.allreduce_grads" is deprecated, and will be ' |
|
'removed in v2.8. Please switch to "mmcv.runner.allreduce_grads') |
|
_allreduce_grads(params, coalesce=coalesce, bucket_size_mb=bucket_size_mb) |
|
|
|
|
|
def wrap_fp16_model(model): |
|
"""Wrap the FP32 model to FP16. |
|
|
|
If you are using PyTorch >= 1.6, torch.cuda.amp is used as the |
|
backend, otherwise, original mmcv implementation will be adopted. |
|
|
|
For PyTorch >= 1.6, this function will |
|
1. Set fp16 flag inside the model to True. |
|
|
|
Otherwise: |
|
1. Convert FP32 model to FP16. |
|
2. Remain some necessary layers to be FP32, e.g., normalization layers. |
|
3. Set `fp16_enabled` flag inside the model to True. |
|
|
|
Args: |
|
model (nn.Module): Model in FP32. |
|
""" |
|
if (TORCH_VERSION == 'parrots' |
|
or digit_version(TORCH_VERSION) < digit_version('1.6.0')): |
|
|
|
model.half() |
|
|
|
patch_norm_fp32(model) |
|
|
|
for m in model.modules(): |
|
if hasattr(m, 'fp16_enabled'): |
|
m.fp16_enabled = True |
|
|
|
|
|
def patch_norm_fp32(module): |
|
"""Recursively convert normalization layers from FP16 to FP32. |
|
|
|
Args: |
|
module (nn.Module): The modules to be converted in FP16. |
|
|
|
Returns: |
|
nn.Module: The converted module, the normalization layers have been |
|
converted to FP32. |
|
""" |
|
if isinstance(module, (nn.modules.batchnorm._BatchNorm, nn.GroupNorm)): |
|
module.float() |
|
if isinstance(module, nn.GroupNorm) or torch.__version__ < '1.3': |
|
module.forward = patch_forward_method(module.forward, torch.half, |
|
torch.float) |
|
for child in module.children(): |
|
patch_norm_fp32(child) |
|
return module |
|
|
|
|
|
def patch_forward_method(func, src_type, dst_type, convert_output=True): |
|
"""Patch the forward method of a module. |
|
|
|
Args: |
|
func (callable): The original forward method. |
|
src_type (torch.dtype): Type of input arguments to be converted from. |
|
dst_type (torch.dtype): Type of input arguments to be converted to. |
|
convert_output (bool): Whether to convert the output back to src_type. |
|
|
|
Returns: |
|
callable: The patched forward method. |
|
""" |
|
|
|
def new_forward(*args, **kwargs): |
|
output = func(*cast_tensor_type(args, src_type, dst_type), |
|
**cast_tensor_type(kwargs, src_type, dst_type)) |
|
if convert_output: |
|
output = cast_tensor_type(output, dst_type, src_type) |
|
return output |
|
|
|
return new_forward |
|
|
|
|
|
class LossScaler: |
|
"""Class that manages loss scaling in mixed precision training which |
|
supports both dynamic or static mode. |
|
|
|
The implementation refers to |
|
https://github.com/NVIDIA/apex/blob/master/apex/fp16_utils/loss_scaler.py. |
|
Indirectly, by supplying ``mode='dynamic'`` for dynamic loss scaling. |
|
It's important to understand how :class:`LossScaler` operates. |
|
Loss scaling is designed to combat the problem of underflowing |
|
gradients encountered at long times when training fp16 networks. |
|
Dynamic loss scaling begins by attempting a very high loss |
|
scale. Ironically, this may result in OVERflowing gradients. |
|
If overflowing gradients are encountered, :class:`FP16_Optimizer` then |
|
skips the update step for this particular iteration/minibatch, |
|
and :class:`LossScaler` adjusts the loss scale to a lower value. |
|
If a certain number of iterations occur without overflowing gradients |
|
detected,:class:`LossScaler` increases the loss scale once more. |
|
In this way :class:`LossScaler` attempts to "ride the edge" of always |
|
using the highest loss scale possible without incurring overflow. |
|
|
|
Args: |
|
init_scale (float): Initial loss scale value, default: 2**32. |
|
scale_factor (float): Factor used when adjusting the loss scale. |
|
Default: 2. |
|
mode (str): Loss scaling mode. 'dynamic' or 'static' |
|
scale_window (int): Number of consecutive iterations without an |
|
overflow to wait before increasing the loss scale. Default: 1000. |
|
""" |
|
|
|
def __init__(self, |
|
init_scale=2**32, |
|
mode='dynamic', |
|
scale_factor=2., |
|
scale_window=1000): |
|
self.cur_scale = init_scale |
|
self.cur_iter = 0 |
|
assert mode in ('dynamic', |
|
'static'), 'mode can only be dynamic or static' |
|
self.mode = mode |
|
self.last_overflow_iter = -1 |
|
self.scale_factor = scale_factor |
|
self.scale_window = scale_window |
|
|
|
def has_overflow(self, params): |
|
"""Check if params contain overflow.""" |
|
if self.mode != 'dynamic': |
|
return False |
|
for p in params: |
|
if p.grad is not None and LossScaler._has_inf_or_nan(p.grad.data): |
|
return True |
|
return False |
|
|
|
def _has_inf_or_nan(x): |
|
"""Check if params contain NaN.""" |
|
try: |
|
cpu_sum = float(x.float().sum()) |
|
except RuntimeError as instance: |
|
if 'value cannot be converted' not in instance.args[0]: |
|
raise |
|
return True |
|
else: |
|
if cpu_sum == float('inf') or cpu_sum == -float('inf') \ |
|
or cpu_sum != cpu_sum: |
|
return True |
|
return False |
|
|
|
def update_scale(self, overflow): |
|
"""update the current loss scale value when overflow happens.""" |
|
if self.mode != 'dynamic': |
|
return |
|
if overflow: |
|
self.cur_scale = max(self.cur_scale / self.scale_factor, 1) |
|
self.last_overflow_iter = self.cur_iter |
|
else: |
|
if (self.cur_iter - self.last_overflow_iter) % \ |
|
self.scale_window == 0: |
|
self.cur_scale *= self.scale_factor |
|
self.cur_iter += 1 |
|
|
|
def state_dict(self): |
|
"""Returns the state of the scaler as a :class:`dict`.""" |
|
return dict( |
|
cur_scale=self.cur_scale, |
|
cur_iter=self.cur_iter, |
|
mode=self.mode, |
|
last_overflow_iter=self.last_overflow_iter, |
|
scale_factor=self.scale_factor, |
|
scale_window=self.scale_window) |
|
|
|
def load_state_dict(self, state_dict): |
|
"""Loads the loss_scaler state dict. |
|
|
|
Args: |
|
state_dict (dict): scaler state. |
|
""" |
|
self.cur_scale = state_dict['cur_scale'] |
|
self.cur_iter = state_dict['cur_iter'] |
|
self.mode = state_dict['mode'] |
|
self.last_overflow_iter = state_dict['last_overflow_iter'] |
|
self.scale_factor = state_dict['scale_factor'] |
|
self.scale_window = state_dict['scale_window'] |
|
|
|
@property |
|
def loss_scale(self): |
|
return self.cur_scale |
|
|