Spaces:
Sleeping
Sleeping
import copy | |
import logging | |
import torch | |
from torch import nn | |
from convs.cifar_resnet import resnet32 | |
from convs.resnet import resnet18, resnet34, resnet50, resnet101, resnet152 | |
from convs.ucir_cifar_resnet import resnet32 as cosine_resnet32 | |
from convs.ucir_resnet import resnet18 as cosine_resnet18 | |
from convs.ucir_resnet import resnet34 as cosine_resnet34 | |
from convs.ucir_resnet import resnet50 as cosine_resnet50 | |
from convs.linears import SimpleLinear, SplitCosineLinear, CosineLinear | |
from convs.modified_represnet import resnet18_rep,resnet34_rep | |
from convs.resnet_cbam import resnet18_cbam,resnet34_cbam,resnet50_cbam | |
from convs.memo_resnet import get_resnet18_imagenet as get_memo_resnet18 #for MEMO imagenet | |
from convs.memo_cifar_resnet import get_resnet32_a2fc as get_memo_resnet32 #for MEMO cifar | |
def get_convnet(args, pretrained=False): | |
name = args["convnet_type"].lower() | |
if name == "resnet32": | |
return resnet32() | |
elif name == "resnet18": | |
return resnet18(pretrained=pretrained,args=args) | |
elif name == "resnet34": | |
return resnet34(pretrained=pretrained,args=args) | |
elif name == "resnet50": | |
return resnet50(pretrained=pretrained,args=args) | |
elif name == "cosine_resnet18": | |
return cosine_resnet18(pretrained=pretrained,args=args) | |
elif name == "cosine_resnet32": | |
return cosine_resnet32() | |
elif name == "cosine_resnet34": | |
return cosine_resnet34(pretrained=pretrained,args=args) | |
elif name == "cosine_resnet50": | |
return cosine_resnet50(pretrained=pretrained,args=args) | |
elif name == "resnet18_rep": | |
return resnet18_rep(pretrained=pretrained,args=args) | |
elif name == "resnet18_cbam": | |
return resnet18_cbam(pretrained=pretrained,args=args) | |
elif name == "resnet34_cbam": | |
return resnet34_cbam(pretrained=pretrained,args=args) | |
elif name == "resnet50_cbam": | |
return resnet50_cbam(pretrained=pretrained,args=args) | |
# MEMO benchmark backbone | |
elif name == 'memo_resnet18': | |
_basenet, _adaptive_net = get_memo_resnet18() | |
return _basenet, _adaptive_net | |
elif name == 'memo_resnet32': | |
_basenet, _adaptive_net = get_memo_resnet32() | |
return _basenet, _adaptive_net | |
else: | |
raise NotImplementedError("Unknown type {}".format(name)) | |
class BaseNet(nn.Module): | |
def __init__(self, args, pretrained): | |
super(BaseNet, self).__init__() | |
self.convnet = get_convnet(args, pretrained) | |
self.fc = None | |
def feature_dim(self): | |
return self.convnet.out_dim | |
def extract_vector(self, x): | |
return self.convnet(x)["features"] | |
def forward(self, x): | |
x = self.convnet(x) | |
out = self.fc(x["features"]) | |
""" | |
{ | |
'fmaps': [x_1, x_2, ..., x_n], | |
'features': features | |
'logits': logits | |
} | |
""" | |
out.update(x) | |
return out | |
def update_fc(self, nb_classes): | |
pass | |
def generate_fc(self, in_dim, out_dim): | |
pass | |
def copy(self): | |
return copy.deepcopy(self) | |
def freeze(self): | |
for param in self.parameters(): | |
param.requires_grad = False | |
self.eval() | |
return self | |
def load_checkpoint(self, args): | |
if args["init_cls"] == 50: | |
pkl_name = "{}_{}_{}_B{}_Inc{}".format( | |
args["dataset"], | |
args["seed"], | |
args["convnet_type"], | |
0, | |
args["init_cls"], | |
) | |
checkpoint_name = f"checkpoints/finetune_{pkl_name}_0.pkl" | |
else: | |
checkpoint_name = f"checkpoints/finetune_{args['csv_name']}_0.pkl" | |
model_infos = torch.load(checkpoint_name) | |
self.convnet.load_state_dict(model_infos['convnet']) | |
self.fc.load_state_dict(model_infos['fc']) | |
test_acc = model_infos['test_acc'] | |
return test_acc | |
class IncrementalNet(BaseNet): | |
def __init__(self, args, pretrained, gradcam=False): | |
super().__init__(args, pretrained) | |
self.gradcam = gradcam | |
if hasattr(self, "gradcam") and self.gradcam: | |
self._gradcam_hooks = [None, None] | |
self.set_gradcam_hook() | |
def update_fc(self, nb_classes): | |
fc = self.generate_fc(self.feature_dim, nb_classes) | |
if self.fc is not None: | |
nb_output = self.fc.out_features | |
weight = copy.deepcopy(self.fc.weight.data) | |
bias = copy.deepcopy(self.fc.bias.data) | |
fc.weight.data[:nb_output] = weight | |
fc.bias.data[:nb_output] = bias | |
del self.fc | |
self.fc = fc | |
def weight_align(self, increment): | |
weights = self.fc.weight.data | |
newnorm = torch.norm(weights[-increment:, :], p=2, dim=1) | |
oldnorm = torch.norm(weights[:-increment, :], p=2, dim=1) | |
meannew = torch.mean(newnorm) | |
meanold = torch.mean(oldnorm) | |
gamma = meanold / meannew | |
print("alignweights,gamma=", gamma) | |
self.fc.weight.data[-increment:, :] *= gamma | |
def generate_fc(self, in_dim, out_dim): | |
fc = SimpleLinear(in_dim, out_dim) | |
return fc | |
def forward(self, x): | |
x = self.convnet(x) | |
out = self.fc(x["features"]) | |
out.update(x) | |
if hasattr(self, "gradcam") and self.gradcam: | |
out["gradcam_gradients"] = self._gradcam_gradients | |
out["gradcam_activations"] = self._gradcam_activations | |
return out | |
def unset_gradcam_hook(self): | |
self._gradcam_hooks[0].remove() | |
self._gradcam_hooks[1].remove() | |
self._gradcam_hooks[0] = None | |
self._gradcam_hooks[1] = None | |
self._gradcam_gradients, self._gradcam_activations = [None], [None] | |
def set_gradcam_hook(self): | |
self._gradcam_gradients, self._gradcam_activations = [None], [None] | |
def backward_hook(module, grad_input, grad_output): | |
self._gradcam_gradients[0] = grad_output[0] | |
return None | |
def forward_hook(module, input, output): | |
self._gradcam_activations[0] = output | |
return None | |
self._gradcam_hooks[0] = self.convnet.last_conv.register_backward_hook( | |
backward_hook | |
) | |
self._gradcam_hooks[1] = self.convnet.last_conv.register_forward_hook( | |
forward_hook | |
) | |
class IL2ANet(IncrementalNet): | |
def update_fc(self, num_old, num_total, num_aux): | |
fc = self.generate_fc(self.feature_dim, num_total+num_aux) | |
if self.fc is not None: | |
weight = copy.deepcopy(self.fc.weight.data) | |
bias = copy.deepcopy(self.fc.bias.data) | |
fc.weight.data[:num_old] = weight[:num_old] | |
fc.bias.data[:num_old] = bias[:num_old] | |
del self.fc | |
self.fc = fc | |
class CosineIncrementalNet(BaseNet): | |
def __init__(self, args, pretrained, nb_proxy=1): | |
super().__init__(args, pretrained) | |
self.nb_proxy = nb_proxy | |
def update_fc(self, nb_classes, task_num): | |
fc = self.generate_fc(self.feature_dim, nb_classes) | |
if self.fc is not None: | |
if task_num == 1: | |
fc.fc1.weight.data = self.fc.weight.data | |
fc.sigma.data = self.fc.sigma.data | |
else: | |
prev_out_features1 = self.fc.fc1.out_features | |
fc.fc1.weight.data[:prev_out_features1] = self.fc.fc1.weight.data | |
fc.fc1.weight.data[prev_out_features1:] = self.fc.fc2.weight.data | |
fc.sigma.data = self.fc.sigma.data | |
del self.fc | |
self.fc = fc | |
def generate_fc(self, in_dim, out_dim): | |
if self.fc is None: | |
fc = CosineLinear(in_dim, out_dim, self.nb_proxy, to_reduce=True) | |
else: | |
prev_out_features = self.fc.out_features // self.nb_proxy | |
# prev_out_features = self.fc.out_features | |
fc = SplitCosineLinear( | |
in_dim, prev_out_features, out_dim - prev_out_features, self.nb_proxy | |
) | |
return fc | |
class BiasLayer_BIC(nn.Module): | |
def __init__(self): | |
super(BiasLayer_BIC, self).__init__() | |
self.alpha = nn.Parameter(torch.ones(1, requires_grad=True)) | |
self.beta = nn.Parameter(torch.zeros(1, requires_grad=True)) | |
def forward(self, x, low_range, high_range): | |
ret_x = x.clone() | |
ret_x[:, low_range:high_range] = ( | |
self.alpha * x[:, low_range:high_range] + self.beta | |
) | |
return ret_x | |
def get_params(self): | |
return (self.alpha.item(), self.beta.item()) | |
class IncrementalNetWithBias(BaseNet): | |
def __init__(self, args, pretrained, bias_correction=False): | |
super().__init__(args, pretrained) | |
# Bias layer | |
self.bias_correction = bias_correction | |
self.bias_layers = nn.ModuleList([]) | |
self.task_sizes = [] | |
def forward(self, x): | |
x = self.convnet(x) | |
out = self.fc(x["features"]) | |
if self.bias_correction: | |
logits = out["logits"] | |
for i, layer in enumerate(self.bias_layers): | |
logits = layer( | |
logits, sum(self.task_sizes[:i]), sum(self.task_sizes[: i + 1]) | |
) | |
out["logits"] = logits | |
out.update(x) | |
return out | |
def update_fc(self, nb_classes): | |
fc = self.generate_fc(self.feature_dim, nb_classes) | |
if self.fc is not None: | |
nb_output = self.fc.out_features | |
weight = copy.deepcopy(self.fc.weight.data) | |
bias = copy.deepcopy(self.fc.bias.data) | |
fc.weight.data[:nb_output] = weight | |
fc.bias.data[:nb_output] = bias | |
del self.fc | |
self.fc = fc | |
new_task_size = nb_classes - sum(self.task_sizes) | |
self.task_sizes.append(new_task_size) | |
self.bias_layers.append(BiasLayer_BIC()) | |
def generate_fc(self, in_dim, out_dim): | |
fc = SimpleLinear(in_dim, out_dim) | |
return fc | |
def get_bias_params(self): | |
params = [] | |
for layer in self.bias_layers: | |
params.append(layer.get_params()) | |
return params | |
def unfreeze(self): | |
for param in self.parameters(): | |
param.requires_grad = True | |
class DERNet(nn.Module): | |
def __init__(self, args, pretrained): | |
super(DERNet, self).__init__() | |
self.convnet_type = args["convnet_type"] | |
self.convnets = nn.ModuleList() | |
self.pretrained = pretrained | |
self.out_dim = None | |
self.fc = None | |
self.aux_fc = None | |
self.task_sizes = [] | |
self.args = args | |
def feature_dim(self): | |
if self.out_dim is None: | |
return 0 | |
return self.out_dim * len(self.convnets) | |
def extract_vector(self, x): | |
features = [convnet(x)["features"] for convnet in self.convnets] | |
features = torch.cat(features, 1) | |
return features | |
def forward(self, x): | |
features = [convnet(x)["features"] for convnet in self.convnets] | |
features = torch.cat(features, 1) | |
out = self.fc(features) # {logics: self.fc(features)} | |
aux_logits = self.aux_fc(features[:, -self.out_dim :])["logits"] | |
out.update({"aux_logits": aux_logits, "features": features}) | |
return out | |
""" | |
{ | |
'features': features | |
'logits': logits | |
'aux_logits':aux_logits | |
} | |
""" | |
def update_fc(self, nb_classes): | |
if len(self.convnets) == 0: | |
self.convnets.append(get_convnet(self.args)) | |
else: | |
self.convnets.append(get_convnet(self.args)) | |
self.convnets[-1].load_state_dict(self.convnets[-2].state_dict()) | |
if self.out_dim is None: | |
self.out_dim = self.convnets[-1].out_dim | |
fc = self.generate_fc(self.feature_dim, nb_classes) | |
if self.fc is not None: | |
nb_output = self.fc.out_features | |
weight = copy.deepcopy(self.fc.weight.data) | |
bias = copy.deepcopy(self.fc.bias.data) | |
fc.weight.data[:nb_output, : self.feature_dim - self.out_dim] = weight | |
fc.bias.data[:nb_output] = bias | |
del self.fc | |
self.fc = fc | |
new_task_size = nb_classes - sum(self.task_sizes) | |
self.task_sizes.append(new_task_size) | |
self.aux_fc = self.generate_fc(self.out_dim, new_task_size + 1) | |
def generate_fc(self, in_dim, out_dim): | |
fc = SimpleLinear(in_dim, out_dim) | |
return fc | |
def copy(self): | |
return copy.deepcopy(self) | |
def freeze(self): | |
for param in self.parameters(): | |
param.requires_grad = False | |
self.eval() | |
return self | |
def freeze_conv(self): | |
for param in self.convnets.parameters(): | |
param.requires_grad = False | |
self.convnets.eval() | |
def weight_align(self, increment): | |
weights = self.fc.weight.data | |
newnorm = torch.norm(weights[-increment:, :], p=2, dim=1) | |
oldnorm = torch.norm(weights[:-increment, :], p=2, dim=1) | |
meannew = torch.mean(newnorm) | |
meanold = torch.mean(oldnorm) | |
gamma = meanold / meannew | |
print("alignweights,gamma=", gamma) | |
self.fc.weight.data[-increment:, :] *= gamma | |
def load_checkpoint(self, args): | |
checkpoint_name = f"checkpoints/finetune_{args['csv_name']}_0.pkl" | |
model_infos = torch.load(checkpoint_name) | |
assert len(self.convnets) == 1 | |
self.convnets[0].load_state_dict(model_infos['convnet']) | |
self.fc.load_state_dict(model_infos['fc']) | |
test_acc = model_infos['test_acc'] | |
return test_acc | |
class SimpleCosineIncrementalNet(BaseNet): | |
def __init__(self, args, pretrained): | |
super().__init__(args, pretrained) | |
def update_fc(self, nb_classes, nextperiod_initialization=None): | |
fc = self.generate_fc(self.feature_dim, nb_classes).cuda() | |
if self.fc is not None: | |
nb_output = self.fc.out_features | |
weight = copy.deepcopy(self.fc.weight.data) | |
fc.sigma.data = self.fc.sigma.data | |
if nextperiod_initialization is not None: | |
weight = torch.cat([weight.cuda(), nextperiod_initialization.cuda()]) | |
else: | |
weight = torch.cat([weight.cuda(), torch.zeros(nb_classes - nb_output, self.feature_dim).cuda()]) | |
fc.weight = nn.Parameter(weight) | |
del self.fc | |
self.fc = fc | |
def load_checkpoint(self, checkpoint): | |
self.convnet.load_state_dict(checkpoint["convnet"]) | |
self.fc.load_state_dict(checkpoint["fc"]) | |
def generate_fc(self, in_dim, out_dim): | |
fc = CosineLinear(in_dim, out_dim) | |
return fc | |
class FOSTERNet(nn.Module): | |
def __init__(self, args, pretrained): | |
super(FOSTERNet, self).__init__() | |
self.convnet_type = args["convnet_type"] | |
self.convnets = nn.ModuleList() | |
self.pretrained = pretrained | |
self.out_dim = None | |
self.fc = None | |
self.fe_fc = None | |
self.task_sizes = [] | |
self.oldfc = None | |
self.args = args | |
def feature_dim(self): | |
if self.out_dim is None: | |
return 0 | |
return self.out_dim * len(self.convnets) | |
def extract_vector(self, x): | |
features = [convnet(x)["features"] for convnet in self.convnets] | |
features = torch.cat(features, 1) | |
return features | |
def load_checkpoint(self, checkpoint): | |
if len(self.convnets) == 0: | |
self.convnets.append(get_convnet(self.args)) | |
self.convnets[0].load_state_dict(checkpoint["convnet"]) | |
self.fc.load_state_dict(checkpoint["fc"]) | |
def forward(self, x): | |
features = [convnet(x)["features"] for convnet in self.convnets] | |
features = torch.cat(features, 1) | |
out = self.fc(features) | |
fe_logits = self.fe_fc(features[:, -self.out_dim :])["logits"] | |
out.update({"fe_logits": fe_logits, "features": features}) | |
if self.oldfc is not None: | |
old_logits = self.oldfc(features[:, : -self.out_dim])["logits"] | |
out.update({"old_logits": old_logits}) | |
out.update({"eval_logits": out["logits"]}) | |
return out | |
def update_fc(self, nb_classes): | |
self.convnets.append(get_convnet(self.args)) | |
if self.out_dim is None: | |
self.out_dim = self.convnets[-1].out_dim | |
fc = self.generate_fc(self.feature_dim, nb_classes) | |
if self.fc is not None: | |
nb_output = self.fc.out_features | |
weight = copy.deepcopy(self.fc.weight.data) | |
bias = copy.deepcopy(self.fc.bias.data) | |
fc.weight.data[:nb_output, : self.feature_dim - self.out_dim] = weight | |
fc.bias.data[:nb_output] = bias | |
self.convnets[-1].load_state_dict(self.convnets[-2].state_dict()) | |
self.oldfc = self.fc | |
self.fc = fc | |
new_task_size = nb_classes - sum(self.task_sizes) | |
self.task_sizes.append(new_task_size) | |
self.fe_fc = self.generate_fc(self.out_dim, nb_classes) | |
def generate_fc(self, in_dim, out_dim): | |
fc = SimpleLinear(in_dim, out_dim) | |
return fc | |
def copy(self): | |
return copy.deepcopy(self) | |
def copy_fc(self, fc): | |
weight = copy.deepcopy(fc.weight.data) | |
bias = copy.deepcopy(fc.bias.data) | |
n, m = weight.shape[0], weight.shape[1] | |
self.fc.weight.data[:n, :m] = weight | |
self.fc.bias.data[:n] = bias | |
def freeze(self): | |
for param in self.parameters(): | |
param.requires_grad = False | |
self.eval() | |
return self | |
def freeze_conv(self): | |
for param in self.convnets.parameters(): | |
param.requires_grad = False | |
self.convnets.eval() | |
def weight_align(self, old, increment, value): | |
weights = self.fc.weight.data | |
newnorm = torch.norm(weights[-increment:, :], p=2, dim=1) | |
oldnorm = torch.norm(weights[:-increment, :], p=2, dim=1) | |
meannew = torch.mean(newnorm) | |
meanold = torch.mean(oldnorm) | |
gamma = meanold / meannew * (value ** (old / increment)) | |
logging.info("align weights, gamma = {} ".format(gamma)) | |
self.fc.weight.data[-increment:, :] *= gamma | |
class BiasLayer(nn.Module): | |
def __init__(self): | |
super(BiasLayer, self).__init__() | |
self.alpha = nn.Parameter(torch.zeros(1, requires_grad=True)) | |
self.beta = nn.Parameter(torch.zeros(1, requires_grad=True)) | |
def forward(self, x , bias=True): | |
ret_x = x.clone() | |
ret_x = (self.alpha+1) * x # + self.beta | |
if bias: | |
ret_x = ret_x + self.beta | |
return ret_x | |
def get_params(self): | |
return (self.alpha.item(), self.beta.item()) | |
class BEEFISONet(nn.Module): | |
def __init__(self, args, pretrained): | |
super(BEEFISONet, self).__init__() | |
self.convnet_type = args["convnet_type"] | |
self.convnets = nn.ModuleList() | |
self.pretrained = pretrained | |
self.out_dim = None | |
self.old_fc = None | |
self.new_fc = None | |
self.task_sizes = [] | |
self.forward_prototypes = None | |
self.backward_prototypes = None | |
self.args = args | |
self.biases = nn.ModuleList() | |
def feature_dim(self): | |
if self.out_dim is None: | |
return 0 | |
return self.out_dim * len(self.convnets) | |
def extract_vector(self, x): | |
features = [convnet(x)["features"] for convnet in self.convnets] | |
features = torch.cat(features, 1) | |
return features | |
def forward(self, x): | |
features = [convnet(x)["features"] for convnet in self.convnets] | |
features = torch.cat(features, 1) | |
if self.old_fc is None: | |
fc = self.new_fc | |
out = fc(features) | |
else: | |
''' | |
merge the weights | |
''' | |
new_task_size = self.task_sizes[-1] | |
fc_weight = torch.cat([self.old_fc.weight,torch.zeros((new_task_size,self.feature_dim-self.out_dim)).cuda()],dim=0) | |
new_fc_weight = self.new_fc.weight | |
new_fc_bias = self.new_fc.bias | |
for i in range(len(self.task_sizes)-2,-1,-1): | |
new_fc_weight = torch.cat([*[self.biases[i](self.backward_prototypes.weight[i].unsqueeze(0),bias=False) for _ in range(self.task_sizes[i])],new_fc_weight],dim=0) | |
new_fc_bias = torch.cat([*[self.biases[i](self.backward_prototypes.bias[i].unsqueeze(0),bias=True) for _ in range(self.task_sizes[i])], new_fc_bias]) | |
fc_weight = torch.cat([fc_weight,new_fc_weight],dim=1) | |
fc_bias = torch.cat([self.old_fc.bias,torch.zeros(new_task_size).cuda()]) | |
fc_bias+=new_fc_bias | |
logits = features@fc_weight.permute(1,0)+fc_bias | |
out = {"logits":logits} | |
new_fc_weight = self.new_fc.weight | |
new_fc_bias = self.new_fc.bias | |
for i in range(len(self.task_sizes)-2,-1,-1): | |
new_fc_weight = torch.cat([self.backward_prototypes.weight[i].unsqueeze(0),new_fc_weight],dim=0) | |
new_fc_bias = torch.cat([self.backward_prototypes.bias[i].unsqueeze(0), new_fc_bias]) | |
out["train_logits"] = features[:,-self.out_dim:]@new_fc_weight.permute(1,0)+new_fc_bias | |
out.update({"eval_logits": out["logits"],"energy_logits":self.forward_prototypes(features[:,-self.out_dim:])["logits"]}) | |
return out | |
def update_fc_before(self, nb_classes): | |
new_task_size = nb_classes - sum(self.task_sizes) | |
self.biases = nn.ModuleList([BiasLayer() for i in range(len(self.task_sizes))]) | |
self.convnets.append(get_convnet(self.args)) | |
if self.out_dim is None: | |
self.out_dim = self.convnets[-1].out_dim | |
if self.new_fc is not None: | |
self.fe_fc = self.generate_fc(self.out_dim, nb_classes) | |
self.backward_prototypes = self.generate_fc(self.out_dim,len(self.task_sizes)) | |
self.convnets[-1].load_state_dict(self.convnets[0].state_dict()) | |
self.forward_prototypes = self.generate_fc(self.out_dim, nb_classes) | |
self.new_fc = self.generate_fc(self.out_dim,new_task_size) | |
self.task_sizes.append(new_task_size) | |
def generate_fc(self, in_dim, out_dim): | |
fc = SimpleLinear(in_dim, out_dim) | |
return fc | |
def update_fc_after(self): | |
if self.old_fc is not None: | |
old_fc = self.generate_fc(self.feature_dim, sum(self.task_sizes)) | |
new_task_size = self.task_sizes[-1] | |
old_fc.weight.data = torch.cat([self.old_fc.weight.data,torch.zeros((new_task_size,self.feature_dim-self.out_dim)).cuda()],dim=0) | |
new_fc_weight = self.new_fc.weight.data | |
new_fc_bias = self.new_fc.bias.data | |
for i in range(len(self.task_sizes)-2,-1,-1): | |
new_fc_weight = torch.cat([*[self.biases[i](self.backward_prototypes.weight.data[i].unsqueeze(0),bias=False) for _ in range(self.task_sizes[i])], new_fc_weight],dim=0) | |
new_fc_bias = torch.cat([*[self.biases[i](self.backward_prototypes.bias.data[i].unsqueeze(0),bias=True) for _ in range(self.task_sizes[i])], new_fc_bias]) | |
old_fc.weight.data = torch.cat([old_fc.weight.data,new_fc_weight],dim=1) | |
old_fc.bias.data = torch.cat([self.old_fc.bias.data,torch.zeros(new_task_size).cuda()]) | |
old_fc.bias.data+=new_fc_bias | |
self.old_fc = old_fc | |
else: | |
self.old_fc = self.new_fc | |
def copy(self): | |
return copy.deepcopy(self) | |
def copy_fc(self, fc): | |
weight = copy.deepcopy(fc.weight.data) | |
bias = copy.deepcopy(fc.bias.data) | |
n, m = weight.shape[0], weight.shape[1] | |
self.fc.weight.data[:n, :m] = weight | |
self.fc.bias.data[:n] = bias | |
def freeze(self): | |
for param in self.parameters(): | |
param.requires_grad = False | |
self.eval() | |
return self | |
def freeze_conv(self): | |
for param in self.convnets.parameters(): | |
param.requires_grad = False | |
self.convnets.eval() | |
def weight_align(self, old, increment, value): | |
weights = self.fc.weight.data | |
newnorm = torch.norm(weights[-increment:, :], p=2, dim=1) | |
oldnorm = torch.norm(weights[:-increment, :], p=2, dim=1) | |
meannew = torch.mean(newnorm) | |
meanold = torch.mean(oldnorm) | |
gamma = meanold / meannew * (value ** (old / increment)) | |
logging.info("align weights, gamma = {} ".format(gamma)) | |
self.fc.weight.data[-increment:, :] *= gamma | |
class AdaptiveNet(nn.Module): | |
def __init__(self, args, pretrained): | |
super(AdaptiveNet, self).__init__() | |
self.convnet_type = args["convnet_type"] | |
self.TaskAgnosticExtractor , _network = get_convnet(args, pretrained) #Generalized blocks | |
self.TaskAgnosticExtractor.train() | |
self.AdaptiveExtractors = nn.ModuleList() #Specialized Blocks | |
self.AdaptiveExtractors.append(_network) | |
self.pretrained=pretrained | |
if args["backbone"] != None and pretrained == True: | |
self.load_checkpoint(args) | |
self.out_dim=None | |
self.fc = None | |
self.aux_fc=None | |
self.task_sizes = [] | |
self.args=args | |
def feature_dim(self): | |
if self.out_dim is None: | |
return 0 | |
return self.out_dim*len(self.AdaptiveExtractors) | |
def extract_vector(self, x): | |
base_feature_map = self.TaskAgnosticExtractor(x) | |
features = [extractor(base_feature_map) for extractor in self.AdaptiveExtractors] | |
features = torch.cat(features, 1) | |
return features | |
def forward(self, x): | |
base_feature_map = self.TaskAgnosticExtractor(x) | |
features = [extractor(base_feature_map) for extractor in self.AdaptiveExtractors] | |
features = torch.cat(features, 1) | |
out=self.fc(features) #{logits: self.fc(features)} | |
aux_logits=self.aux_fc(features[:,-self.out_dim:])["logits"] | |
out.update({"aux_logits":aux_logits,"features":features}) | |
out.update({"base_features":base_feature_map}) | |
return out | |
''' | |
{ | |
'features': features | |
'logits': logits | |
'aux_logits':aux_logits | |
} | |
''' | |
def update_fc(self,nb_classes): | |
_ , _new_extractor = get_convnet(self.args) | |
if len(self.AdaptiveExtractors)==0: | |
self.AdaptiveExtractors.append(_new_extractor) | |
else: | |
self.AdaptiveExtractors.append(_new_extractor) | |
self.AdaptiveExtractors[-1].load_state_dict(self.AdaptiveExtractors[-2].state_dict()) | |
if self.out_dim is None: | |
logging.info(self.AdaptiveExtractors[-1]) | |
self.out_dim=self.AdaptiveExtractors[-1].feature_dim | |
fc = self.generate_fc(self.feature_dim, nb_classes) | |
if self.fc is not None: | |
nb_output = self.fc.out_features | |
weight = copy.deepcopy(self.fc.weight.data) | |
bias = copy.deepcopy(self.fc.bias.data) | |
fc.weight.data[:nb_output,:self.feature_dim-self.out_dim] = weight | |
fc.bias.data[:nb_output] = bias | |
del self.fc | |
self.fc = fc | |
new_task_size = nb_classes - sum(self.task_sizes) | |
self.task_sizes.append(new_task_size) | |
self.aux_fc=self.generate_fc(self.out_dim,new_task_size+1) | |
def generate_fc(self, in_dim, out_dim): | |
fc = SimpleLinear(in_dim, out_dim) | |
return fc | |
def copy(self): | |
return copy.deepcopy(self) | |
def weight_align(self, increment): | |
weights=self.fc.weight.data | |
newnorm=(torch.norm(weights[-increment:,:],p=2,dim=1)) | |
oldnorm=(torch.norm(weights[:-increment,:],p=2,dim=1)) | |
meannew=torch.mean(newnorm) | |
meanold=torch.mean(oldnorm) | |
gamma=meanold/meannew | |
print('alignweights,gamma=',gamma) | |
self.fc.weight.data[-increment:,:]*=gamma | |
def load_checkpoint(self, args): | |
checkpoint_name = args["backbone"] | |
model_infos = torch.load(checkpoint_name) | |
model_dict = model_infos['convnet'] | |
assert len(self.AdaptiveExtractors) == 1 | |
base_state_dict = self.TaskAgnosticExtractor.state_dict() | |
adap_state_dict = self.AdaptiveExtractors[0].state_dict() | |
pretrained_base_dict = { | |
k:v | |
for k, v in model_dict.items() | |
if k in base_state_dict | |
} | |
pretrained_adap_dict = { | |
k:v | |
for k, v in model_dict.items() | |
if k in adap_state_dict | |
} | |
base_state_dict.update(pretrained_base_dict) | |
adap_state_dict.update(pretrained_adap_dict) | |
self.TaskAgnosticExtractor.load_state_dict(base_state_dict) | |
self.AdaptiveExtractors[0].load_state_dict(adap_state_dict) | |
#self.fc.load_state_dict(model_infos['fc']) | |
test_acc = model_infos['test_acc'] | |
return test_acc | |