from itertools import repeat import import torch from torch import nn as nn from torchvision.ops.misc import FrozenBatchNorm2d def freeze_batch_norm_2d(module, module_match={}, name=''): """ Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and returned. Otherwise, the module is walked recursively and submodules are converted in place. Args: module (torch.nn.Module): Any PyTorch module. module_match (dict): Dictionary of full module names to freeze (all if empty) name (str): Full module name (prefix) Returns: torch.nn.Module: Resulting module Inspired by """ res = module is_match = True if module_match: is_match = name in module_match if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): res = FrozenBatchNorm2d(module.num_features) res.num_features = module.num_features res.affine = module.affine if module.affine: = = = = res.eps = module.eps else: for child_name, child in module.named_children(): full_child_name = '.'.join([name, child_name]) if name else child_name new_child = freeze_batch_norm_2d(child, module_match, full_child_name) if new_child is not child: res.add_module(child_name, new_child) return res # From PyTorch internals def _ntuple(n): def parse(x): if isinstance(x, return x return tuple(repeat(x, n)) return parse to_1tuple = _ntuple(1) to_2tuple = _ntuple(2) to_3tuple = _ntuple(3) to_4tuple = _ntuple(4) to_ntuple = lambda n, x: _ntuple(n)(x) # Replaces all linear layers with linear_replacement # TODO: add int8 support for other linear layers including attn and convnets def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True): for name, module in model.named_children(): if len(list(module.children())) > 0: replace_linear(module, linear_replacement, include_modules, copy_weights) if isinstance(module, torch.nn.Linear) and name in include_modules: old_module = model._modules[name] model._modules[name] = linear_replacement( module.in_features, module.out_features, module.bias is not None, ) if copy_weights: model._modules[name] if model._modules[name].bias is not None: model._modules[name] return model def convert_int8_model_to_inference_mode(model): for m in model.modules(): if hasattr(m, 'prepare_for_eval'): int8_original_dtype = m.weight.dtype m.prepare_for_eval() m.int8_original_dtype = int8_original_dtype def accuracy(output, target, topk=(1,)): """ Compute top-k accuracy output: torch.Tensor shape (N, C) where N is the number of examples, C the number of classes. these are the logits. target: torch.Tensor shape (N,) where N is the number of examples. Groundtruth class id of each example. topk: tuple which topk to compute, e.g., topk=(1,5) will compute top-1 and top-5 accuracies Returns ------- list of top-k accuracies in the same order as `topk` """ pred = output.topk(max(topk), 1, True, True)[1].t() correct = pred.eq(target.view(1, -1).expand_as(pred)) n = len(target) return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) / n for k in topk]