Spaces:
Running
Running
| from typing import Optional | |
| import torch | |
| from copy import deepcopy | |
| from torch import nn | |
| from utils.common.others import get_cur_time_str | |
| from utils.dl.common.model import get_model_device, get_model_latency, get_model_size, get_module, get_super_module, set_module | |
| from utils.common.log import logger | |
| from utils.third_party.nni_new.compression.pytorch.speedup import ModelSpeedup | |
| import os | |
| from .base import Abs, KTakesAll, Layer_WrappedWithFBS, ElasticDNNUtil | |
| class Conv2d_WrappedWithFBS(Layer_WrappedWithFBS): | |
| def __init__(self, raw_conv2d: nn.Conv2d, raw_bn: nn.BatchNorm2d, r): | |
| super(Conv2d_WrappedWithFBS, self).__init__() | |
| self.fbs = nn.Sequential( | |
| Abs(), | |
| nn.AdaptiveAvgPool2d(1), | |
| nn.Flatten(), | |
| nn.Linear(raw_conv2d.in_channels, raw_conv2d.out_channels // r), | |
| nn.ReLU(), | |
| nn.Linear(raw_conv2d.out_channels // r, raw_conv2d.out_channels), | |
| nn.ReLU() | |
| ) | |
| self.raw_conv2d = raw_conv2d | |
| self.raw_bn = raw_bn # remember clear the original BNs in the network | |
| nn.init.constant_(self.fbs[5].bias, 1.) | |
| nn.init.kaiming_normal_(self.fbs[5].weight) | |
| def forward(self, x): | |
| raw_x = self.raw_bn(self.raw_conv2d(x)) | |
| if self.use_cached_channel_attention and self.cached_channel_attention is not None: | |
| channel_attention = self.cached_channel_attention | |
| else: | |
| self.cached_raw_channel_attention = self.fbs(x) | |
| self.cached_channel_attention = self.k_takes_all(self.cached_raw_channel_attention) | |
| channel_attention = self.cached_channel_attention | |
| return raw_x * channel_attention.unsqueeze(2).unsqueeze(3) | |
| class StaticFBS(nn.Module): | |
| def __init__(self, channel_attention: torch.Tensor): | |
| super(StaticFBS, self).__init__() | |
| assert channel_attention.dim() == 1 | |
| self.channel_attention = nn.Parameter(channel_attention.unsqueeze(0).unsqueeze(2).unsqueeze(3), requires_grad=False) | |
| def forward(self, x): | |
| return x * self.channel_attention | |
| def __str__(self) -> str: | |
| return f'StaticFBS({len(self.channel_attention.size(1))})' | |
| class ElasticCNNUtil(ElasticDNNUtil): | |
| def convert_raw_dnn_to_master_dnn(self, raw_dnn: nn.Module, r: float, ignore_layers=[]): | |
| model = deepcopy(raw_dnn) | |
| # clear original BNs | |
| num_original_bns = 0 | |
| last_conv_name = None | |
| conv_bn_map = {} | |
| for name, module in model.named_modules(): | |
| if isinstance(module, nn.Conv2d): | |
| last_conv_name = name | |
| if isinstance(module, nn.BatchNorm2d) and (ignore_layers is not None and last_conv_name not in ignore_layers): | |
| num_original_bns += 1 | |
| conv_bn_map[last_conv_name] = name | |
| num_conv = 0 | |
| for name, module in model.named_modules(): | |
| if isinstance(module, nn.Conv2d) and (ignore_layers is not None and name not in ignore_layers): | |
| set_module(model, name, Conv2d_WrappedWithFBS(module, get_module(model, conv_bn_map[name]), r)) | |
| num_conv += 1 | |
| assert num_conv == num_original_bns | |
| for bn_layer in conv_bn_map.values(): | |
| set_module(model, bn_layer, nn.Identity()) | |
| return model | |
| def select_most_rep_sample(self, master_dnn: nn.Module, samples: torch.Tensor): | |
| return samples[0].unsqueeze(0) | |
| def extract_surrogate_dnn_via_samples(self, master_dnn: nn.Module, samples: torch.Tensor): | |
| sample = self.select_most_rep_sample(master_dnn, samples) | |
| assert sample.dim() == 4 and sample.size(0) == 1 | |
| master_dnn.eval() | |
| with torch.no_grad(): | |
| master_dnn_output = master_dnn(sample) | |
| pruning_info = {} | |
| pruning_masks = {} | |
| for layer_name, layer in master_dnn.named_modules(): | |
| if not isinstance(layer, Conv2d_WrappedWithFBS): | |
| continue | |
| cur_pruning_mask = {'weight': torch.zeros_like(layer.raw_conv2d.weight.data)} | |
| if layer.raw_conv2d.bias is not None: | |
| cur_pruning_mask['bias'] = torch.zeros_like(layer.raw_conv2d.bias.data) | |
| w = get_module(master_dnn, layer_name).cached_channel_attention.squeeze(0) | |
| unpruned_filters_index = w.nonzero(as_tuple=True)[0] | |
| pruning_info[layer_name] = w | |
| cur_pruning_mask['weight'][unpruned_filters_index, ...] = 1. | |
| if layer.raw_conv2d.bias is not None: | |
| cur_pruning_mask['bias'][unpruned_filters_index, ...] = 1. | |
| pruning_masks[layer_name + '.0'] = cur_pruning_mask | |
| surrogate_dnn = deepcopy(master_dnn) | |
| for name, layer in surrogate_dnn.named_modules(): | |
| if not isinstance(layer, Conv2d_WrappedWithFBS): | |
| continue | |
| set_module(surrogate_dnn, name, nn.Sequential(layer.raw_conv2d, layer.raw_bn, nn.Identity())) | |
| # fixed_pruning_masks = fix_mask_conflict(pruning_masks, fbs_model, sample.size(), None, True, True, True) | |
| tmp_mask_path = f'tmp_mask_{get_cur_time_str()}_{os.getpid()}.pth' | |
| torch.save(pruning_masks, tmp_mask_path) | |
| surrogate_dnn.eval() | |
| model_speedup = ModelSpeedup(surrogate_dnn, sample, tmp_mask_path, sample.device) | |
| model_speedup.speedup_model() | |
| os.remove(tmp_mask_path) | |
| # add feature boosting module | |
| for layer_name, feature_boosting_w in pruning_info.items(): | |
| feature_boosting_w = feature_boosting_w[feature_boosting_w.nonzero(as_tuple=True)[0]] | |
| set_module(surrogate_dnn, layer_name + '.2', StaticFBS(feature_boosting_w)) | |
| surrogate_dnn.eval() | |
| with torch.no_grad(): | |
| surrogate_dnn_output = surrogate_dnn(sample) | |
| output_diff = ((surrogate_dnn_output - master_dnn_output) ** 2).sum() | |
| assert output_diff < 1e-4, output_diff | |
| logger.info(f'output diff of master and surrogate DNN: {output_diff}') | |
| return surrogate_dnn | |