|
import logging as log
|
|
import os
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torchvision.transforms as transforms
|
|
import torchvision.models as models
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
from enum import Enum
|
|
|
|
|
|
class AdvEnum(Enum):
|
|
@classmethod
|
|
def list(cls):
|
|
return list(map(lambda c: c.value, cls))
|
|
|
|
@classmethod
|
|
def list_name_value(cls):
|
|
return list(map(lambda c: (c.name, c.value), cls))
|
|
|
|
|
|
class DecoNetMode(AdvEnum):
|
|
FREEZE_DECO = 0
|
|
FREEZE_PTMODEL = 1
|
|
FREEZE_PTMODEL_NO_FC = 2
|
|
UNFREEZE_ALL = 3
|
|
FREEZE_ALL = 4
|
|
FREEZE_ALL_NO_FC = 5
|
|
|
|
|
|
class DecoType(AdvEnum):
|
|
NO = 0
|
|
DECONV = 1
|
|
RESIZE_CONV = 2
|
|
ColorUDECO = 16
|
|
PIXEL_SHUFFLE = 20
|
|
|
|
|
|
def get_deco_model(use_deco, out_deco) -> nn.Module:
|
|
if use_deco in [DecoType.DECONV, DecoType.DECONV_NORM]:
|
|
return StandardDECO(out_deco, deconv=True)
|
|
elif use_deco in [DecoType.RESIZE_CONV]:
|
|
return StandardDECO(out_deco, deconv=False)
|
|
elif use_deco is DecoType.PIXEL_SHUFFLE:
|
|
return PixelShuffle(out_deco, lrelu=False)
|
|
elif use_deco is DecoType.ColorUDECO:
|
|
return ColorUDECO(out_deco)
|
|
else:
|
|
raise ValueError("Module not found")
|
|
|
|
|
|
class PreTrainedModel(AdvEnum):
|
|
DENSENET_121 = 0
|
|
RESNET_18 = 1
|
|
RESNET_34 = 2
|
|
RESNET_50 = 3
|
|
VGG11 = 4
|
|
VGG11_BN = 5
|
|
|
|
|
|
def get_pt_model(model, output, pretrained=True):
|
|
input = 224
|
|
if not isinstance(model, PreTrainedModel):
|
|
model = PreTrainedModel(model)
|
|
pt_model = None
|
|
if model == PreTrainedModel.DENSENET_121:
|
|
pt_model = models.densenet121(pretrained=pretrained)
|
|
num_ftrs = pt_model.classifier.in_features
|
|
pt_model.classifier = nn.Linear(num_ftrs, output)
|
|
pt_model.last_layer_name = "classifier"
|
|
elif model == PreTrainedModel.RESNET_18:
|
|
pt_model = models.resnet18(pretrained=pretrained)
|
|
num_ftrs = pt_model.fc.in_features
|
|
pt_model.fc = nn.Linear(num_ftrs, output)
|
|
pt_model.last_layer_name = "fc"
|
|
elif model == PreTrainedModel.RESNET_34:
|
|
pt_model = models.resnet34(pretrained=pretrained)
|
|
num_ftrs = pt_model.fc.in_features
|
|
pt_model.fc = nn.Linear(num_ftrs, output)
|
|
pt_model.last_layer_name = "fc"
|
|
elif model == PreTrainedModel.RESNET_50:
|
|
pt_model = models.resnet50(pretrained=pretrained)
|
|
num_ftrs = pt_model.fc.in_features
|
|
pt_model.fc = nn.Linear(num_ftrs, output)
|
|
pt_model.last_layer_name = "fc"
|
|
elif model == PreTrainedModel.VGG11:
|
|
pt_model = models.vgg11(pretrained=pretrained)
|
|
num_ftrs = pt_model.classifier[6].in_features
|
|
pt_model.classifier[6] = nn.Linear(num_ftrs, output)
|
|
pt_model.last_layer_name = "classifier.6"
|
|
elif model == PreTrainedModel.VGG11_BN:
|
|
pt_model = models.vgg11_bn(pretrained=pretrained)
|
|
num_ftrs = pt_model.classifier[6].in_features
|
|
pt_model.classifier[6] = nn.Linear(num_ftrs, output)
|
|
pt_model.last_layer_name = "classifier.6"
|
|
else:
|
|
raise ValueError("Model not found")
|
|
|
|
return pt_model, input
|
|
|
|
|
|
class DecoNet(nn.Module):
|
|
"""
|
|
Colorization module(optional)+Model
|
|
"""
|
|
|
|
def __init__(self, output=14,
|
|
deco_type=DecoType.ColorUDECO,
|
|
pt_model=PreTrainedModel.RESNET_18,
|
|
pre_trained=True,
|
|
training_mode=DecoNetMode.FREEZE_PTMODEL_NO_FC,
|
|
use_aap=False):
|
|
super().__init__()
|
|
|
|
self.deco_type = deco_type
|
|
self.training_mode = training_mode
|
|
self.use_aap = use_aap
|
|
pt_model, self.out_deco = get_pt_model(pt_model, output, pre_trained)
|
|
self.last_layer_name = pt_model.last_layer_name
|
|
|
|
if self.deco_type is not DecoType.NO:
|
|
self.deco = get_deco_model(self.deco_type, self.out_deco)
|
|
else:
|
|
self.deco = None
|
|
self.pt_model = pt_model
|
|
self.set_mode(training_mode)
|
|
|
|
def set_mode(self, mode, print=True):
|
|
if not isinstance(mode, DecoNetMode):
|
|
mode = DecoNetMode(mode)
|
|
if mode == DecoNetMode.UNFREEZE_ALL:
|
|
for param in self.parameters():
|
|
param.requires_grad = True
|
|
elif mode == DecoNetMode.FREEZE_DECO:
|
|
self.set_mode(DecoNetMode.UNFREEZE_ALL, False)
|
|
for param in self.deco.parameters():
|
|
param.requires_grad = False
|
|
elif mode == DecoNetMode.FREEZE_PTMODEL:
|
|
self.set_mode(DecoNetMode.UNFREEZE_ALL, False)
|
|
for param in self.pt_model.parameters():
|
|
param.requires_grad = False
|
|
elif mode == DecoNetMode.FREEZE_PTMODEL_NO_FC:
|
|
self.set_mode(DecoNetMode.UNFREEZE_ALL, False)
|
|
for name, param in self.pt_model.named_parameters():
|
|
if self.last_layer_name not in name:
|
|
param.requires_grad = False
|
|
elif mode == DecoNetMode.FREEZE_ALL:
|
|
for param in self.parameters():
|
|
param.requires_grad = False
|
|
elif mode == DecoNetMode.FREEZE_ALL_NO_FC:
|
|
self.set_mode(DecoNetMode.FREEZE_ALL, False)
|
|
|
|
for name, param in self.pt_model.named_parameters():
|
|
if self.last_layer_name in name:
|
|
param.requires_grad = True
|
|
|
|
if print:
|
|
log.info("#############################################")
|
|
log.info("PARAMETERS STATUS:")
|
|
for name, param in self.named_parameters():
|
|
log.info("{} : {}".format(name, param.requires_grad))
|
|
log.info("#############################################")
|
|
|
|
def get_layer_weight(self, sel_name: str = ""):
|
|
if sel_name == "":
|
|
sel_name = self.last_layer_name
|
|
res = []
|
|
for name, param in self.pt_model.named_parameters():
|
|
if sel_name in name:
|
|
res.append(param)
|
|
|
|
return res
|
|
|
|
def forward(self, xb):
|
|
"""
|
|
@:param xb : tensor
|
|
Batch of input images
|
|
|
|
@:return tensor
|
|
A batch of output images
|
|
"""
|
|
if self.deco is not None:
|
|
xb = self.deco(xb)
|
|
if self.use_aap:
|
|
xb = F.adaptive_avg_pool2d(xb, (self.out_deco, self.out_deco))
|
|
return self.pt_model(xb)
|
|
|
|
def clean_last_layer(self):
|
|
pt_model_type = self.pt_model
|
|
|
|
if pt_model_type == PreTrainedModel.VGG11_BN or pt_model_type == PreTrainedModel.VGG11:
|
|
self.pt_model.classifier[6].reset_parameters()
|
|
else:
|
|
last_layer_name = list(self.pt_model._modules)[-1]
|
|
self.pt_model._modules[last_layer_name].reset_parameters()
|
|
|
|
log.info("Last layer cleaned!")
|
|
|
|
def last_layer_size(self):
|
|
pt_model_type = self.pt_model
|
|
if pt_model_type == PreTrainedModel.VGG11_BN or pt_model_type == PreTrainedModel.VGG11:
|
|
return self.pt_model.classifier[6].weight.shape[-1]
|
|
else:
|
|
last_layer_name = list(self.pt_model._modules)[-1]
|
|
return self.pt_model._modules[last_layer_name].shape[-1]
|
|
|
|
def load_deco_state_dict(self, state_dict):
|
|
if self.deco is None:
|
|
self.deco = get_deco_model(self.deco_type, self.out_deco)
|
|
if hasattr(self.deco, "load_state_dict"):
|
|
self.deco.load_state_dict(state_dict)
|
|
else:
|
|
return False
|
|
self.set_mode(self.training_mode)
|
|
return True
|
|
|
|
|
|
def default_deco__weight_init(m):
|
|
if isinstance(m, nn.Conv2d):
|
|
|
|
|
|
torch.nn.init.xavier_uniform_(m.weight)
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
m.weight.data.fill_(1)
|
|
m.bias.data.zero_()
|
|
|
|
|
|
def bn_weight_init(m):
|
|
if isinstance(m, nn.BatchNorm2d):
|
|
m.weight.data.fill_(1)
|
|
m.bias.data.zero_()
|
|
|
|
|
|
class BaseDECO(nn.Module):
|
|
def __init__(self, out=224, init=None):
|
|
super().__init__()
|
|
self.out_s = out
|
|
self.init = init
|
|
|
|
def set_output_size(self, out_s):
|
|
self.out_s = out_s
|
|
|
|
def init_weights(self):
|
|
if self.init is None:
|
|
pass
|
|
elif self.init == 0:
|
|
self.apply(default_deco__weight_init)
|
|
elif self.init == 1:
|
|
self.apply(bn_weight_init)
|
|
|
|
|
|
class ResBlock(nn.Module):
|
|
def __init__(self, ni, nf=None, kernel=3, stride=1, padding=1):
|
|
super().__init__()
|
|
if nf is None:
|
|
nf = ni
|
|
self.conv1 = conv_layer(ni, nf, kernel=kernel, stride=stride, padding=padding)
|
|
self.conv2 = conv_layer(nf, nf, kernel=kernel, stride=stride, padding=padding)
|
|
|
|
def forward(self, x):
|
|
return x + self.conv2(self.conv1(x))
|
|
|
|
|
|
def conv_layer(in_layer, out_layer, kernel=3, stride=1, padding=1, instanceNorm=False):
|
|
return nn.Sequential(
|
|
nn.Conv2d(in_layer, out_layer, kernel_size=kernel, stride=stride, padding=padding),
|
|
nn.BatchNorm2d(out_layer) if not instanceNorm else nn.InstanceNorm2d(out_layer),
|
|
nn.LeakyReLU(inplace=True)
|
|
)
|
|
|
|
|
|
def _make_res_layers(nl, ni, kernel=3, stride=1, padding=1):
|
|
layers = []
|
|
for i in range(nl):
|
|
layers.append(ResBlock(ni, kernel=kernel, stride=stride, padding=padding))
|
|
|
|
return nn.Sequential(*layers)
|
|
|
|
|
|
class StandardDECO(BaseDECO):
|
|
"""
|
|
Standard DECO Module
|
|
"""
|
|
|
|
def __init__(self, out=224, init=0, deconv=False):
|
|
super().__init__(out, init)
|
|
self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=2)
|
|
self.bn1 = nn.BatchNorm2d(64)
|
|
|
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
|
self.resblocks = _make_res_layers(8, 64)
|
|
self.conv_last = nn.Conv2d(64, 3, kernel_size=1)
|
|
self.deconv = deconv
|
|
if deconv:
|
|
|
|
self.deconv = nn.ConvTranspose2d(in_channels=3, out_channels=3, kernel_size=8, padding=2, stride=4,
|
|
groups=3, bias=False)
|
|
else:
|
|
self.pad = nn.ReflectionPad2d(1)
|
|
self.conv_up = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, padding=0, stride=1)
|
|
|
|
self.init_weights()
|
|
|
|
def forward(self, xb):
|
|
"""
|
|
@:param xb : Tensor
|
|
Batch of input images
|
|
|
|
@:return tensor
|
|
A batch of output images
|
|
"""
|
|
_xb = self.maxpool(F.leaky_relu(self.bn1(self.conv1(xb))))
|
|
_xb = self.resblocks(_xb)
|
|
_xb = self.conv_last(_xb)
|
|
if self.deconv:
|
|
_xb = self.deconv(_xb, output_size=xb.shape)
|
|
else:
|
|
_xb = self.conv_up(self.pad(F.interpolate(_xb, scale_factor=4, mode='nearest')))
|
|
return _xb
|
|
|
|
|
|
def icnr(x, scale=4, init=nn.init.kaiming_normal_):
|
|
""" ICNR init of `x`, with `scale` and `init` function.
|
|
|
|
Checkerboard artifact free sub-pixel convolution: https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf
|
|
"""
|
|
ni, nf, h, w = x.shape
|
|
ni2 = int(ni / (scale ** 2))
|
|
k = init(torch.zeros([ni2, nf, h, w])).transpose(0, 1)
|
|
k = k.contiguous().view(ni2, nf, -1)
|
|
k = k.repeat(1, 1, scale ** 2)
|
|
k = k.contiguous().view([nf, ni, h, w]).transpose(0, 1)
|
|
x.data.copy_(k)
|
|
|
|
|
|
class PixelShuffle_ICNR(nn.Module):
|
|
""" Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, `icnr` init,
|
|
and `weight_norm`.
|
|
|
|
"Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts":
|
|
https://arxiv.org/abs/1806.02658
|
|
"""
|
|
|
|
def __init__(self, ni: int, nf: int = None, scale: int = 4, icnr_init=True, blur_k=2, blur_s=1,
|
|
blur_pad=(1, 0, 1, 0), lrelu=True):
|
|
super().__init__()
|
|
nf = ni if nf is None else nf
|
|
self.conv = conv_layer(ni, nf * (scale ** 2), kernel=1, padding=0, stride=1) if lrelu else nn.Sequential(
|
|
nn.Conv2d(64, 3 * (scale ** 2), 1, 1, 0), nn.BatchNorm2d(3 * (scale ** 2)))
|
|
if icnr_init:
|
|
icnr(self.conv[0].weight, scale=scale)
|
|
self.act = nn.LeakyReLU(inplace=False) if lrelu else nn.Hardtanh(-10000, 10000)
|
|
self.shuf = nn.PixelShuffle(scale)
|
|
|
|
self.pad = nn.ReplicationPad2d(blur_pad)
|
|
self.blur = nn.AvgPool2d(blur_k, stride=blur_s)
|
|
|
|
def forward(self, x):
|
|
x = self.shuf(self.act(self.conv(x)))
|
|
return self.blur(self.pad(x))
|
|
|
|
|
|
class PixelShuffle(BaseDECO):
|
|
"""
|
|
PixelShuffle Module
|
|
"""
|
|
|
|
def __init__(self, out=224, init=1, scale=4, lrelu=False):
|
|
super().__init__(out, init)
|
|
|
|
self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=2)
|
|
self.bn1 = nn.BatchNorm2d(64)
|
|
self.act1 = nn.LeakyReLU()
|
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
|
self.resblocks = _make_res_layers(8, 64)
|
|
self.pixel_shuffle = PixelShuffle_ICNR(ni=64, nf=3, scale=scale, lrelu=lrelu)
|
|
self.init_weights()
|
|
|
|
def forward(self, xb):
|
|
"""
|
|
@:param xb : Tensor
|
|
Batch of input images
|
|
|
|
@:return tensor
|
|
A batch of output images
|
|
"""
|
|
_xb = self.maxpool(self.act1(self.bn1(self.conv1(xb))))
|
|
_xb = self.resblocks(_xb)
|
|
|
|
return self.pixel_shuffle(_xb)
|
|
|
|
|
|
class ColorUDECO(BaseDECO):
|
|
"""
|
|
ColorUDECO Module
|
|
"""
|
|
|
|
def __init__(self, out=224, init=0, in_ch=1, out_ch=3):
|
|
super().__init__(out, init)
|
|
self.dw1 = ColorDown(in_ch, 16)
|
|
self.dw2 = ColorDown(16, 32)
|
|
self.dw3 = ColorDown(32, 64)
|
|
self.up1 = ColorUp(64, 32)
|
|
self.up2 = ColorUp(64, 16)
|
|
self.out = ColorOut(32, 16, out_ch)
|
|
|
|
def forward(self, x1):
|
|
"""
|
|
@:param x1 : Tensor
|
|
Batch of input images
|
|
|
|
@:return tensor
|
|
A batch of output images
|
|
"""
|
|
x1 = self.dw1(x1)
|
|
x2 = self.dw2(x1)
|
|
x3 = self.dw3(x2)
|
|
x3 = self.up1(x3)
|
|
x2 = self.up2(torch.cat([x2, x3], dim=1))
|
|
return self.out(torch.cat([x1, x2], dim=1))
|
|
|
|
|
|
class ColorDown(nn.Module):
|
|
def __init__(self, in_ch, out_ch, htan=False):
|
|
super(ColorDown, self).__init__()
|
|
self.d = nn.Sequential(
|
|
nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU() if not htan else nn.Hardtanh(),
|
|
nn.Conv2d(out_ch, out_ch, kernel_size=4, stride=2, padding=1),
|
|
nn.LeakyReLU() if not htan else nn.Hardtanh(),
|
|
nn.BatchNorm2d(out_ch)
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.d(x)
|
|
|
|
|
|
class ColorUp(nn.Module):
|
|
def __init__(self, in_ch, out_ch, htan=False):
|
|
super(ColorUp, self).__init__()
|
|
self.u = nn.Sequential(
|
|
nn.ConvTranspose2d(in_ch, out_ch, 4, 2, 1),
|
|
nn.LeakyReLU() if not htan else nn.Hardtanh(),
|
|
nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU() if not htan else nn.Hardtanh(),
|
|
nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU() if not htan else nn.Hardtanh(),
|
|
nn.BatchNorm2d(out_ch)
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.u(x)
|
|
|
|
|
|
class ColorOut(nn.Module):
|
|
def __init__(self, in_ch, out_ch, out_last, htan=False):
|
|
super(ColorOut, self).__init__()
|
|
self.u = nn.Sequential(
|
|
nn.ConvTranspose2d(in_ch, out_ch, 4, 2, 1),
|
|
nn.LeakyReLU() if not htan else nn.Hardtanh(),
|
|
nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU() if not htan else nn.Hardtanh(),
|
|
nn.Conv2d(out_ch, out_last, kernel_size=1, stride=1, padding=0),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.u(x)
|
|
|