File size: 3,291 Bytes
b84549f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import copy
import os
import sys
from tabnanny import verbose
from typing import List, Optional, Tuple
import torch

from ...third_party.nni_new.algorithms.compression.pytorch.pruning import L1FilterPruner
from ...third_party.nni_new.compression.pytorch.speedup import ModelSpeedup
from ...common.others import get_cur_time_str


def _prune_module(model, pruner, model_input_size, device, verbose=False, need_return_mask=False):
    pruner.compress()
    pid = os.getpid()
    timestamp = get_cur_time_str()
    tmp_model_path = './tmp_weight-{}-{}.pth'.format(pid, timestamp)
    tmp_mask_path = './tmp_mask-{}-{}.pth'.format(pid, timestamp)
    pruner.export_model(model_path=tmp_model_path, mask_path=tmp_mask_path)
    os.remove(tmp_model_path)

    # speed up
    dummy_input = torch.rand(model_input_size).to(device)
    pruned_model = model
    pruned_model.eval()
    model_speedup = ModelSpeedup(pruned_model, dummy_input, tmp_mask_path, device)
    fixed_mask = model_speedup.speedup_model()
    
    if not need_return_mask:
        os.remove(tmp_mask_path)
        return pruned_model
    else:
        mask = fixed_mask
        os.remove(tmp_mask_path)
        return pruned_model, mask
    

def l1_prune_model(model: torch.nn.Module, pruned_layers_name: Optional[List[str]], sparsity: float, 
                   model_input_size: Tuple[int], device: str, verbose=False, need_return_mask=False, dep_aware=False):
    """Get the pruned model via L1 Filter Pruning.
    
    Reference: 
        Li H, Kadav A, Durdanovic I, et al. Pruning filters for efficient convnets[J]. arXiv preprint arXiv:1608.08710, 2016.

    Args:
        model (torch.nn.Module): A PyTorch model.
        pruned_layers_name (Optional[List[str]]): Which layers will be pruned. If it's `None`, all layers will be pruned.
        sparsity (float): Target sparsity. The pruned model is smaller if sparsity is higher.
        model_input_size (Tuple[int]): Typically be `(1, 3, 32, 32)` or `(1, 3, 224, 224)`.
        device (str): Typically be 'cpu' or 'cuda'.
        verbose (bool, optional): Whether to output the verbose log. Defaults to False. (BUG TO FIX)
        need_return_mask (bool, optional): Return the fine-grained mask generated by NNI framework for debug. Defaults to False.
        dep_aware (bool, optional): Refers to the argument `dependency_aware` in NNI framework. Defaults to False.

    Returns:
        torch.nn.Module: Pruned model.
    """
    model = copy.deepcopy(model).to(device)
    if sparsity == 0:
        return model
    pruned_model = copy.deepcopy(model).to(device)

    # generate mask
    model.eval()
    if pruned_layers_name is not None:
        config_list = [{
            'op_types': ['Conv2d', 'ConvTranspose2d'],
            'op_names': pruned_layers_name,
            'sparsity': sparsity
        }]
    else:
        config_list = [{
            'op_types': ['Conv2d', 'ConvTranspose2d'],
            'sparsity': sparsity
        }]
        
    pruner = L1FilterPruner(model, config_list, dependency_aware=dep_aware, 
                            dummy_input=torch.rand(model_input_size).to(device) if dep_aware else None)
    pruned_model = _prune_module(pruned_model, pruner, model_input_size, device, verbose, need_return_mask)
    return pruned_model