| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from pointnet2_utils import PointNetSetAbstractionMsg, PointNetSetAbstraction | |
| class get_model(nn.Module): | |
| def __init__(self,num_class,normal_channel=True): | |
| super(get_model, self).__init__() | |
| in_channel = 3 if normal_channel else 0 | |
| self.normal_channel = normal_channel | |
| self.sa1 = PointNetSetAbstractionMsg(512, [0.1, 0.2, 0.4], [16, 32, 128], in_channel,[[32, 32, 64], [64, 64, 128], [64, 96, 128]]) | |
| self.sa2 = PointNetSetAbstractionMsg(128, [0.2, 0.4, 0.8], [32, 64, 128], 320,[[64, 64, 128], [128, 128, 256], [128, 128, 256]]) | |
| self.sa3 = PointNetSetAbstraction(None, None, None, 640 + 3, [256, 512, 1024], True) | |
| self.fc1 = nn.Linear(1024, 512) | |
| self.bn1 = nn.BatchNorm1d(512) | |
| self.drop1 = nn.Dropout(0.4) | |
| self.fc2 = nn.Linear(512, 256) | |
| self.bn2 = nn.BatchNorm1d(256) | |
| self.drop2 = nn.Dropout(0.5) | |
| self.fc3 = nn.Linear(256, num_class) | |
| def forward(self, xyz): | |
| B, _, _ = xyz.shape | |
| if self.normal_channel: | |
| norm = xyz[:, 3:, :] | |
| xyz = xyz[:, :3, :] | |
| else: | |
| norm = None | |
| l1_xyz, l1_points = self.sa1(xyz, norm) | |
| l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) | |
| l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) | |
| x = l3_points.view(B, 1024) | |
| x = self.drop1(F.relu(self.bn1(self.fc1(x)))) | |
| x = self.drop2(F.relu(self.bn2(self.fc2(x)))) | |
| x = self.fc3(x) | |
| x = F.log_softmax(x, -1) | |
| return x,l3_points | |
| class get_loss(nn.Module): | |
| def __init__(self): | |
| super(get_loss, self).__init__() | |
| def forward(self, pred, target, trans_feat): | |
| total_loss = F.nll_loss(pred, target) | |
| return total_loss | |