bartduis's picture
Update eval_wrapper/eval.py
736cf6e verified
raw
history blame
21.3 kB
from PIL import Image
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import os
import sys
current_dir = os.getcwd()
sys.path.append(current_dir)
from eval_wrapper.sample_poses import pointmap_to_poses
from utils.fusion import fuse_batch
from models.rayquery import *
from models.losses import *
import argparse
from utils import misc
import torch.distributed as dist
from utils.collate import collate
from engine import eval_model
from utils.viz import just_load_viz
from utils.geometry import compute_pointmap_torch
from eval_wrapper.eval_utils import filter_all_masks
from huggingface_hub import hf_hub_download
class EvalWrapper(torch.nn.Module):
def __init__(self,checkpoint_path,distributed=False,device="cuda",dtype=torch.float32,**kwargs):
super().__init__()
checkpoint = torch.load(checkpoint_path, map_location='cpu',weights_only=False)
model_string = checkpoint['args'].model
self.model = eval(model_string).to(device)
if distributed:
rank, world_size, local_rank = misc.setup_distributed()
self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[local_rank],find_unused_parameters=True)
self.dtype = dtype
self.model.load_state_dict(checkpoint['model'])
self.model.eval()
def to(self,device):
self.model.to(device)
def forward(self,x,dino_model=None):
pred, gt, loss, scale = eval_model(self.model,x,mode='viz',dino_model=dino_model,return_scale=True)
return pred, gt, loss, scale
class PostProcessWrapper(torch.nn.Module):
def __init__(self,pred_mask_threshold = 0.5, mode='novel_views',
debug=False,conf_dist_mode='isotonic',set_conf=None,percentile=20,
no_input_mask=False,no_pred_mask=False):
super().__init__()
self.pred_mask_threshold = pred_mask_threshold
self.mode = mode
self.debug = debug
self.conf_dist_mode = conf_dist_mode
self.set_conf = set_conf
self.percentile = percentile
self.no_input_mask = no_input_mask
self.no_pred_mask = no_pred_mask
def transform_pointmap(self,pointmap_cam,c2w):
# pointmap: shape H x W x 3
# cw2: shape 4 x 4
# we want to transform the pointmap to the world frame
pointmap_cam_h = torch.cat([pointmap_cam,torch.ones(pointmap_cam.shape[:-1]+(1,)).to(pointmap_cam.device)],dim=-1)
pointmap_world_h = pointmap_cam_h @ c2w.T
pointmap_world = pointmap_world_h[...,:3]/pointmap_world_h[...,3:4]
return pointmap_world
def reject_conf_points(self,conf_pts):
if self.set_conf is None:
raise ValueError("set_conf must be set")
conf_mask = conf_pts > self.set_conf
return conf_mask
def project_input_mask(self,pred_dict,batch):
input_mask = batch['input_cams']['original_valid_masks'][0][0] # shape H x W
input_c2w = batch['input_cams']['c2ws'][0][0]
input_w2c = torch.linalg.inv(input_c2w)
input_K = batch['input_cams']['Ks'][0][0]
H, W = input_mask.shape
pointmaps_input_cam = torch.stack([self.transform_pointmap(pmap,input_w2c@c2w) for pmap,c2w in zip(pred_dict['pointmaps'][0],batch['new_cams']['c2ws'][0])]) # bp: Assuming batch size is 1!!
img_coords = pointmaps_input_cam @ input_K.T
img_coords = (img_coords[...,:2]/img_coords[...,2:3]).int()
n_views, H, W = img_coords.shape[:3]
device = input_mask.device
if self.no_input_mask:
combined_mask = torch.ones((n_views, H, W), device=device)
else:
combined_mask = torch.zeros((n_views, H, W), device=device)
# Flatten spatial dims
xs = img_coords[..., 0].view(n_views, -1) # [V, H*W]
ys = img_coords[..., 1].view(n_views, -1) # [V, H*W]
# Create base pixel coords (i, j)
i_coords = torch.arange(H, device=device).view(-1, 1).expand(H, W).reshape(-1) # [H*W]
j_coords = torch.arange(W, device=device).view(1, -1).expand(H, W).reshape(-1) # [H*W]
mask_coords = torch.stack((i_coords, j_coords), dim=-1) # [H*W, 2], shared across views
# Mask for valid projections
valid = (xs >= 0) & (xs < W) & (ys >= 0) & (ys < H) # [V, H*W]
# Clip out-of-bounds coords for indexing (only valid will be used anyway)
xs_clipped = torch.clamp(xs, 0, W-1)
ys_clipped = torch.clamp(ys, 0, H-1)
# input_mask lookup per view
flat_input_mask = input_mask[ys_clipped, xs_clipped] # [V, H*W]
input_mask_mask = flat_input_mask & valid # apply valid range mask
# Apply mask to coords and depths
depth_points = pointmaps_input_cam[..., -1].view(n_views, -1) # [V, H*W]
input_depths = batch['input_cams']['depths'][0][0][ys_clipped, xs_clipped] # [V, H*W]
depth_mask = (depth_points > input_depths) & input_mask_mask # final mask [V, H*W]
#depth_mask = input_mask_mask # final mask [V, H*W]
# Get final (i,j) coords to write
final_i = mask_coords[:, 0].unsqueeze(0).expand(n_views, -1)[depth_mask] # [N_mask]
final_j = mask_coords[:, 1].unsqueeze(0).expand(n_views, -1)[depth_mask] # [N_mask]
final_view_idx = torch.arange(n_views, device=device).view(-1, 1).expand(-1, H*W)[depth_mask] # [N_mask]
# Scatter final mask
combined_mask[final_view_idx, final_i, final_j] = 1
return combined_mask.unsqueeze(0).bool()
def forward(self,pred_dict,batch):
if self.mode == 'novel_views':
project_masks = self.project_input_mask(pred_dict,batch)
pred_mask_raw = torch.sigmoid(pred_dict['classifier'])
if self.no_pred_mask:
pred_masks = torch.ones_like(project_masks).bool()
else:
pred_masks = (pred_mask_raw > self.pred_mask_threshold).bool()
conf_masks = self.reject_conf_points(pred_dict['conf_pointmaps'])
combined_mask = project_masks & pred_masks & conf_masks
batch['new_cams']['valid_masks'] = combined_mask
elif self.mode == 'input_view':
conf_masks = self.reject_conf_points(pred_dict['conf_pointmaps'])
if self.no_pred_mask:
pred_masks = torch.ones_like(conf_masks).bool()
else:
pred_mask_raw = torch.sigmoid(pred_dict['classifier'])
pred_masks = (pred_mask_raw > self.pred_mask_threshold).bool()
combined_mask = conf_masks & batch['new_cams']['valid_masks'] & pred_masks
batch['new_cams']['valid_masks'] = combined_mask # this is for visualization
return pred_dict, batch
class GenericLoaderSmall(torch.utils.data.Dataset):
def __init__(self,data_dir,mode="single_scene",dtype=torch.float32,n_pred_views=3,pred_input_only=False,min_depth=0.1,
pointmap_for_bb=None,run_octmae=False,false_positive=None,false_negative=None):
self.data_dir = data_dir
self.mode = mode
self.dtype = dtype
self.rng = np.random.RandomState(seed=42)
self.n_pred_views = n_pred_views
self.min_depth = self.depth_metric_to_uint16(min_depth)
if self.mode == "single_scene":
self.inputs = [data_dir]
self.pred_input_only = pred_input_only
if self.pred_input_only:
self.n_pred_views = 1
self.desired_resolution = (480,640)
self.resize_transform_rgb = transforms.Resize(self.desired_resolution)
self.resize_transform_depth = transforms.Resize(self.desired_resolution,interpolation=transforms.InterpolationMode.NEAREST)
self.pointmap_for_bb = pointmap_for_bb
self.run_octmae = run_octmae
self.false_positive = false_positive
self.false_negative = false_negative
def transform_pointmap(self,pointmap_cam,c2w):
# pointmap: shape H x W x 3
# cw2: shape 4 x 4
# we want to transform the pointmap to the world frame
pointmap_cam_h = torch.cat([pointmap_cam,torch.ones(pointmap_cam.shape[:-1]+(1,)).to(pointmap_cam.device)],dim=-1)
pointmap_world_h = pointmap_cam_h @ c2w.T
pointmap_world = pointmap_world_h[...,:3]/pointmap_world_h[...,3:4]
return pointmap_world
def __len__(self):
return len(self.inputs)
def look_at(self,cam_pos, center=(0,0,0), up=(0,0,1)):
z = center - cam_pos
z /= np.linalg.norm(z, axis=-1, keepdims=True)
y = -np.float32(up)
y = y - np.sum(y * z, axis=-1, keepdims=True) * z
y /= np.linalg.norm(y, axis=-1, keepdims=True)
x = np.cross(y, z, axis=-1)
cam2w = np.r_[np.c_[x,y,z,cam_pos],[[0,0,0,1]]]
return cam2w.astype(np.float32)
def find_new_views(self,n_views,geometric_median = (0,0,0),r_min=0.4,r_max=0.9):
rad = self.rng.uniform(r_min,r_max, size=n_views)
azi = self.rng.uniform(0, 2*np.pi, size=n_views)
ele = self.rng.uniform(-np.pi, np.pi, size=n_views)
cam_centers = np.c_[np.cos(azi), np.sin(azi)]
cam_centers = rad[:,None] * np.c_[np.cos(ele)[:,None]*cam_centers, np.sin(ele)] + geometric_median
c2ws = [self.look_at(cam_pos=cam_center,center=geometric_median) for cam_center in cam_centers]
return c2ws
def depth_uint16_to_metric(self,depth):
return depth / torch.iinfo(torch.uint16).max * 10.0 # threshold is in m, convert to uint16 value
def depth_metric_to_uint16(self,depth):
return depth * torch.iinfo(torch.uint16).max / 10.0 # threshold is in m, convert to uint16 value
def resize(self,depth,img,mask,K):
s_x = self.desired_resolution[1] / img.shape[1]
s_y = self.desired_resolution[0] / img.shape[0]
depth = self.resize_transform_depth(depth.unsqueeze(0)).squeeze(0)
img = self.resize_transform_rgb(img.permute(-1,0,1)).permute(1,2,0)
mask = self.resize_transform_depth(mask.unsqueeze(0)).squeeze(0)
K[0] *= s_x
K[1] *= s_y
return depth, img, mask, K
def add_false_positives_and_negatives(self,valid_mask,false_positive,false_negative):
# add false positives to the valid mask
# add false negatives to the valid mask
# return the new valid mask
n_total_pixels = valid_mask.sum()
n_pixels_left = n_total_pixels * (1-false_positive)
mask_pixels_coords = torch.where(valid_mask)
left_pixels_coords = torch.where(~valid_mask)
# false positives
n_false_positives = min(int(n_pixels_left * false_positive),n_pixels_left)
# randomly sample n_false_positives from mask_pixels_coords
false_positives = torch.randperm(len(left_pixels_coords[0]))[:n_false_positives]
valid_mask[left_pixels_coords[0][false_positives],left_pixels_coords[1][false_positives]] = 1
# false negatives
n_false_negatives = min(int(n_total_pixels * false_negative),n_total_pixels)
# randomly sample n_false_negatives from left_pixels_coords
false_negatives = torch.randperm(len(mask_pixels_coords[0]))[:n_false_negatives]
valid_mask[mask_pixels_coords[0][false_negatives],mask_pixels_coords[1][false_negatives]] = 0
return valid_mask
def __getitem__(self,idx):
scene_dir = self.inputs[idx]
data = dict(new_cams={},input_cams={})
c2w_path = os.path.join(scene_dir,'cam2world.pt')
if os.path.exists(c2w_path):
data['input_cams']['c2ws_original'] = [torch.load(c2w_path,map_location='cpu',weights_only=True).to(self.dtype)]
else:
data['input_cams']['c2ws_original'] = [torch.eye(4).to(self.dtype)]
data['input_cams']['c2ws'] = [torch.eye(4).to(self.dtype)]
data['input_cams']['Ks'] = [torch.load(os.path.join(scene_dir,'intrinsics.pt'),map_location='cpu',weights_only=True).to(self.dtype)]
data['input_cams']['depths'] = [torch.from_numpy(np.array(Image.open(os.path.join(scene_dir,'depth.png'))).astype(np.float32))]
data['input_cams']['valid_masks'] = [torch.from_numpy(np.array(Image.open(os.path.join(scene_dir,'mask.png')))).bool()]
data['input_cams']['imgs'] = [torch.from_numpy(np.array(Image.open(os.path.join(scene_dir,'rgb.png'))))]
if self.false_positive is not None or self.false_negative is not None:
data['input_cams']['valid_masks'][0] = self.add_false_positives_and_negatives(data['input_cams']['valid_masks'][0],self.false_positive,self.false_negative)
if data['input_cams']['depths'][0].shape != self.desired_resolution:
data['input_cams']['depths'][0], data['input_cams']['imgs'][0], data['input_cams']['valid_masks'][0], data['input_cams']['Ks'][0] = \
self.resize(data['input_cams']['depths'][0], data['input_cams']['imgs'][0], data['input_cams']['valid_masks'][0], data['input_cams']['Ks'][0])
data['input_cams']['original_valid_masks'] = [data['input_cams']['valid_masks'][0].clone()]
data['input_cams']['valid_masks'][0] = data['input_cams']['valid_masks'][0] & \
(data['input_cams']['depths'][0] > self.min_depth)
if self.pred_input_only:
c2ws = [data['input_cams']['c2ws'][0].cpu().numpy()]
else:
input_mask = data['input_cams']['valid_masks'][0]
if self.pointmap_for_bb is not None:
pointmap_input = self.pointmap_for_bb
else:
pointmap_input = compute_pointmap_torch(self.depth_uint16_to_metric(data['input_cams']['depths'][0]),data['input_cams']['c2ws'][0],data['input_cams']['Ks'][0],device='cpu')[input_mask]
c2ws = pointmap_to_poses(pointmap_input, self.n_pred_views, inner_radius=1.1, outer_radius=2.5, device='cpu',run_octmae=self.run_octmae)
self.n_pred_views = len(c2ws)
data['new_cams'] = {}
data['new_cams']['c2ws'] = [torch.from_numpy(c2w).to(self.dtype) for c2w in c2ws]
data['new_cams']['depths'] = [torch.zeros_like(data['input_cams']['depths'][0]) for _ in range(self.n_pred_views)]
data['new_cams']['Ks'] = [data['input_cams']['Ks'][0] for _ in range(self.n_pred_views)]
if self.pred_input_only:
data['new_cams']['valid_masks'] = data['input_cams']['original_valid_masks']
else:
data['new_cams']['valid_masks'] = [torch.ones_like(data['input_cams']['valid_masks'][0]) for _ in range(self.n_pred_views)]
return data
def dict_to_float(d):
return {k: v.float() for k, v in d.items()}
def merge_dicts(d1,d2):
# stack the tensors along dimension 1
for k,v in d1.items():
d1[k] = torch.cat([d1[k],d2[k]],dim=1)
return d1
def compute_all_points(pred_dict,batch):
n_views = pred_dict['depths'].shape[1]
all_points = None
for i in range(n_views):
mask = batch['new_cams']['valid_masks'][0,i]
pointmap = compute_pointmap_torch(pred_dict['depths'][0,i],batch['new_cams']['c2ws'][0,i],batch['new_cams']['Ks'][0,i])
masked_points = pointmap[mask]
if all_points is None:
all_points = masked_points
else:
all_points = torch.cat([all_points,masked_points],dim=0)
return all_points
def eval_scene(model, data_dir,visualize=False,rr_addr=None,run_octmae=False,set_conf=5,
no_input_mask=False,no_pred_mask=False,no_filter_input_view=False,false_positive=None,false_negative=None,n_pred_views=5,
do_filter_all_masks=False, dino_model=None,tsdf=False, device = 'cpu'):
if dino_model is None:
# Loading DINOv2 model
dino_model = torch.hub.load('facebookresearch/dinov2', "dinov2_vitl14_reg")
dino_model.eval()
dino_model.to(device)
dataloader_input_view = GenericLoaderSmall(data_dir,n_pred_views=1,pred_input_only=True,false_positive=false_positive,false_negative=false_negative)
input_view_loader = DataLoader(dataloader_input_view, batch_size=1, shuffle=True, collate_fn=collate)
input_view_batch = next(iter(input_view_loader))
postprocessor_input_view = PostProcessWrapper(mode='input_view',set_conf=set_conf,
no_input_mask=no_input_mask,no_pred_mask=no_pred_mask)
postprocessor_pred_views = PostProcessWrapper(mode='novel_views',debug=False,set_conf=set_conf,
no_input_mask=no_input_mask,no_pred_mask=no_pred_mask)
fused_meshes = None
with torch.no_grad():
pred_input_view, gt_input_view, _, scale_factor = model(input_view_batch,dino_model)
if no_filter_input_view:
pred_input_view['pointmaps'] = input_view_batch['input_cams']['pointmaps']
pred_input_view['depths'] = input_view_batch['input_cams']['depths']
else:
pred_input_view, input_view_batch = postprocessor_input_view(pred_input_view,input_view_batch)
input_points = pred_input_view['pointmaps'][0][0][input_view_batch['new_cams']['valid_masks'][0][0]] * (1.0/scale_factor)
if input_points.shape[0] == 0:
input_points = None
dataloader_pred_views = GenericLoaderSmall(data_dir,n_pred_views=n_pred_views,pred_input_only=False,
pointmap_for_bb=input_points,run_octmae=run_octmae)
pred_views_loader = DataLoader(dataloader_pred_views, batch_size=1, shuffle=True, collate_fn=collate)
pred_views_batch = next(iter(pred_views_loader))
# this is for the mask ablation
if (false_positive is not None or false_negative is not None) and input_points is not None:
pred_views_batch['input_cams']['valid_masks'] = input_view_batch['input_cams']['valid_masks']
pred_new_views, gt_new_views, _, scale_factor = model(pred_views_batch,dino_model)
pred_new_views, pred_views_batch = postprocessor_pred_views(pred_new_views,pred_views_batch)
pred = merge_dicts(dict_to_float(pred_input_view),dict_to_float(pred_new_views))
gt = merge_dicts(dict_to_float(gt_input_view),dict_to_float(gt_new_views))
batch = copy.deepcopy(input_view_batch)
batch['new_cams'] = merge_dicts(input_view_batch['new_cams'],pred_views_batch['new_cams'])
gt['pointmaps'] = None # make sure it's not used in viz
if do_filter_all_masks:
batch = filter_all_masks(pred,input_view_batch,max_outlier_views=1)
# scale factor is the scale we applied to the input view for inference
all_points = compute_all_points(pred,batch)
all_points = all_points*(1.0/scale_factor)
# transform all_points to the original coordinate system
all_points_h = torch.cat([all_points,torch.ones(all_points.shape[:-1]+(1,)).to(all_points.device)],dim=-1)
all_points_original = all_points_h @ batch['input_cams']['c2ws_original'][0][0].T
all_points = all_points_original[...,:3]
# uncomment this to visualize a simple TSDF
if tsdf:
fused_meshes = fuse_batch(pred,gt,batch,voxel_size=0.002)
else:
fused_meshes = None
if visualize:
just_load_viz(pred, gt, batch, addr=rr_addr,fused_meshes=fused_meshes)
return all_points
def main():
parser = argparse.ArgumentParser()
parser.add_argument("data_dir", type=str)
parser.add_argument("--rr_addr", type=str, default="0.0.0.0:"+os.getenv("RERUN_RECORDING","9876"))
parser.add_argument("--visualize", action="store_true", default=False)
parser.add_argument("--run_octmae", action="store_true", default=False)
parser.add_argument("--set_conf", type=float, default=5)
parser.add_argument("--n_pred_views", type=int, default=5)
parser.add_argument("--filter_all_masks", action="store_true", default=False)
parser.add_argument("--tsdf", action="store_true", default=False)
# ablation settings
parser.add_argument("--no_input_mask", action="store_true", default=False)
parser.add_argument("--no_pred_mask", action="store_true", default=False)
parser.add_argument("--no_filter_input_view", action="store_true", default=False)
parser.add_argument("--false_positive", type=float, default=None)
parser.add_argument("--false_negative", type=float, default=None)
args = parser.parse_args()
print("Loading checkpoint from Huggingface")
rayst3r_checkpoint = hf_hub_download("bartduis/rayst3r", "rayst3r.pth")
model = EvalWrapper(rayst3r_checkpoint,distributed=False)
all_points = eval_scene(model, args.data_dir,visualize=args.visualize,rr_addr=args.rr_addr,run_octmae=args.run_octmae,set_conf=args.set_conf,
no_input_mask=args.no_input_mask,no_pred_mask=args.no_pred_mask,no_filter_input_view=args.no_filter_input_view,false_positive=args.false_positive,
false_negative=args.false_negative,n_pred_views=args.n_pred_views,
do_filter_all_masks=args.filter_all_masks,tsdf=args.tsdf).cpu().numpy()
all_points_save = os.path.join(args.data_dir,"inference_points.ply")
if __name__ == "__main__":
main()