LINC-BIT's picture
Upload 1912 files
b84549f verified
import torch
from torch import nn
from methods.elasticdnn.model.base import ElasticDNNUtil
def test(raw_dnn: nn.Module, ignore_layers, elastic_dnn_util: ElasticDNNUtil, input_sample: torch.Tensor, sparsity):
# raw_dnn.eval()
# with torch.no_grad():
# raw_dnn(input_sample)
master_dnn = elastic_dnn_util.convert_raw_dnn_to_master_dnn_with_perf_test(raw_dnn, 16, ignore_layers)
# print(master_dnn)
# exit()
elastic_dnn_util.set_master_dnn_sparsity(master_dnn, sparsity)
# master_dnn.eval()
# with torch.no_grad():
# master_dnn(input_sample)
surrogate_dnn = elastic_dnn_util.extract_surrogate_dnn_via_samples_with_perf_test(master_dnn, input_sample)
if __name__ == '__main__':
from utils.dl.common.env import set_random_seed
set_random_seed(1)
# from torchvision.models import resnet50
# from methods.elasticdnn.model.cnn import ElasticCNNUtil
# raw_cnn = resnet50()
# prunable_layers = []
# for i in range(1, 5):
# for j in range([3, 4, 6, 3][i - 1]):
# prunable_layers += [f'layer{i}.{j}.conv1', f'layer{i}.{j}.conv2']
# ignore_layers = [layer for layer, m in raw_cnn.named_modules() if isinstance(m, nn.Conv2d) and layer not in prunable_layers]
# test(raw_cnn, ignore_layers, ElasticCNNUtil(), torch.rand(1, 3, 224, 224))
ignore_layers = []
from methods.elasticdnn.model.vit import ElasticViTUtil
# raw_vit = torch.load('tmp-master-dnn.pt')
raw_vit = torch.load('')
test(raw_vit, ignore_layers, ElasticViTUtil(), torch.rand(16, 3, 224, 224).cuda(), 0.9)
exit()
from dnns.vit import vit_b_16
# from methods.elasticdnn.model.vit_new import ElasticViTUtil
from methods.elasticdnn.model.vit import ElasticViTUtil
# raw_vit = vit_b_16()
for s in [0.8, 0.9, 0.95]:
raw_vit = vit_b_16().cuda()
ignore_layers = []
test(raw_vit, ignore_layers, ElasticViTUtil(), torch.rand(16, 3, 224, 224).cuda(), s)
# for s in [0, 0.2, 0.4, 0.6, 0.8]:
# pretrained_md_models_dict_path = 'experiments/elasticdnn/vit_b_16/offline/fm_to_md/results/20230518/999999-164524-wo_FBS_trial_dsnet_lr/models/md_best.pt'
# raw_vit = torch.load(pretrained_md_models_dict_path)['main'].cuda()
# ignore_layers = []
# test(raw_vit, ignore_layers, ElasticViTUtil(), torch.rand(16, 3, 224, 224).cuda(), s)
# exit()
# weight = torch.rand((10, 5))
# bias = torch.rand(10)
# x = torch.rand((1, 3, 5))
# t = torch.randperm(5)
# pruned, unpruned = t[0: 3], t[3: ]
# mask = torch.ones_like(x)
# mask[:, :, pruned] = 0
# print(x, x * mask, (x * mask).sum((0, 1)))
# import torch.nn.functional as F
# o1 = F.linear(x * mask, weight, bias)
# # print(o1)
# o2 = F.linear(x[:, :, unpruned], weight[:, unpruned], bias)
# # print(o2)
# print(o1.size(), o2.size(), ((o1 - o2) ** 2).sum())
# weight = torch.rand((130, 5))
# bias = torch.rand(130)
# x = torch.rand((1, 3, 5))
# t = torch.randperm(5)
# pruned, unpruned = t[0: 3], t[3: ]
# mask = torch.ones_like(x)
# mask[:, :, pruned] = 0
# print(x, x * mask, (x * mask).sum((0, 1)))
# import torch.nn.functional as F
# o1 = F.linear(x * mask, weight, bias)
# # print(o1)
# o2 = F.linear(x[:, :, unpruned], weight[:, unpruned], bias)
# # print(o2)
# print(o1.size(), o2.size(), ((o1 - o2) ** 2).sum())
# weight = torch.rand((1768, 768))
# bias = torch.rand(1768)
# x = torch.rand([1, 197, 768])
# t = torch.randperm(768)
# unpruned, pruned = t[0: 144], t[144: ]
# unpruned = unpruned.sort()[0]
# pruned = pruned.sort()[0]
# mask = torch.ones_like(x)
# mask[:, :, pruned] = 0
# print(x.sum((0, 1)).size(), (x * mask).sum((0, 1))[0: 10], x[:, :, unpruned].sum((0, 1))[0: 10])
# import torch.nn.functional as F
# o1 = F.linear(x * mask, weight, bias)
# o2 = F.linear(x[:, :, unpruned], weight[:, unpruned], bias)
# print(o1.sum((0, 1))[0: 10], o2.sum((0, 1))[0: 10], o1.size(), o2.size(), ((o1 - o2).abs()).sum(), ((o1 - o2) ** 2).sum())
# unpruned_indexes = torch.randperm(5)[0: 2]
# o2 = F.linear(x[:, unpruned_indexes], weight[:, unpruned_indexes])
# print(o2)