|
import torch |
|
from torch import nn |
|
from abc import ABC, abstractmethod |
|
|
|
from utils.dl.common.model import get_model_device, get_model_latency, get_model_size |
|
from utils.common.log import logger |
|
|
|
|
|
class KTakesAll(nn.Module): |
|
|
|
def __init__(self, k): |
|
super(KTakesAll, self).__init__() |
|
self.k = k |
|
self.cached_i = None |
|
|
|
def forward(self, g: torch.Tensor): |
|
|
|
|
|
|
|
|
|
k = int(g.size(-1) * self.k) |
|
i = (-g).topk(k, -1)[1] |
|
self.cached_i = i |
|
t = g.scatter(-1, i, 0) |
|
|
|
return t |
|
|
|
|
|
class Abs(nn.Module): |
|
def __init__(self): |
|
super(Abs, self).__init__() |
|
|
|
def forward(self, x): |
|
return x.abs() |
|
|
|
|
|
class Layer_WrappedWithFBS(nn.Module): |
|
def __init__(self): |
|
super(Layer_WrappedWithFBS, self).__init__() |
|
|
|
init_sparsity = 0.5 |
|
self.k_takes_all = KTakesAll(init_sparsity) |
|
|
|
self.cached_raw_channel_attention = None |
|
self.cached_channel_attention = None |
|
self.use_cached_channel_attention = False |
|
|
|
|
|
class ElasticDNNUtil(ABC): |
|
@abstractmethod |
|
def convert_raw_dnn_to_master_dnn(self, raw_dnn: nn.Module, r: float, ignore_layers=[]): |
|
raise NotImplementedError |
|
|
|
def convert_raw_dnn_to_master_dnn_with_perf_test(self, raw_dnn: nn.Module, r: float, ignore_layers=[]): |
|
raw_dnn_size = get_model_size(raw_dnn, True) |
|
master_dnn = self.convert_raw_dnn_to_master_dnn(raw_dnn, r, ignore_layers) |
|
master_dnn_size = get_model_size(master_dnn, True) |
|
|
|
logger.info(f'master DNN w/o FBS ({raw_dnn_size:.3f}MB) -> master DNN w/ FBS ({master_dnn_size:.3f}MB) ' |
|
f'(β {(((master_dnn_size - raw_dnn_size) / raw_dnn_size) * 100.):.2f}%)') |
|
return master_dnn |
|
|
|
def set_master_dnn_inference_via_cached_channel_attention(self, master_dnn: nn.Module): |
|
for name, module in master_dnn.named_modules(): |
|
if isinstance(module, Layer_WrappedWithFBS): |
|
assert module.cached_channel_attention is not None |
|
module.use_cached_channel_attention = True |
|
|
|
def set_master_dnn_dynamic_inference(self, master_dnn: nn.Module): |
|
for name, module in master_dnn.named_modules(): |
|
if isinstance(module, Layer_WrappedWithFBS): |
|
module.cached_channel_attention = None |
|
module.use_cached_channel_attention = False |
|
|
|
def train_only_fbs_of_master_dnn(self, master_dnn: nn.Module): |
|
fbs_params = [] |
|
for n, p in master_dnn.named_parameters(): |
|
if '.fbs' in n: |
|
fbs_params += [p] |
|
p.requires_grad = True |
|
else: |
|
p.requires_grad = False |
|
return fbs_params |
|
|
|
def get_accu_l1_reg_of_raw_channel_attention_in_master_dnn(self, master_dnn: nn.Module): |
|
res = 0. |
|
for name, module in master_dnn.named_modules(): |
|
if isinstance(module, Layer_WrappedWithFBS): |
|
res += module.cached_raw_channel_attention.norm(1) |
|
return res |
|
|
|
def get_raw_channel_attention_in_master_dnn(self, master_dnn: nn.Module): |
|
res = {} |
|
for name, module in master_dnn.named_modules(): |
|
if isinstance(module, Layer_WrappedWithFBS): |
|
res[name] = module.cached_raw_channel_attention |
|
return res |
|
|
|
def set_master_dnn_sparsity(self, master_dnn: nn.Module, sparsity: float): |
|
assert 0 <= sparsity <= 1., sparsity |
|
for name, module in master_dnn.named_modules(): |
|
if isinstance(module, KTakesAll): |
|
module.k = sparsity |
|
logger.debug(f'set master DNN sparsity to {sparsity}') |
|
|
|
def clear_cached_channel_attention_in_master_dnn(self, master_dnn: nn.Module): |
|
for name, module in master_dnn.named_modules(): |
|
if isinstance(module, Layer_WrappedWithFBS): |
|
module.cached_raw_channel_attention = None |
|
module.cached_channel_attention = None |
|
|
|
@abstractmethod |
|
def select_most_rep_sample(self, master_dnn: nn.Module, samples: torch.Tensor): |
|
raise NotImplementedError |
|
|
|
@abstractmethod |
|
def extract_surrogate_dnn_via_samples(self, master_dnn: nn.Module, samples: torch.Tensor, return_detail=False): |
|
raise NotImplementedError |
|
|
|
def extract_surrogate_dnn_via_samples_with_perf_test(self, master_dnn: nn.Module, samples: torch.Tensor, return_detail=False): |
|
master_dnn_size = get_model_size(master_dnn, True) |
|
master_dnn_latency = get_model_latency(master_dnn, (1, *list(samples.size())[1:]), 50, |
|
get_model_device(master_dnn), 50, False) |
|
|
|
res = self.extract_surrogate_dnn_via_samples(master_dnn, samples, return_detail) |
|
if not return_detail: |
|
surrogate_dnn = res |
|
else: |
|
surrogate_dnn, unpruned_indexes_of_layers = res |
|
surrogate_dnn_size = get_model_size(surrogate_dnn, True) |
|
surrogate_dnn_latency = get_model_latency(surrogate_dnn, (1, *list(samples.size())[1:]), 50, |
|
get_model_device(surrogate_dnn), 50, False) |
|
|
|
logger.info(f'master DNN ({master_dnn_size:.3f}MB, {master_dnn_latency:.4f}s/sample) -> ' |
|
f'surrogate DNN ({surrogate_dnn_size:.3f}MB, {surrogate_dnn_latency:.4f}s/sample)\n' |
|
f'(model size: β {(master_dnn_size / surrogate_dnn_size):.2f}x, ' |
|
f'latency: β {(master_dnn_latency / surrogate_dnn_latency):.2f}x)') |
|
|
|
return res |
|
|