|
import numpy as np |
|
import torch |
|
from torch import Tensor |
|
import torch.nn as nn |
|
import torchvision |
|
from torch.utils.data import DataLoader |
|
from torchvision import datasets |
|
import torchvision.transforms as transforms |
|
import os |
|
import time |
|
import sys |
|
import torch.quantization |
|
from typing import Any, Callable, List, Optional, Type, Union |
|
from torch.quantization.observer import MinMaxObserver, MovingAverageMinMaxObserver, HistogramObserver |
|
from typing import Dict, cast |
|
from torchvision.models.vgg import VGG |
|
|
|
|
|
import warnings |
|
warnings.filterwarnings( |
|
action = 'ignore', |
|
category = DeprecationWarning, |
|
module = r'.*' |
|
) |
|
warnings.filterwarnings( |
|
action = 'default', |
|
module = r'torch.quantization' |
|
) |
|
from torch.quantization import QuantStub, DeQuantStub |
|
|
|
def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential: |
|
layers: List[nn.Module] = [] |
|
in_channels = 3 |
|
for v in cfg: |
|
if v == "M": |
|
layers += [nn.MaxPool2d(kernel_size = 2, stride = 2)] |
|
else: |
|
v = cast(int, v) |
|
conv2d = nn.Conv2d(in_channels, v, kernel_size = 3, padding = 1) |
|
if batch_norm: |
|
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace = True)] |
|
else: |
|
layers += [conv2d, nn.ReLU(inplace = True)] |
|
in_channels = v |
|
return nn.Sequential(*layers) |
|
|
|
cfgs: Dict[str, List[Union[str, int]]] = { |
|
"A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], |
|
"B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], |
|
"D": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"], |
|
"E": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"], |
|
} |
|
|
|
class qVGG(nn.Module): |
|
def __init__( |
|
self, num_classes: int = 1000, init_weights: bool = True, dropout: float = 0.5 |
|
) -> None: |
|
super().__init__() |
|
|
|
self.features = make_layers(cfgs["D"], batch_norm = True) |
|
self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) |
|
self.classifier = nn.Sequential( |
|
nn.Linear(512 * 7 * 7, 4096), |
|
nn.ReLU(True), |
|
nn.Dropout(p = dropout), |
|
nn.Linear(4096, 4096), |
|
nn.ReLU(True), |
|
nn.Dropout(p = dropout), |
|
nn.Linear(4096, num_classes), |
|
) |
|
qconfig = torch.quantization.QConfig( |
|
activation = MinMaxObserver.with_args(qscheme = torch.per_tensor_symmetric, dtype = torch.quint8), |
|
weight = MinMaxObserver.with_args(qscheme = torch.per_tensor_symmetric, dtype = torch.qint8) |
|
) |
|
self.quant = torch.quantization.QuantStub(qconfig) |
|
self.dequant = DeQuantStub() |
|
|
|
if init_weights: |
|
for m in self.modules(): |
|
if isinstance(m, nn.Conv2d): |
|
nn.init.kaiming_normal_(m.weight, mode = "fan_out", nonlinearity = "relu") |
|
if m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
elif isinstance(m, nn.BatchNorm2d): |
|
nn.init.constant_(m.weight, 1) |
|
nn.init.constant_(m.bias, 0) |
|
elif isinstance(m, nn.Linear): |
|
nn.init.normal_(m.weight, 0, 0.01) |
|
nn.init.constant_(m.bias, 0) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = self.quant(x) |
|
x = self.features(x) |
|
x = self.avgpool(x) |
|
x = torch.flatten(x, 1) |
|
x = self.classifier(x) |
|
x = self.dequant(x) |
|
return x |
|
|
|
|
|
def fuse_model(self) -> None: |
|
for m, n in zip(self.modules(), self.named_modules()): |
|
if type(m) == nn.Conv2d: |
|
k = int(n[0].split('.')[-1]) |
|
torch.quantization.fuse_modules(self, [["features." + str(k), "features." + str(k + 1), "features." + str(k + 2)]], inplace = True) |
|
|
|
|
|
|