|
from typing import Optional |
|
import torch |
|
from copy import deepcopy |
|
from torch import nn |
|
from utils.common.others import get_cur_time_str |
|
from utils.dl.common.model import get_model_device, get_model_latency, get_model_size, get_module, get_super_module, set_module |
|
from utils.common.log import logger |
|
from utils.third_party.nni_new.compression.pytorch.speedup import ModelSpeedup |
|
import os |
|
|
|
from .base import Abs, KTakesAll, Layer_WrappedWithFBS, ElasticDNNUtil |
|
|
|
|
|
class Conv2d_WrappedWithFBS(Layer_WrappedWithFBS): |
|
def __init__(self, raw_conv2d: nn.Conv2d, raw_bn: nn.BatchNorm2d, r): |
|
super(Conv2d_WrappedWithFBS, self).__init__() |
|
|
|
self.fbs = nn.Sequential( |
|
Abs(), |
|
nn.AdaptiveAvgPool2d(1), |
|
nn.Flatten(), |
|
nn.Linear(raw_conv2d.in_channels, raw_conv2d.out_channels // r), |
|
nn.ReLU(), |
|
nn.Linear(raw_conv2d.out_channels // r, raw_conv2d.out_channels), |
|
nn.ReLU() |
|
) |
|
|
|
self.raw_conv2d = raw_conv2d |
|
self.raw_bn = raw_bn |
|
|
|
nn.init.constant_(self.fbs[5].bias, 1.) |
|
nn.init.kaiming_normal_(self.fbs[5].weight) |
|
|
|
def forward(self, x): |
|
raw_x = self.raw_bn(self.raw_conv2d(x)) |
|
|
|
if self.use_cached_channel_attention and self.cached_channel_attention is not None: |
|
channel_attention = self.cached_channel_attention |
|
else: |
|
self.cached_raw_channel_attention = self.fbs(x) |
|
self.cached_channel_attention = self.k_takes_all(self.cached_raw_channel_attention) |
|
|
|
channel_attention = self.cached_channel_attention |
|
|
|
return raw_x * channel_attention.unsqueeze(2).unsqueeze(3) |
|
|
|
|
|
class StaticFBS(nn.Module): |
|
def __init__(self, channel_attention: torch.Tensor): |
|
super(StaticFBS, self).__init__() |
|
assert channel_attention.dim() == 1 |
|
self.channel_attention = nn.Parameter(channel_attention.unsqueeze(0).unsqueeze(2).unsqueeze(3), requires_grad=False) |
|
|
|
def forward(self, x): |
|
return x * self.channel_attention |
|
|
|
def __str__(self) -> str: |
|
return f'StaticFBS({len(self.channel_attention.size(1))})' |
|
|
|
|
|
class ElasticCNNUtil(ElasticDNNUtil): |
|
def convert_raw_dnn_to_master_dnn(self, raw_dnn: nn.Module, r: float, ignore_layers=[]): |
|
model = deepcopy(raw_dnn) |
|
|
|
|
|
num_original_bns = 0 |
|
last_conv_name = None |
|
conv_bn_map = {} |
|
for name, module in model.named_modules(): |
|
if isinstance(module, nn.Conv2d): |
|
last_conv_name = name |
|
if isinstance(module, nn.BatchNorm2d) and (ignore_layers is not None and last_conv_name not in ignore_layers): |
|
num_original_bns += 1 |
|
conv_bn_map[last_conv_name] = name |
|
|
|
num_conv = 0 |
|
for name, module in model.named_modules(): |
|
if isinstance(module, nn.Conv2d) and (ignore_layers is not None and name not in ignore_layers): |
|
set_module(model, name, Conv2d_WrappedWithFBS(module, get_module(model, conv_bn_map[name]), r)) |
|
num_conv += 1 |
|
|
|
assert num_conv == num_original_bns |
|
|
|
for bn_layer in conv_bn_map.values(): |
|
set_module(model, bn_layer, nn.Identity()) |
|
|
|
return model |
|
|
|
def select_most_rep_sample(self, master_dnn: nn.Module, samples: torch.Tensor): |
|
return samples[0].unsqueeze(0) |
|
|
|
def extract_surrogate_dnn_via_samples(self, master_dnn: nn.Module, samples: torch.Tensor): |
|
sample = self.select_most_rep_sample(master_dnn, samples) |
|
assert sample.dim() == 4 and sample.size(0) == 1 |
|
|
|
master_dnn.eval() |
|
with torch.no_grad(): |
|
master_dnn_output = master_dnn(sample) |
|
|
|
pruning_info = {} |
|
pruning_masks = {} |
|
|
|
for layer_name, layer in master_dnn.named_modules(): |
|
if not isinstance(layer, Conv2d_WrappedWithFBS): |
|
continue |
|
|
|
cur_pruning_mask = {'weight': torch.zeros_like(layer.raw_conv2d.weight.data)} |
|
if layer.raw_conv2d.bias is not None: |
|
cur_pruning_mask['bias'] = torch.zeros_like(layer.raw_conv2d.bias.data) |
|
|
|
w = get_module(master_dnn, layer_name).cached_channel_attention.squeeze(0) |
|
unpruned_filters_index = w.nonzero(as_tuple=True)[0] |
|
pruning_info[layer_name] = w |
|
|
|
cur_pruning_mask['weight'][unpruned_filters_index, ...] = 1. |
|
if layer.raw_conv2d.bias is not None: |
|
cur_pruning_mask['bias'][unpruned_filters_index, ...] = 1. |
|
pruning_masks[layer_name + '.0'] = cur_pruning_mask |
|
|
|
surrogate_dnn = deepcopy(master_dnn) |
|
for name, layer in surrogate_dnn.named_modules(): |
|
if not isinstance(layer, Conv2d_WrappedWithFBS): |
|
continue |
|
set_module(surrogate_dnn, name, nn.Sequential(layer.raw_conv2d, layer.raw_bn, nn.Identity())) |
|
|
|
|
|
tmp_mask_path = f'tmp_mask_{get_cur_time_str()}_{os.getpid()}.pth' |
|
torch.save(pruning_masks, tmp_mask_path) |
|
surrogate_dnn.eval() |
|
model_speedup = ModelSpeedup(surrogate_dnn, sample, tmp_mask_path, sample.device) |
|
model_speedup.speedup_model() |
|
os.remove(tmp_mask_path) |
|
|
|
|
|
for layer_name, feature_boosting_w in pruning_info.items(): |
|
feature_boosting_w = feature_boosting_w[feature_boosting_w.nonzero(as_tuple=True)[0]] |
|
set_module(surrogate_dnn, layer_name + '.2', StaticFBS(feature_boosting_w)) |
|
|
|
surrogate_dnn.eval() |
|
with torch.no_grad(): |
|
surrogate_dnn_output = surrogate_dnn(sample) |
|
output_diff = ((surrogate_dnn_output - master_dnn_output) ** 2).sum() |
|
assert output_diff < 1e-4, output_diff |
|
logger.info(f'output diff of master and surrogate DNN: {output_diff}') |
|
|
|
return surrogate_dnn |
|
|