|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import sys |
|
from functools import partial |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
|
|
import annotator.uniformer.mmcv as mmcv |
|
|
|
|
|
def get_model_complexity_info(model, |
|
input_shape, |
|
print_per_layer_stat=True, |
|
as_strings=True, |
|
input_constructor=None, |
|
flush=False, |
|
ost=sys.stdout): |
|
"""Get complexity information of a model. |
|
|
|
This method can calculate FLOPs and parameter counts of a model with |
|
corresponding input shape. It can also print complexity information for |
|
each layer in a model. |
|
|
|
Supported layers are listed as below: |
|
- Convolutions: ``nn.Conv1d``, ``nn.Conv2d``, ``nn.Conv3d``. |
|
- Activations: ``nn.ReLU``, ``nn.PReLU``, ``nn.ELU``, ``nn.LeakyReLU``, |
|
``nn.ReLU6``. |
|
- Poolings: ``nn.MaxPool1d``, ``nn.MaxPool2d``, ``nn.MaxPool3d``, |
|
``nn.AvgPool1d``, ``nn.AvgPool2d``, ``nn.AvgPool3d``, |
|
``nn.AdaptiveMaxPool1d``, ``nn.AdaptiveMaxPool2d``, |
|
``nn.AdaptiveMaxPool3d``, ``nn.AdaptiveAvgPool1d``, |
|
``nn.AdaptiveAvgPool2d``, ``nn.AdaptiveAvgPool3d``. |
|
- BatchNorms: ``nn.BatchNorm1d``, ``nn.BatchNorm2d``, |
|
``nn.BatchNorm3d``, ``nn.GroupNorm``, ``nn.InstanceNorm1d``, |
|
``InstanceNorm2d``, ``InstanceNorm3d``, ``nn.LayerNorm``. |
|
- Linear: ``nn.Linear``. |
|
- Deconvolution: ``nn.ConvTranspose2d``. |
|
- Upsample: ``nn.Upsample``. |
|
|
|
Args: |
|
model (nn.Module): The model for complexity calculation. |
|
input_shape (tuple): Input shape used for calculation. |
|
print_per_layer_stat (bool): Whether to print complexity information |
|
for each layer in a model. Default: True. |
|
as_strings (bool): Output FLOPs and params counts in a string form. |
|
Default: True. |
|
input_constructor (None | callable): If specified, it takes a callable |
|
method that generates input. otherwise, it will generate a random |
|
tensor with input shape to calculate FLOPs. Default: None. |
|
flush (bool): same as that in :func:`print`. Default: False. |
|
ost (stream): same as ``file`` param in :func:`print`. |
|
Default: sys.stdout. |
|
|
|
Returns: |
|
tuple[float | str]: If ``as_strings`` is set to True, it will return |
|
FLOPs and parameter counts in a string format. otherwise, it will |
|
return those in a float number format. |
|
""" |
|
assert type(input_shape) is tuple |
|
assert len(input_shape) >= 1 |
|
assert isinstance(model, nn.Module) |
|
flops_model = add_flops_counting_methods(model) |
|
flops_model.eval() |
|
flops_model.start_flops_count() |
|
if input_constructor: |
|
input = input_constructor(input_shape) |
|
_ = flops_model(**input) |
|
else: |
|
try: |
|
batch = torch.ones(()).new_empty( |
|
(1, *input_shape), |
|
dtype=next(flops_model.parameters()).dtype, |
|
device=next(flops_model.parameters()).device) |
|
except StopIteration: |
|
|
|
|
|
batch = torch.ones(()).new_empty((1, *input_shape)) |
|
|
|
_ = flops_model(batch) |
|
|
|
flops_count, params_count = flops_model.compute_average_flops_cost() |
|
if print_per_layer_stat: |
|
print_model_with_flops( |
|
flops_model, flops_count, params_count, ost=ost, flush=flush) |
|
flops_model.stop_flops_count() |
|
|
|
if as_strings: |
|
return flops_to_string(flops_count), params_to_string(params_count) |
|
|
|
return flops_count, params_count |
|
|
|
|
|
def flops_to_string(flops, units='GFLOPs', precision=2): |
|
"""Convert FLOPs number into a string. |
|
|
|
Note that Here we take a multiply-add counts as one FLOP. |
|
|
|
Args: |
|
flops (float): FLOPs number to be converted. |
|
units (str | None): Converted FLOPs units. Options are None, 'GFLOPs', |
|
'MFLOPs', 'KFLOPs', 'FLOPs'. If set to None, it will automatically |
|
choose the most suitable unit for FLOPs. Default: 'GFLOPs'. |
|
precision (int): Digit number after the decimal point. Default: 2. |
|
|
|
Returns: |
|
str: The converted FLOPs number with units. |
|
|
|
Examples: |
|
>>> flops_to_string(1e9) |
|
'1.0 GFLOPs' |
|
>>> flops_to_string(2e5, 'MFLOPs') |
|
'0.2 MFLOPs' |
|
>>> flops_to_string(3e-9, None) |
|
'3e-09 FLOPs' |
|
""" |
|
if units is None: |
|
if flops // 10**9 > 0: |
|
return str(round(flops / 10.**9, precision)) + ' GFLOPs' |
|
elif flops // 10**6 > 0: |
|
return str(round(flops / 10.**6, precision)) + ' MFLOPs' |
|
elif flops // 10**3 > 0: |
|
return str(round(flops / 10.**3, precision)) + ' KFLOPs' |
|
else: |
|
return str(flops) + ' FLOPs' |
|
else: |
|
if units == 'GFLOPs': |
|
return str(round(flops / 10.**9, precision)) + ' ' + units |
|
elif units == 'MFLOPs': |
|
return str(round(flops / 10.**6, precision)) + ' ' + units |
|
elif units == 'KFLOPs': |
|
return str(round(flops / 10.**3, precision)) + ' ' + units |
|
else: |
|
return str(flops) + ' FLOPs' |
|
|
|
|
|
def params_to_string(num_params, units=None, precision=2): |
|
"""Convert parameter number into a string. |
|
|
|
Args: |
|
num_params (float): Parameter number to be converted. |
|
units (str | None): Converted FLOPs units. Options are None, 'M', |
|
'K' and ''. If set to None, it will automatically choose the most |
|
suitable unit for Parameter number. Default: None. |
|
precision (int): Digit number after the decimal point. Default: 2. |
|
|
|
Returns: |
|
str: The converted parameter number with units. |
|
|
|
Examples: |
|
>>> params_to_string(1e9) |
|
'1000.0 M' |
|
>>> params_to_string(2e5) |
|
'200.0 k' |
|
>>> params_to_string(3e-9) |
|
'3e-09' |
|
""" |
|
if units is None: |
|
if num_params // 10**6 > 0: |
|
return str(round(num_params / 10**6, precision)) + ' M' |
|
elif num_params // 10**3: |
|
return str(round(num_params / 10**3, precision)) + ' k' |
|
else: |
|
return str(num_params) |
|
else: |
|
if units == 'M': |
|
return str(round(num_params / 10.**6, precision)) + ' ' + units |
|
elif units == 'K': |
|
return str(round(num_params / 10.**3, precision)) + ' ' + units |
|
else: |
|
return str(num_params) |
|
|
|
|
|
def print_model_with_flops(model, |
|
total_flops, |
|
total_params, |
|
units='GFLOPs', |
|
precision=3, |
|
ost=sys.stdout, |
|
flush=False): |
|
"""Print a model with FLOPs for each layer. |
|
|
|
Args: |
|
model (nn.Module): The model to be printed. |
|
total_flops (float): Total FLOPs of the model. |
|
total_params (float): Total parameter counts of the model. |
|
units (str | None): Converted FLOPs units. Default: 'GFLOPs'. |
|
precision (int): Digit number after the decimal point. Default: 3. |
|
ost (stream): same as `file` param in :func:`print`. |
|
Default: sys.stdout. |
|
flush (bool): same as that in :func:`print`. Default: False. |
|
|
|
Example: |
|
>>> class ExampleModel(nn.Module): |
|
|
|
>>> def __init__(self): |
|
>>> super().__init__() |
|
>>> self.conv1 = nn.Conv2d(3, 8, 3) |
|
>>> self.conv2 = nn.Conv2d(8, 256, 3) |
|
>>> self.conv3 = nn.Conv2d(256, 8, 3) |
|
>>> self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) |
|
>>> self.flatten = nn.Flatten() |
|
>>> self.fc = nn.Linear(8, 1) |
|
|
|
>>> def forward(self, x): |
|
>>> x = self.conv1(x) |
|
>>> x = self.conv2(x) |
|
>>> x = self.conv3(x) |
|
>>> x = self.avg_pool(x) |
|
>>> x = self.flatten(x) |
|
>>> x = self.fc(x) |
|
>>> return x |
|
|
|
>>> model = ExampleModel() |
|
>>> x = (3, 16, 16) |
|
to print the complexity information state for each layer, you can use |
|
>>> get_model_complexity_info(model, x) |
|
or directly use |
|
>>> print_model_with_flops(model, 4579784.0, 37361) |
|
ExampleModel( |
|
0.037 M, 100.000% Params, 0.005 GFLOPs, 100.000% FLOPs, |
|
(conv1): Conv2d(0.0 M, 0.600% Params, 0.0 GFLOPs, 0.959% FLOPs, 3, 8, kernel_size=(3, 3), stride=(1, 1)) # noqa: E501 |
|
(conv2): Conv2d(0.019 M, 50.020% Params, 0.003 GFLOPs, 58.760% FLOPs, 8, 256, kernel_size=(3, 3), stride=(1, 1)) |
|
(conv3): Conv2d(0.018 M, 49.356% Params, 0.002 GFLOPs, 40.264% FLOPs, 256, 8, kernel_size=(3, 3), stride=(1, 1)) |
|
(avg_pool): AdaptiveAvgPool2d(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.017% FLOPs, output_size=(1, 1)) |
|
(flatten): Flatten(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.000% FLOPs, ) |
|
(fc): Linear(0.0 M, 0.024% Params, 0.0 GFLOPs, 0.000% FLOPs, in_features=8, out_features=1, bias=True) |
|
) |
|
""" |
|
|
|
def accumulate_params(self): |
|
if is_supported_instance(self): |
|
return self.__params__ |
|
else: |
|
sum = 0 |
|
for m in self.children(): |
|
sum += m.accumulate_params() |
|
return sum |
|
|
|
def accumulate_flops(self): |
|
if is_supported_instance(self): |
|
return self.__flops__ / model.__batch_counter__ |
|
else: |
|
sum = 0 |
|
for m in self.children(): |
|
sum += m.accumulate_flops() |
|
return sum |
|
|
|
def flops_repr(self): |
|
accumulated_num_params = self.accumulate_params() |
|
accumulated_flops_cost = self.accumulate_flops() |
|
return ', '.join([ |
|
params_to_string( |
|
accumulated_num_params, units='M', precision=precision), |
|
'{:.3%} Params'.format(accumulated_num_params / total_params), |
|
flops_to_string( |
|
accumulated_flops_cost, units=units, precision=precision), |
|
'{:.3%} FLOPs'.format(accumulated_flops_cost / total_flops), |
|
self.original_extra_repr() |
|
]) |
|
|
|
def add_extra_repr(m): |
|
m.accumulate_flops = accumulate_flops.__get__(m) |
|
m.accumulate_params = accumulate_params.__get__(m) |
|
flops_extra_repr = flops_repr.__get__(m) |
|
if m.extra_repr != flops_extra_repr: |
|
m.original_extra_repr = m.extra_repr |
|
m.extra_repr = flops_extra_repr |
|
assert m.extra_repr != m.original_extra_repr |
|
|
|
def del_extra_repr(m): |
|
if hasattr(m, 'original_extra_repr'): |
|
m.extra_repr = m.original_extra_repr |
|
del m.original_extra_repr |
|
if hasattr(m, 'accumulate_flops'): |
|
del m.accumulate_flops |
|
|
|
model.apply(add_extra_repr) |
|
print(model, file=ost, flush=flush) |
|
model.apply(del_extra_repr) |
|
|
|
|
|
def get_model_parameters_number(model): |
|
"""Calculate parameter number of a model. |
|
|
|
Args: |
|
model (nn.module): The model for parameter number calculation. |
|
|
|
Returns: |
|
float: Parameter number of the model. |
|
""" |
|
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
return num_params |
|
|
|
|
|
def add_flops_counting_methods(net_main_module): |
|
|
|
|
|
net_main_module.start_flops_count = start_flops_count.__get__( |
|
net_main_module) |
|
net_main_module.stop_flops_count = stop_flops_count.__get__( |
|
net_main_module) |
|
net_main_module.reset_flops_count = reset_flops_count.__get__( |
|
net_main_module) |
|
net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__( |
|
net_main_module) |
|
|
|
net_main_module.reset_flops_count() |
|
|
|
return net_main_module |
|
|
|
|
|
def compute_average_flops_cost(self): |
|
"""Compute average FLOPs cost. |
|
|
|
A method to compute average FLOPs cost, which will be available after |
|
`add_flops_counting_methods()` is called on a desired net object. |
|
|
|
Returns: |
|
float: Current mean flops consumption per image. |
|
""" |
|
batches_count = self.__batch_counter__ |
|
flops_sum = 0 |
|
for module in self.modules(): |
|
if is_supported_instance(module): |
|
flops_sum += module.__flops__ |
|
params_sum = get_model_parameters_number(self) |
|
return flops_sum / batches_count, params_sum |
|
|
|
|
|
def start_flops_count(self): |
|
"""Activate the computation of mean flops consumption per image. |
|
|
|
A method to activate the computation of mean flops consumption per image. |
|
which will be available after ``add_flops_counting_methods()`` is called on |
|
a desired net object. It should be called before running the network. |
|
""" |
|
add_batch_counter_hook_function(self) |
|
|
|
def add_flops_counter_hook_function(module): |
|
if is_supported_instance(module): |
|
if hasattr(module, '__flops_handle__'): |
|
return |
|
|
|
else: |
|
handle = module.register_forward_hook( |
|
get_modules_mapping()[type(module)]) |
|
|
|
module.__flops_handle__ = handle |
|
|
|
self.apply(partial(add_flops_counter_hook_function)) |
|
|
|
|
|
def stop_flops_count(self): |
|
"""Stop computing the mean flops consumption per image. |
|
|
|
A method to stop computing the mean flops consumption per image, which will |
|
be available after ``add_flops_counting_methods()`` is called on a desired |
|
net object. It can be called to pause the computation whenever. |
|
""" |
|
remove_batch_counter_hook_function(self) |
|
self.apply(remove_flops_counter_hook_function) |
|
|
|
|
|
def reset_flops_count(self): |
|
"""Reset statistics computed so far. |
|
|
|
A method to Reset computed statistics, which will be available after |
|
`add_flops_counting_methods()` is called on a desired net object. |
|
""" |
|
add_batch_counter_variables_or_reset(self) |
|
self.apply(add_flops_counter_variable_or_reset) |
|
|
|
|
|
|
|
def empty_flops_counter_hook(module, input, output): |
|
module.__flops__ += 0 |
|
|
|
|
|
def upsample_flops_counter_hook(module, input, output): |
|
output_size = output[0] |
|
batch_size = output_size.shape[0] |
|
output_elements_count = batch_size |
|
for val in output_size.shape[1:]: |
|
output_elements_count *= val |
|
module.__flops__ += int(output_elements_count) |
|
|
|
|
|
def relu_flops_counter_hook(module, input, output): |
|
active_elements_count = output.numel() |
|
module.__flops__ += int(active_elements_count) |
|
|
|
|
|
def linear_flops_counter_hook(module, input, output): |
|
input = input[0] |
|
output_last_dim = output.shape[ |
|
-1] |
|
module.__flops__ += int(np.prod(input.shape) * output_last_dim) |
|
|
|
|
|
def pool_flops_counter_hook(module, input, output): |
|
input = input[0] |
|
module.__flops__ += int(np.prod(input.shape)) |
|
|
|
|
|
def norm_flops_counter_hook(module, input, output): |
|
input = input[0] |
|
|
|
batch_flops = np.prod(input.shape) |
|
if (getattr(module, 'affine', False) |
|
or getattr(module, 'elementwise_affine', False)): |
|
batch_flops *= 2 |
|
module.__flops__ += int(batch_flops) |
|
|
|
|
|
def deconv_flops_counter_hook(conv_module, input, output): |
|
|
|
input = input[0] |
|
|
|
batch_size = input.shape[0] |
|
input_height, input_width = input.shape[2:] |
|
|
|
kernel_height, kernel_width = conv_module.kernel_size |
|
in_channels = conv_module.in_channels |
|
out_channels = conv_module.out_channels |
|
groups = conv_module.groups |
|
|
|
filters_per_channel = out_channels // groups |
|
conv_per_position_flops = ( |
|
kernel_height * kernel_width * in_channels * filters_per_channel) |
|
|
|
active_elements_count = batch_size * input_height * input_width |
|
overall_conv_flops = conv_per_position_flops * active_elements_count |
|
bias_flops = 0 |
|
if conv_module.bias is not None: |
|
output_height, output_width = output.shape[2:] |
|
bias_flops = out_channels * batch_size * output_height * output_height |
|
overall_flops = overall_conv_flops + bias_flops |
|
|
|
conv_module.__flops__ += int(overall_flops) |
|
|
|
|
|
def conv_flops_counter_hook(conv_module, input, output): |
|
|
|
input = input[0] |
|
|
|
batch_size = input.shape[0] |
|
output_dims = list(output.shape[2:]) |
|
|
|
kernel_dims = list(conv_module.kernel_size) |
|
in_channels = conv_module.in_channels |
|
out_channels = conv_module.out_channels |
|
groups = conv_module.groups |
|
|
|
filters_per_channel = out_channels // groups |
|
conv_per_position_flops = int( |
|
np.prod(kernel_dims)) * in_channels * filters_per_channel |
|
|
|
active_elements_count = batch_size * int(np.prod(output_dims)) |
|
|
|
overall_conv_flops = conv_per_position_flops * active_elements_count |
|
|
|
bias_flops = 0 |
|
|
|
if conv_module.bias is not None: |
|
|
|
bias_flops = out_channels * active_elements_count |
|
|
|
overall_flops = overall_conv_flops + bias_flops |
|
|
|
conv_module.__flops__ += int(overall_flops) |
|
|
|
|
|
def batch_counter_hook(module, input, output): |
|
batch_size = 1 |
|
if len(input) > 0: |
|
|
|
input = input[0] |
|
batch_size = len(input) |
|
else: |
|
pass |
|
print('Warning! No positional inputs found for a module, ' |
|
'assuming batch size is 1.') |
|
module.__batch_counter__ += batch_size |
|
|
|
|
|
def add_batch_counter_variables_or_reset(module): |
|
|
|
module.__batch_counter__ = 0 |
|
|
|
|
|
def add_batch_counter_hook_function(module): |
|
if hasattr(module, '__batch_counter_handle__'): |
|
return |
|
|
|
handle = module.register_forward_hook(batch_counter_hook) |
|
module.__batch_counter_handle__ = handle |
|
|
|
|
|
def remove_batch_counter_hook_function(module): |
|
if hasattr(module, '__batch_counter_handle__'): |
|
module.__batch_counter_handle__.remove() |
|
del module.__batch_counter_handle__ |
|
|
|
|
|
def add_flops_counter_variable_or_reset(module): |
|
if is_supported_instance(module): |
|
if hasattr(module, '__flops__') or hasattr(module, '__params__'): |
|
print('Warning: variables __flops__ or __params__ are already ' |
|
'defined for the module' + type(module).__name__ + |
|
' ptflops can affect your code!') |
|
module.__flops__ = 0 |
|
module.__params__ = get_model_parameters_number(module) |
|
|
|
|
|
def is_supported_instance(module): |
|
if type(module) in get_modules_mapping(): |
|
return True |
|
return False |
|
|
|
|
|
def remove_flops_counter_hook_function(module): |
|
if is_supported_instance(module): |
|
if hasattr(module, '__flops_handle__'): |
|
module.__flops_handle__.remove() |
|
del module.__flops_handle__ |
|
|
|
|
|
def get_modules_mapping(): |
|
return { |
|
|
|
nn.Conv1d: conv_flops_counter_hook, |
|
nn.Conv2d: conv_flops_counter_hook, |
|
mmcv.cnn.bricks.Conv2d: conv_flops_counter_hook, |
|
nn.Conv3d: conv_flops_counter_hook, |
|
mmcv.cnn.bricks.Conv3d: conv_flops_counter_hook, |
|
|
|
nn.ReLU: relu_flops_counter_hook, |
|
nn.PReLU: relu_flops_counter_hook, |
|
nn.ELU: relu_flops_counter_hook, |
|
nn.LeakyReLU: relu_flops_counter_hook, |
|
nn.ReLU6: relu_flops_counter_hook, |
|
|
|
nn.MaxPool1d: pool_flops_counter_hook, |
|
nn.AvgPool1d: pool_flops_counter_hook, |
|
nn.AvgPool2d: pool_flops_counter_hook, |
|
nn.MaxPool2d: pool_flops_counter_hook, |
|
mmcv.cnn.bricks.MaxPool2d: pool_flops_counter_hook, |
|
nn.MaxPool3d: pool_flops_counter_hook, |
|
mmcv.cnn.bricks.MaxPool3d: pool_flops_counter_hook, |
|
nn.AvgPool3d: pool_flops_counter_hook, |
|
nn.AdaptiveMaxPool1d: pool_flops_counter_hook, |
|
nn.AdaptiveAvgPool1d: pool_flops_counter_hook, |
|
nn.AdaptiveMaxPool2d: pool_flops_counter_hook, |
|
nn.AdaptiveAvgPool2d: pool_flops_counter_hook, |
|
nn.AdaptiveMaxPool3d: pool_flops_counter_hook, |
|
nn.AdaptiveAvgPool3d: pool_flops_counter_hook, |
|
|
|
nn.BatchNorm1d: norm_flops_counter_hook, |
|
nn.BatchNorm2d: norm_flops_counter_hook, |
|
nn.BatchNorm3d: norm_flops_counter_hook, |
|
nn.GroupNorm: norm_flops_counter_hook, |
|
nn.InstanceNorm1d: norm_flops_counter_hook, |
|
nn.InstanceNorm2d: norm_flops_counter_hook, |
|
nn.InstanceNorm3d: norm_flops_counter_hook, |
|
nn.LayerNorm: norm_flops_counter_hook, |
|
|
|
nn.Linear: linear_flops_counter_hook, |
|
mmcv.cnn.bricks.Linear: linear_flops_counter_hook, |
|
|
|
nn.Upsample: upsample_flops_counter_hook, |
|
|
|
nn.ConvTranspose2d: deconv_flops_counter_hook, |
|
mmcv.cnn.bricks.ConvTranspose2d: deconv_flops_counter_hook, |
|
} |
|
|