from .dec_net import DecNet from . import resnet import torch.nn as nn import numpy as np class SpineNet(nn.Module): def __init__(self, heads, pretrained, down_ratio, final_kernel, head_conv): super(SpineNet, self).__init__() assert down_ratio in [2, 4, 8, 16] channels = [3, 64, 64, 128, 256, 512] self.l1 = int(np.log2(down_ratio)) self.base_network = resnet.resnet34(pretrained=pretrained) self.dec_net = DecNet(heads, final_kernel, head_conv, channels[self.l1]) def forward(self, x): x = self.base_network(x) dec_dict = self.dec_net(x) return dec_dict