Spaces:
Runtime error
Runtime error
| import pytorch_lightning as pl | |
| import torch | |
| import torch.nn as nn | |
| from monoscene.unet3d_nyu import UNet3D as UNet3DNYU | |
| from monoscene.unet3d_kitti import UNet3D as UNet3DKitti | |
| from monoscene.flosp import FLoSP | |
| import numpy as np | |
| import torch.nn.functional as F | |
| from monoscene.unet2d import UNet2D | |
| class MonoScene(pl.LightningModule): | |
| def __init__( | |
| self, | |
| n_classes, | |
| feature, | |
| project_scale, | |
| full_scene_size, | |
| dataset, | |
| project_res=["1", "2", "4", "8"], | |
| n_relations=4, | |
| context_prior=True, | |
| fp_loss=True, | |
| frustum_size=4, | |
| relation_loss=False, | |
| CE_ssc_loss=True, | |
| geo_scal_loss=True, | |
| sem_scal_loss=True, | |
| lr=1e-4, | |
| weight_decay=1e-4, | |
| ): | |
| super().__init__() | |
| self.project_res = project_res | |
| self.fp_loss = fp_loss | |
| self.dataset = dataset | |
| self.context_prior = context_prior | |
| self.frustum_size = frustum_size | |
| self.relation_loss = relation_loss | |
| self.CE_ssc_loss = CE_ssc_loss | |
| self.sem_scal_loss = sem_scal_loss | |
| self.geo_scal_loss = geo_scal_loss | |
| self.project_scale = project_scale | |
| self.lr = lr | |
| self.weight_decay = weight_decay | |
| self.projects = {} | |
| self.scale_2ds = [1, 2, 4, 8] # 2D scales | |
| for scale_2d in self.scale_2ds: | |
| self.projects[str(scale_2d)] = FLoSP( | |
| full_scene_size, project_scale=self.project_scale, dataset=self.dataset | |
| ) | |
| self.projects = nn.ModuleDict(self.projects) | |
| self.n_classes = n_classes | |
| if self.dataset == "NYU": | |
| self.net_3d_decoder = UNet3DNYU( | |
| self.n_classes, | |
| nn.BatchNorm3d, | |
| n_relations=n_relations, | |
| feature=feature, | |
| full_scene_size=full_scene_size, | |
| context_prior=context_prior, | |
| ) | |
| elif self.dataset == "kitti": | |
| self.net_3d_decoder = UNet3DKitti( | |
| self.n_classes, | |
| nn.BatchNorm3d, | |
| project_scale=project_scale, | |
| feature=feature, | |
| full_scene_size=full_scene_size, | |
| context_prior=context_prior, | |
| ) | |
| self.net_rgb = UNet2D.build(out_feature=feature, use_decoder=True) | |
| def forward(self, batch): | |
| img = batch["img"] | |
| bs = len(img) | |
| out = {} | |
| x_rgb = self.net_rgb(img) | |
| x3ds = [] | |
| for i in range(bs): | |
| x3d = None | |
| for scale_2d in self.project_res: | |
| # project features at each 2D scale to target 3D scale | |
| scale_2d = int(scale_2d) | |
| projected_pix = batch["projected_pix_{}".format(self.project_scale)][i]#.cuda() | |
| fov_mask = batch["fov_mask_{}".format(self.project_scale)][i]#.cuda() | |
| # Sum all the 3D features | |
| if x3d is None: | |
| x3d = self.projects[str(scale_2d)]( | |
| x_rgb["1_" + str(scale_2d)][i], | |
| # torch.div(projected_pix, scale_2d, rounding_mode='floor'), | |
| projected_pix // scale_2d, | |
| fov_mask, | |
| ) | |
| else: | |
| x3d += self.projects[str(scale_2d)]( | |
| x_rgb["1_" + str(scale_2d)][i], | |
| # torch.div(projected_pix, scale_2d, rounding_mode='floor'), | |
| projected_pix // scale_2d, | |
| fov_mask, | |
| ) | |
| x3ds.append(x3d) | |
| input_dict = { | |
| "x3d": torch.stack(x3ds), | |
| } | |
| out_dict = self.net_3d_decoder(input_dict) | |
| ssc_pred = out_dict["ssc_logit"] | |
| y_pred = ssc_pred.detach().cpu().numpy() | |
| y_pred = np.argmax(y_pred, axis=1) | |
| return y_pred | |