Spaces:
Running
Running
from typing import Optional | |
import torch | |
from copy import deepcopy | |
from torch import nn | |
from utils.common.others import get_cur_time_str | |
from utils.dl.common.model import get_model_device, get_model_latency, get_model_size, get_module, get_super_module, set_module | |
from utils.common.log import logger | |
from utils.third_party.nni_new.compression.pytorch.speedup import ModelSpeedup | |
import os | |
from .base import Abs, KTakesAll, Layer_WrappedWithFBS, ElasticDNNUtil | |
class Conv2d_WrappedWithFBS(Layer_WrappedWithFBS): | |
def __init__(self, raw_conv2d: nn.Conv2d, raw_bn: nn.BatchNorm2d, r): | |
super(Conv2d_WrappedWithFBS, self).__init__() | |
self.fbs = nn.Sequential( | |
Abs(), | |
nn.AdaptiveAvgPool2d(1), | |
nn.Flatten(), | |
nn.Linear(raw_conv2d.in_channels, raw_conv2d.out_channels // r), | |
nn.ReLU(), | |
nn.Linear(raw_conv2d.out_channels // r, raw_conv2d.out_channels), | |
nn.ReLU() | |
) | |
self.raw_conv2d = raw_conv2d | |
self.raw_bn = raw_bn # remember clear the original BNs in the network | |
nn.init.constant_(self.fbs[5].bias, 1.) | |
nn.init.kaiming_normal_(self.fbs[5].weight) | |
def forward(self, x): | |
raw_x = self.raw_bn(self.raw_conv2d(x)) | |
if self.use_cached_channel_attention and self.cached_channel_attention is not None: | |
channel_attention = self.cached_channel_attention | |
else: | |
self.cached_raw_channel_attention = self.fbs(x) | |
self.cached_channel_attention = self.k_takes_all(self.cached_raw_channel_attention) | |
channel_attention = self.cached_channel_attention | |
return raw_x * channel_attention.unsqueeze(2).unsqueeze(3) | |
class StaticFBS(nn.Module): | |
def __init__(self, channel_attention: torch.Tensor): | |
super(StaticFBS, self).__init__() | |
assert channel_attention.dim() == 1 | |
self.channel_attention = nn.Parameter(channel_attention.unsqueeze(0).unsqueeze(2).unsqueeze(3), requires_grad=False) | |
def forward(self, x): | |
return x * self.channel_attention | |
def __str__(self) -> str: | |
return f'StaticFBS({len(self.channel_attention.size(1))})' | |
class ElasticCNNUtil(ElasticDNNUtil): | |
def convert_raw_dnn_to_master_dnn(self, raw_dnn: nn.Module, r: float, ignore_layers=[]): | |
model = deepcopy(raw_dnn) | |
# clear original BNs | |
num_original_bns = 0 | |
last_conv_name = None | |
conv_bn_map = {} | |
for name, module in model.named_modules(): | |
if isinstance(module, nn.Conv2d): | |
last_conv_name = name | |
if isinstance(module, nn.BatchNorm2d) and (ignore_layers is not None and last_conv_name not in ignore_layers): | |
num_original_bns += 1 | |
conv_bn_map[last_conv_name] = name | |
num_conv = 0 | |
for name, module in model.named_modules(): | |
if isinstance(module, nn.Conv2d) and (ignore_layers is not None and name not in ignore_layers): | |
set_module(model, name, Conv2d_WrappedWithFBS(module, get_module(model, conv_bn_map[name]), r)) | |
num_conv += 1 | |
assert num_conv == num_original_bns | |
for bn_layer in conv_bn_map.values(): | |
set_module(model, bn_layer, nn.Identity()) | |
return model | |
def select_most_rep_sample(self, master_dnn: nn.Module, samples: torch.Tensor): | |
return samples[0].unsqueeze(0) | |
def extract_surrogate_dnn_via_samples(self, master_dnn: nn.Module, samples: torch.Tensor): | |
sample = self.select_most_rep_sample(master_dnn, samples) | |
assert sample.dim() == 4 and sample.size(0) == 1 | |
master_dnn.eval() | |
with torch.no_grad(): | |
master_dnn_output = master_dnn(sample) | |
pruning_info = {} | |
pruning_masks = {} | |
for layer_name, layer in master_dnn.named_modules(): | |
if not isinstance(layer, Conv2d_WrappedWithFBS): | |
continue | |
cur_pruning_mask = {'weight': torch.zeros_like(layer.raw_conv2d.weight.data)} | |
if layer.raw_conv2d.bias is not None: | |
cur_pruning_mask['bias'] = torch.zeros_like(layer.raw_conv2d.bias.data) | |
w = get_module(master_dnn, layer_name).cached_channel_attention.squeeze(0) | |
unpruned_filters_index = w.nonzero(as_tuple=True)[0] | |
pruning_info[layer_name] = w | |
cur_pruning_mask['weight'][unpruned_filters_index, ...] = 1. | |
if layer.raw_conv2d.bias is not None: | |
cur_pruning_mask['bias'][unpruned_filters_index, ...] = 1. | |
pruning_masks[layer_name + '.0'] = cur_pruning_mask | |
surrogate_dnn = deepcopy(master_dnn) | |
for name, layer in surrogate_dnn.named_modules(): | |
if not isinstance(layer, Conv2d_WrappedWithFBS): | |
continue | |
set_module(surrogate_dnn, name, nn.Sequential(layer.raw_conv2d, layer.raw_bn, nn.Identity())) | |
# fixed_pruning_masks = fix_mask_conflict(pruning_masks, fbs_model, sample.size(), None, True, True, True) | |
tmp_mask_path = f'tmp_mask_{get_cur_time_str()}_{os.getpid()}.pth' | |
torch.save(pruning_masks, tmp_mask_path) | |
surrogate_dnn.eval() | |
model_speedup = ModelSpeedup(surrogate_dnn, sample, tmp_mask_path, sample.device) | |
model_speedup.speedup_model() | |
os.remove(tmp_mask_path) | |
# add feature boosting module | |
for layer_name, feature_boosting_w in pruning_info.items(): | |
feature_boosting_w = feature_boosting_w[feature_boosting_w.nonzero(as_tuple=True)[0]] | |
set_module(surrogate_dnn, layer_name + '.2', StaticFBS(feature_boosting_w)) | |
surrogate_dnn.eval() | |
with torch.no_grad(): | |
surrogate_dnn_output = surrogate_dnn(sample) | |
output_diff = ((surrogate_dnn_output - master_dnn_output) ** 2).sum() | |
assert output_diff < 1e-4, output_diff | |
logger.info(f'output diff of master and surrogate DNN: {output_diff}') | |
return surrogate_dnn | |