LINC-BIT's picture
Upload 1912 files
b84549f verified
from typing import Optional
import torch
from copy import deepcopy
from torch import nn
from utils.common.others import get_cur_time_str
from utils.dl.common.model import get_model_device, get_model_latency, get_model_size, get_module, get_super_module, set_module
from utils.common.log import logger
from utils.third_party.nni_new.compression.pytorch.speedup import ModelSpeedup
import os
from .base import Abs, KTakesAll, Layer_WrappedWithFBS, ElasticDNNUtil
class Conv2d_WrappedWithFBS(Layer_WrappedWithFBS):
def __init__(self, raw_conv2d: nn.Conv2d, raw_bn: nn.BatchNorm2d, r):
super(Conv2d_WrappedWithFBS, self).__init__()
self.fbs = nn.Sequential(
Abs(),
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(raw_conv2d.in_channels, raw_conv2d.out_channels // r),
nn.ReLU(),
nn.Linear(raw_conv2d.out_channels // r, raw_conv2d.out_channels),
nn.ReLU()
)
self.raw_conv2d = raw_conv2d
self.raw_bn = raw_bn # remember clear the original BNs in the network
nn.init.constant_(self.fbs[5].bias, 1.)
nn.init.kaiming_normal_(self.fbs[5].weight)
def forward(self, x):
raw_x = self.raw_bn(self.raw_conv2d(x))
if self.use_cached_channel_attention and self.cached_channel_attention is not None:
channel_attention = self.cached_channel_attention
else:
self.cached_raw_channel_attention = self.fbs(x)
self.cached_channel_attention = self.k_takes_all(self.cached_raw_channel_attention)
channel_attention = self.cached_channel_attention
return raw_x * channel_attention.unsqueeze(2).unsqueeze(3)
class StaticFBS(nn.Module):
def __init__(self, channel_attention: torch.Tensor):
super(StaticFBS, self).__init__()
assert channel_attention.dim() == 1
self.channel_attention = nn.Parameter(channel_attention.unsqueeze(0).unsqueeze(2).unsqueeze(3), requires_grad=False)
def forward(self, x):
return x * self.channel_attention
def __str__(self) -> str:
return f'StaticFBS({len(self.channel_attention.size(1))})'
class ElasticCNNUtil(ElasticDNNUtil):
def convert_raw_dnn_to_master_dnn(self, raw_dnn: nn.Module, r: float, ignore_layers=[]):
model = deepcopy(raw_dnn)
# clear original BNs
num_original_bns = 0
last_conv_name = None
conv_bn_map = {}
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
last_conv_name = name
if isinstance(module, nn.BatchNorm2d) and (ignore_layers is not None and last_conv_name not in ignore_layers):
num_original_bns += 1
conv_bn_map[last_conv_name] = name
num_conv = 0
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d) and (ignore_layers is not None and name not in ignore_layers):
set_module(model, name, Conv2d_WrappedWithFBS(module, get_module(model, conv_bn_map[name]), r))
num_conv += 1
assert num_conv == num_original_bns
for bn_layer in conv_bn_map.values():
set_module(model, bn_layer, nn.Identity())
return model
def select_most_rep_sample(self, master_dnn: nn.Module, samples: torch.Tensor):
return samples[0].unsqueeze(0)
def extract_surrogate_dnn_via_samples(self, master_dnn: nn.Module, samples: torch.Tensor):
sample = self.select_most_rep_sample(master_dnn, samples)
assert sample.dim() == 4 and sample.size(0) == 1
master_dnn.eval()
with torch.no_grad():
master_dnn_output = master_dnn(sample)
pruning_info = {}
pruning_masks = {}
for layer_name, layer in master_dnn.named_modules():
if not isinstance(layer, Conv2d_WrappedWithFBS):
continue
cur_pruning_mask = {'weight': torch.zeros_like(layer.raw_conv2d.weight.data)}
if layer.raw_conv2d.bias is not None:
cur_pruning_mask['bias'] = torch.zeros_like(layer.raw_conv2d.bias.data)
w = get_module(master_dnn, layer_name).cached_channel_attention.squeeze(0)
unpruned_filters_index = w.nonzero(as_tuple=True)[0]
pruning_info[layer_name] = w
cur_pruning_mask['weight'][unpruned_filters_index, ...] = 1.
if layer.raw_conv2d.bias is not None:
cur_pruning_mask['bias'][unpruned_filters_index, ...] = 1.
pruning_masks[layer_name + '.0'] = cur_pruning_mask
surrogate_dnn = deepcopy(master_dnn)
for name, layer in surrogate_dnn.named_modules():
if not isinstance(layer, Conv2d_WrappedWithFBS):
continue
set_module(surrogate_dnn, name, nn.Sequential(layer.raw_conv2d, layer.raw_bn, nn.Identity()))
# fixed_pruning_masks = fix_mask_conflict(pruning_masks, fbs_model, sample.size(), None, True, True, True)
tmp_mask_path = f'tmp_mask_{get_cur_time_str()}_{os.getpid()}.pth'
torch.save(pruning_masks, tmp_mask_path)
surrogate_dnn.eval()
model_speedup = ModelSpeedup(surrogate_dnn, sample, tmp_mask_path, sample.device)
model_speedup.speedup_model()
os.remove(tmp_mask_path)
# add feature boosting module
for layer_name, feature_boosting_w in pruning_info.items():
feature_boosting_w = feature_boosting_w[feature_boosting_w.nonzero(as_tuple=True)[0]]
set_module(surrogate_dnn, layer_name + '.2', StaticFBS(feature_boosting_w))
surrogate_dnn.eval()
with torch.no_grad():
surrogate_dnn_output = surrogate_dnn(sample)
output_diff = ((surrogate_dnn_output - master_dnn_output) ** 2).sum()
assert output_diff < 1e-4, output_diff
logger.info(f'output diff of master and surrogate DNN: {output_diff}')
return surrogate_dnn