import copy import os import sys from tabnanny import verbose from typing import List, Optional, Tuple import torch from ...third_party.nni_new.algorithms.compression.pytorch.pruning import L1FilterPruner from ...third_party.nni_new.compression.pytorch.speedup import ModelSpeedup from ...common.others import get_cur_time_str def _prune_module(model, pruner, model_input_size, device, verbose=False, need_return_mask=False): pruner.compress() pid = os.getpid() timestamp = get_cur_time_str() tmp_model_path = './tmp_weight-{}-{}.pth'.format(pid, timestamp) tmp_mask_path = './tmp_mask-{}-{}.pth'.format(pid, timestamp) pruner.export_model(model_path=tmp_model_path, mask_path=tmp_mask_path) os.remove(tmp_model_path) # speed up dummy_input = torch.rand(model_input_size).to(device) pruned_model = model pruned_model.eval() model_speedup = ModelSpeedup(pruned_model, dummy_input, tmp_mask_path, device) fixed_mask = model_speedup.speedup_model() if not need_return_mask: os.remove(tmp_mask_path) return pruned_model else: mask = fixed_mask os.remove(tmp_mask_path) return pruned_model, mask def l1_prune_model(model: torch.nn.Module, pruned_layers_name: Optional[List[str]], sparsity: float, model_input_size: Tuple[int], device: str, verbose=False, need_return_mask=False, dep_aware=False): """Get the pruned model via L1 Filter Pruning. Reference: Li H, Kadav A, Durdanovic I, et al. Pruning filters for efficient convnets[J]. arXiv preprint arXiv:1608.08710, 2016. Args: model (torch.nn.Module): A PyTorch model. pruned_layers_name (Optional[List[str]]): Which layers will be pruned. If it's `None`, all layers will be pruned. sparsity (float): Target sparsity. The pruned model is smaller if sparsity is higher. model_input_size (Tuple[int]): Typically be `(1, 3, 32, 32)` or `(1, 3, 224, 224)`. device (str): Typically be 'cpu' or 'cuda'. verbose (bool, optional): Whether to output the verbose log. Defaults to False. (BUG TO FIX) need_return_mask (bool, optional): Return the fine-grained mask generated by NNI framework for debug. Defaults to False. dep_aware (bool, optional): Refers to the argument `dependency_aware` in NNI framework. Defaults to False. Returns: torch.nn.Module: Pruned model. """ model = copy.deepcopy(model).to(device) if sparsity == 0: return model pruned_model = copy.deepcopy(model).to(device) # generate mask model.eval() if pruned_layers_name is not None: config_list = [{ 'op_types': ['Conv2d', 'ConvTranspose2d'], 'op_names': pruned_layers_name, 'sparsity': sparsity }] else: config_list = [{ 'op_types': ['Conv2d', 'ConvTranspose2d'], 'sparsity': sparsity }] pruner = L1FilterPruner(model, config_list, dependency_aware=dep_aware, dummy_input=torch.rand(model_input_size).to(device) if dep_aware else None) pruned_model = _prune_module(pruned_model, pruner, model_input_size, device, verbose, need_return_mask) return pruned_model