import torch def eval_pred(pred_dict, gt_dict,accuracy_tresh=[0.001,0.01,0.02,0.05,0.1,0.5]): pointmaps_pred = pred_dict['pointmaps'] pointmaps_gt = gt_dict['pointmaps'] mask = gt_dict['valid_masks'].unsqueeze(-1).repeat(1,1,1,3) points_pred = pointmaps_pred[mask].reshape(-1,3) points_gt = pointmaps_gt[mask].reshape(-1,3) dists = torch.norm(points_pred - points_gt, dim=1) results = {'dist':dists.mean().detach().item()} if 'classifier' in pred_dict: classifier_pred = (torch.sigmoid(pred_dict['classifier']) > 0.5).bool() classifier_gt = gt_dict['valid_masks'] results['classifier_acc'] = (classifier_pred == classifier_gt).float().mean().detach().item() for tresh in accuracy_tresh: acc = (dists < tresh).float().mean() results[f'acc_{tresh}'] = acc.detach().item() return results