import torch.nn as nn import torch from .model_parts import CombinationModule class DecNet(nn.Module): def __init__(self, heads, final_kernel, head_conv, channel): super(DecNet, self).__init__() self.dec_c2 = CombinationModule(128, 64, batch_norm=True) self.dec_c3 = CombinationModule(256, 128, batch_norm=True) self.dec_c4 = CombinationModule(512, 256, batch_norm=True) self.heads = heads for head in self.heads: classes = self.heads[head] if head == 'wh': fc = nn.Sequential(nn.Conv2d(channel, head_conv, kernel_size=7, padding=7//2, bias=True), nn.ReLU(inplace=True), nn.Conv2d(head_conv, classes, kernel_size=7, padding=7 // 2, bias=True)) else: fc = nn.Sequential(nn.Conv2d(channel, head_conv, kernel_size=3, padding=1, bias=True), nn.ReLU(inplace=True), nn.Conv2d(head_conv, classes, kernel_size=final_kernel, stride=1, padding=final_kernel // 2, bias=True)) if 'hm' in head: fc[-1].bias.data.fill_(-2.19) else: self.fill_fc_weights(fc) self.__setattr__(head, fc) def fill_fc_weights(self, layers): for m in layers.modules(): if isinstance(m, nn.Conv2d): if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): c4_combine = self.dec_c4(x[-1], x[-2]) c3_combine = self.dec_c3(c4_combine, x[-3]) c2_combine = self.dec_c2(c3_combine, x[-4]) dec_dict = {} for head in self.heads: dec_dict[head] = self.__getattr__(head)(c2_combine) if 'hm' in head: dec_dict[head] = torch.sigmoid(dec_dict[head]) return dec_dict