Spaces:
Build error
Build error
| # encoding: utf-8 | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from monoscene.CRP3D import CPMegaVoxels | |
| from monoscene.modules import ( | |
| Process, | |
| Upsample, | |
| Downsample, | |
| SegmentationHead, | |
| ASPP, | |
| ) | |
| class UNet3D(nn.Module): | |
| def __init__( | |
| self, | |
| class_num, | |
| norm_layer, | |
| feature, | |
| full_scene_size, | |
| n_relations=4, | |
| project_res=[], | |
| context_prior=True, | |
| bn_momentum=0.1, | |
| ): | |
| super(UNet3D, self).__init__() | |
| self.business_layer = [] | |
| self.project_res = project_res | |
| self.feature_1_4 = feature | |
| self.feature_1_8 = feature * 2 | |
| self.feature_1_16 = feature * 4 | |
| self.feature_1_16_dec = self.feature_1_16 | |
| self.feature_1_8_dec = self.feature_1_8 | |
| self.feature_1_4_dec = self.feature_1_4 | |
| self.process_1_4 = nn.Sequential( | |
| Process(self.feature_1_4, norm_layer, bn_momentum, dilations=[1, 2, 3]), | |
| Downsample(self.feature_1_4, norm_layer, bn_momentum), | |
| ) | |
| self.process_1_8 = nn.Sequential( | |
| Process(self.feature_1_8, norm_layer, bn_momentum, dilations=[1, 2, 3]), | |
| Downsample(self.feature_1_8, norm_layer, bn_momentum), | |
| ) | |
| self.up_1_16_1_8 = Upsample( | |
| self.feature_1_16_dec, self.feature_1_8_dec, norm_layer, bn_momentum | |
| ) | |
| self.up_1_8_1_4 = Upsample( | |
| self.feature_1_8_dec, self.feature_1_4_dec, norm_layer, bn_momentum | |
| ) | |
| self.ssc_head_1_4 = SegmentationHead( | |
| self.feature_1_4_dec, self.feature_1_4_dec, class_num, [1, 2, 3] | |
| ) | |
| self.context_prior = context_prior | |
| size_1_16 = tuple(np.ceil(i / 4).astype(int) for i in full_scene_size) | |
| if context_prior: | |
| self.CP_mega_voxels = CPMegaVoxels( | |
| self.feature_1_16, | |
| size_1_16, | |
| n_relations=n_relations, | |
| bn_momentum=bn_momentum, | |
| ) | |
| # | |
| def forward(self, input_dict): | |
| res = {} | |
| x3d_1_4 = input_dict["x3d"] | |
| x3d_1_8 = self.process_1_4(x3d_1_4) | |
| x3d_1_16 = self.process_1_8(x3d_1_8) | |
| if self.context_prior: | |
| ret = self.CP_mega_voxels(x3d_1_16) | |
| x3d_1_16 = ret["x"] | |
| for k in ret.keys(): | |
| res[k] = ret[k] | |
| x3d_up_1_8 = self.up_1_16_1_8(x3d_1_16) + x3d_1_8 | |
| x3d_up_1_4 = self.up_1_8_1_4(x3d_up_1_8) + x3d_1_4 | |
| ssc_logit_1_4 = self.ssc_head_1_4(x3d_up_1_4) | |
| res["ssc_logit"] = ssc_logit_1_4 | |
| return res | |