| import torch | |
| def eval_pred(pred_dict, gt_dict,accuracy_tresh=[0.001,0.01,0.02,0.05,0.1,0.5]): | |
| pointmaps_pred = pred_dict['pointmaps'] | |
| pointmaps_gt = gt_dict['pointmaps'] | |
| mask = gt_dict['valid_masks'].unsqueeze(-1).repeat(1,1,1,3) | |
| points_pred = pointmaps_pred[mask].reshape(-1,3) | |
| points_gt = pointmaps_gt[mask].reshape(-1,3) | |
| dists = torch.norm(points_pred - points_gt, dim=1) | |
| results = {'dist':dists.mean().detach().item()} | |
| if 'classifier' in pred_dict: | |
| classifier_pred = (torch.sigmoid(pred_dict['classifier']) > 0.5).bool() | |
| classifier_gt = gt_dict['valid_masks'] | |
| results['classifier_acc'] = (classifier_pred == classifier_gt).float().mean().detach().item() | |
| for tresh in accuracy_tresh: | |
| acc = (dists < tresh).float().mean() | |
| results[f'acc_{tresh}'] = acc.detach().item() | |
| return results | |