import numpy as np import torch from torch import Tensor import torch.nn as nn import torchvision from torch.utils.data import DataLoader from torchvision import datasets import torchvision.transforms as transforms import os import time import sys import torch.quantization from typing import Any, Callable, List, Optional, Type, Union from torch.quantization.observer import MinMaxObserver, MovingAverageMinMaxObserver, HistogramObserver from typing import Dict, cast from torchvision.models.vgg import VGG # # Setup warnings import warnings warnings.filterwarnings( action = 'ignore', category = DeprecationWarning, module = r'.*' ) warnings.filterwarnings( action = 'default', module = r'torch.quantization' ) from torch.quantization import QuantStub, DeQuantStub def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential: layers: List[nn.Module] = [] in_channels = 3 for v in cfg: if v == "M": layers += [nn.MaxPool2d(kernel_size = 2, stride = 2)] else: v = cast(int, v) conv2d = nn.Conv2d(in_channels, v, kernel_size = 3, padding = 1) if batch_norm: layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace = True)] else: layers += [conv2d, nn.ReLU(inplace = True)] in_channels = v return nn.Sequential(*layers) cfgs: Dict[str, List[Union[str, int]]] = { "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], "B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], "D": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"], "E": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"], } class qVGG(nn.Module): def __init__( self, num_classes: int = 1000, init_weights: bool = True, dropout: float = 0.5 ) -> None: super().__init__() #_log_api_usage_once(self) self.features = make_layers(cfgs["D"], batch_norm = True) self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) self.classifier = nn.Sequential( nn.Linear(512 * 7 * 7, 4096), nn.ReLU(True), nn.Dropout(p = dropout), nn.Linear(4096, 4096), nn.ReLU(True), nn.Dropout(p = dropout), nn.Linear(4096, num_classes), ) qconfig = torch.quantization.QConfig( activation = MinMaxObserver.with_args(qscheme = torch.per_tensor_symmetric, dtype = torch.quint8), weight = MinMaxObserver.with_args(qscheme = torch.per_tensor_symmetric, dtype = torch.qint8) ) self.quant = torch.quantization.QuantStub(qconfig) self.dequant = DeQuantStub() if init_weights: for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode = "fan_out", nonlinearity = "relu") if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) nn.init.constant_(m.bias, 0) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.quant(x) x = self.features(x) x = self.avgpool(x) x = torch.flatten(x, 1) x = self.classifier(x) x = self.dequant(x) return x #We define the function to fuse models def fuse_model(self) -> None: for m, n in zip(self.modules(), self.named_modules()): if type(m) == nn.Conv2d: k = int(n[0].split('.')[-1]) torch.quantization.fuse_modules(self, [["features." + str(k), "features." + str(k + 1), "features." + str(k + 2)]], inplace = True)