|
from copy import deepcopy |
|
from typing import Optional, Union |
|
import torch |
|
from torch import nn |
|
from einops import rearrange, repeat |
|
from einops.layers.torch import Rearrange |
|
import tqdm |
|
|
|
from utils.dl.common.model import get_model_device, get_model_size, set_module, get_module |
|
from .base import Abs, KTakesAll, ElasticDNNUtil, Layer_WrappedWithFBS |
|
from utils.common.log import logger |
|
|
|
|
|
class SqueezeLast(nn.Module): |
|
def __init__(self): |
|
super(SqueezeLast, self).__init__() |
|
|
|
def forward(self, x): |
|
return x.squeeze(-1) |
|
|
|
|
|
class ProjConv_WrappedWithFBS(Layer_WrappedWithFBS): |
|
def __init__(self, proj: nn.Conv2d, r): |
|
super(ProjConv_WrappedWithFBS, self).__init__() |
|
|
|
self.proj = proj |
|
|
|
|
|
|
|
self.fbs = nn.Sequential( |
|
Abs(), |
|
nn.AdaptiveAvgPool1d(1), |
|
SqueezeLast(), |
|
nn.Linear(proj.in_channels, proj.out_channels // r), |
|
nn.ReLU(), |
|
nn.Linear(proj.out_channels // r, proj.out_channels), |
|
nn.ReLU() |
|
) |
|
|
|
nn.init.constant_(self.fbs[6].bias, 1.) |
|
nn.init.kaiming_normal_(self.fbs[6].weight) |
|
|
|
|
|
def forward(self, x): |
|
if self.use_cached_channel_attention and self.cached_channel_attention is not None: |
|
channel_attention = self.cached_channel_attention |
|
else: |
|
self.cached_raw_channel_attention = self.fbs(x) |
|
self.cached_channel_attention = self.k_takes_all(self.cached_raw_channel_attention) |
|
|
|
channel_attention = self.cached_channel_attention |
|
|
|
raw_res = self.proj(x) |
|
|
|
return channel_attention.unsqueeze(1) * raw_res |
|
|
|
|
|
class Linear_WrappedWithFBS(Layer_WrappedWithFBS): |
|
def __init__(self, linear: nn.Linear, r): |
|
super(Linear_WrappedWithFBS, self).__init__() |
|
|
|
self.linear = linear |
|
|
|
|
|
|
|
self.fbs = nn.Sequential( |
|
Rearrange('b n d -> b d n'), |
|
Abs(), |
|
nn.AdaptiveAvgPool1d(1), |
|
SqueezeLast(), |
|
nn.Linear(linear.in_features, linear.out_features // r), |
|
nn.ReLU(), |
|
nn.Linear(linear.out_features // r, linear.out_features), |
|
nn.ReLU() |
|
) |
|
|
|
nn.init.constant_(self.fbs[6].bias, 1.) |
|
nn.init.kaiming_normal_(self.fbs[6].weight) |
|
|
|
|
|
def forward(self, x): |
|
if self.use_cached_channel_attention and self.cached_channel_attention is not None: |
|
channel_attention = self.cached_channel_attention |
|
else: |
|
self.cached_raw_channel_attention = self.fbs(x) |
|
self.cached_channel_attention = self.k_takes_all(self.cached_raw_channel_attention) |
|
|
|
channel_attention = self.cached_channel_attention |
|
|
|
raw_res = self.linear(x) |
|
|
|
return channel_attention.unsqueeze(1) * raw_res |
|
|
|
|
|
class ToQKV_WrappedWithFBS(Layer_WrappedWithFBS): |
|
""" |
|
This regards to_q/to_k/to_v as a whole (in fact it consists of multiple heads) and prunes it. |
|
It seems different channels of different heads are pruned according to the input. |
|
This is different from "removing some head" or "removing the same channels in each head". |
|
""" |
|
def __init__(self, to_qkv: nn.Linear, r): |
|
super(ToQKV_WrappedWithFBS, self).__init__() |
|
|
|
|
|
|
|
self.to_qk = nn.Linear(to_qkv.in_features, to_qkv.out_features // 3 * 2, bias=to_qkv.bias is not None) |
|
self.to_v = nn.Linear(to_qkv.in_features, to_qkv.out_features // 3, bias=to_qkv.bias is not None) |
|
self.to_qk.weight.data.copy_(to_qkv.weight.data[0: to_qkv.out_features // 3 * 2]) |
|
if to_qkv.bias is not None: |
|
self.to_qk.bias.data.copy_(to_qkv.bias.data[0: to_qkv.out_features // 3 * 2]) |
|
self.to_v.weight.data.copy_(to_qkv.weight.data[to_qkv.out_features // 3 * 2: ]) |
|
if to_qkv.bias is not None: |
|
self.to_v.bias.data.copy_(to_qkv.bias.data[to_qkv.out_features // 3 * 2: ]) |
|
|
|
self.fbs = nn.Sequential( |
|
Rearrange('b n d -> b d n'), |
|
Abs(), |
|
nn.AdaptiveAvgPool1d(1), |
|
SqueezeLast(), |
|
nn.Linear(to_qkv.in_features, to_qkv.out_features // 3 // r), |
|
nn.ReLU(), |
|
|
|
nn.Linear(to_qkv.out_features // 3 // r, self.to_v.out_features), |
|
nn.ReLU() |
|
) |
|
|
|
nn.init.constant_(self.fbs[6].bias, 1.) |
|
nn.init.kaiming_normal_(self.fbs[6].weight) |
|
|
|
def forward(self, x): |
|
if self.use_cached_channel_attention and self.cached_channel_attention is not None: |
|
channel_attention = self.cached_channel_attention |
|
else: |
|
self.cached_raw_channel_attention = self.fbs(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.cached_channel_attention = self.k_takes_all(self.cached_raw_channel_attention) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
channel_attention = self.cached_channel_attention |
|
|
|
qk = self.to_qk(x) |
|
v = channel_attention.unsqueeze(1) * self.to_v(x) |
|
return torch.cat([qk, v], dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StaticFBS(nn.Module): |
|
def __init__(self, static_channel_attention): |
|
super(StaticFBS, self).__init__() |
|
assert static_channel_attention.dim() == 2 and static_channel_attention.size(0) == 1 |
|
self.static_channel_attention = nn.Parameter(static_channel_attention, requires_grad=False) |
|
|
|
def forward(self, x): |
|
|
|
return x * self.static_channel_attention.unsqueeze(1) |
|
|
|
|
|
class ElasticViltUtil(ElasticDNNUtil): |
|
def convert_raw_dnn_to_master_dnn(self, raw_dnn: nn.Module, r: float, ignore_layers=[]): |
|
assert len(ignore_layers) == 0, 'not supported yet' |
|
|
|
raw_vit = deepcopy(raw_dnn) |
|
|
|
|
|
|
|
for name, module in raw_vit.named_modules(): |
|
|
|
|
|
if name.endswith('intermediate'): |
|
set_module(module, 'dense', Linear_WrappedWithFBS(module.dense, r)) |
|
|
|
return raw_vit |
|
|
|
def set_master_dnn_sparsity(self, master_dnn: nn.Module, sparsity: float): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return super().set_master_dnn_sparsity(master_dnn, sparsity) |
|
|
|
def select_most_rep_sample(self, master_dnn: nn.Module, samples: torch.Tensor): |
|
|
|
|
|
res = {k: v[0: 1] for k, v in samples.items()} |
|
return res |
|
|
|
def extract_surrogate_dnn_via_samples(self, master_dnn: nn.Module, samples: torch.Tensor, return_detail=False): |
|
sample = self.select_most_rep_sample(master_dnn, samples) |
|
|
|
|
|
|
|
master_dnn.eval() |
|
self.clear_cached_channel_attention_in_master_dnn(master_dnn) |
|
with torch.no_grad(): |
|
master_dnn_output = master_dnn(**sample) |
|
|
|
|
|
|
|
boosted_vit = deepcopy(master_dnn) |
|
|
|
def get_unpruned_indexes_from_channel_attn(channel_attn: torch.Tensor, k): |
|
assert channel_attn.size(0) == 1, 'use A representative sample to generate channel attentions' |
|
|
|
|
|
|
|
res = channel_attn[0].nonzero(as_tuple=True)[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return res |
|
|
|
unpruned_indexes_of_layers = {} |
|
|
|
|
|
|
|
for block_i, block in enumerate(boosted_vit.vilt.encoder.layer): |
|
|
|
|
|
|
|
ff_0 = get_module(block, f'intermediate.dense') |
|
|
|
ff_0_pruned_indexes = ff_0.k_takes_all.cached_i[0].sort()[0] |
|
ff_0_unpruned_indexes = torch.LongTensor([ii for ii in range(ff_0.cached_channel_attention.size(1)) if ii not in ff_0_pruned_indexes]) |
|
new_ff_0 = nn.Linear(ff_0.linear.in_features, ff_0_unpruned_indexes.size(0), ff_0.linear.bias is not None) |
|
new_ff_0.weight.data.copy_(ff_0.linear.weight.data[ff_0_unpruned_indexes]) |
|
if ff_0.linear.bias is not None: |
|
new_ff_0.bias.data.copy_(ff_0.linear.bias.data[ff_0_unpruned_indexes]) |
|
set_module(block, 'intermediate.dense', nn.Sequential(new_ff_0, StaticFBS(ff_0.cached_channel_attention[:, ff_0_unpruned_indexes]))) |
|
|
|
ff_1 = get_module(block, f'output.dense') |
|
new_ff_1 = nn.Linear(ff_0_unpruned_indexes.size(0), ff_1.out_features, ff_1.bias is not None) |
|
new_ff_1.weight.data.copy_(ff_1.weight.data[:, ff_0_unpruned_indexes]) |
|
if ff_1.bias is not None: |
|
new_ff_1.bias.data.copy_(ff_1.bias.data) |
|
set_module(block, 'output.dense', new_ff_1) |
|
|
|
unpruned_indexes_of_layers[f'vilt.encoder.layer.{block_i}.intermediate.dense.0.weight'] = ff_0_unpruned_indexes |
|
|
|
surrogate_dnn = boosted_vit |
|
surrogate_dnn.eval() |
|
surrogate_dnn = surrogate_dnn.to(get_model_device(master_dnn)) |
|
|
|
with torch.no_grad(): |
|
surrogate_dnn_output = surrogate_dnn(**sample) |
|
|
|
output_diff = ((surrogate_dnn_output.logits - master_dnn_output.logits) ** 2).sum() |
|
|
|
logger.info(f'output diff of master and surrogate DNN: {output_diff}') |
|
|
|
|
|
|
|
|
|
if return_detail: |
|
return boosted_vit, unpruned_indexes_of_layers |
|
|
|
return boosted_vit |
|
|
|
def extract_surrogate_dnn_via_samples_with_perf_test(self, master_dnn: nn.Module, samples: torch.Tensor, return_detail=False): |
|
master_dnn_size = get_model_size(master_dnn, True) |
|
master_dnn_latency = self._get_model_latency(master_dnn, samples, 50, |
|
get_model_device(master_dnn), 50, False) |
|
|
|
res = self.extract_surrogate_dnn_via_samples(master_dnn, samples, return_detail) |
|
if not return_detail: |
|
surrogate_dnn = res |
|
else: |
|
surrogate_dnn, unpruned_indexes_of_layers = res |
|
surrogate_dnn_size = get_model_size(surrogate_dnn, True) |
|
surrogate_dnn_latency = self._get_model_latency(master_dnn, samples, 50, |
|
get_model_device(master_dnn), 50, False) |
|
|
|
logger.info(f'master DNN ({master_dnn_size:.3f}MB, {master_dnn_latency:.4f}s/sample) -> ' |
|
f'surrogate DNN ({surrogate_dnn_size:.3f}MB, {surrogate_dnn_latency:.4f}s/sample)\n' |
|
f'(model size: ↓ {(master_dnn_size / surrogate_dnn_size):.2f}x, ' |
|
f'latency: ↓ {(master_dnn_latency / surrogate_dnn_latency):.2f}x)') |
|
|
|
return res |
|
|
|
def _get_model_latency(self, model: torch.nn.Module, model_input_size, sample_num: int, |
|
device: str, warmup_sample_num: int, return_detail=False): |
|
import time |
|
|
|
if isinstance(model_input_size, tuple): |
|
dummy_input = torch.rand(model_input_size).to(device) |
|
else: |
|
dummy_input = model_input_size |
|
|
|
model = model.to(device) |
|
model.eval() |
|
|
|
|
|
with torch.no_grad(): |
|
for _ in range(warmup_sample_num): |
|
model(**dummy_input) |
|
|
|
infer_time_list = [] |
|
|
|
if device == 'cuda' or 'cuda' in str(device): |
|
with torch.no_grad(): |
|
for _ in range(sample_num): |
|
s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) |
|
s.record() |
|
model(**dummy_input) |
|
e.record() |
|
torch.cuda.synchronize() |
|
cur_model_infer_time = s.elapsed_time(e) / 1000. |
|
infer_time_list += [cur_model_infer_time] |
|
|
|
else: |
|
with torch.no_grad(): |
|
for _ in range(sample_num): |
|
start = time.time() |
|
model(**dummy_input) |
|
cur_model_infer_time = time.time() - start |
|
infer_time_list += [cur_model_infer_time] |
|
|
|
avg_infer_time = sum(infer_time_list) / sample_num |
|
|
|
if return_detail: |
|
return avg_infer_time, infer_time_list |
|
return avg_infer_time |