File size: 9,558 Bytes
70d1188 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
bb = breakpoint
import torch
import trimesh
import matplotlib.pyplot as plt
import numpy as np
import os
from pathlib import Path
import pickle
import tqdm
import json
from PIL import Image
class GenericLoader(torch.utils.data.Dataset):
def __init__(self,dir="octmae_data/tiny_train/train_processed",seed=747,size=10,datasets=["fp_objaverse"],split="train",dtype=torch.float32,mode="slow",
prefetch_dino=False,dino_features=[23],view_select_mode="new_zoom",noise_std=0.0,rendered_views_mode="None",**kwargs):
super().__init__(**kwargs)
self.dir = dir
self.rng = np.random.default_rng(seed)
self.size = size
self.datasets = datasets
self.split = split
self.dtype = dtype
self.mode = mode
self.prefetch_dino = prefetch_dino
self.view_select_mode = view_select_mode
self.noise_std = noise_std * torch.iinfo(torch.uint16).max / 10.0 # variance in the range of the depth map, uint16 normalized to 10
if self.mode == 'slow':
self.prefetch_dino = True
self.find_scenes()
self.dino_features = dino_features
self.rendered_views_mode = rendered_views_mode
def find_dataset_location_list(self,dataset):
data_dir = None
for d in self.dir:
datasets = os.listdir(d)
if dataset in datasets:
if data_dir is not None:
raise ValueError(f"Dataset {dataset} found in multiple locations: {self.dir}")
else:
data_dir = os.path.join(d,dataset)
if data_dir is None:
raise ValueError(f"Dataset {dataset} not found in {self.dir}")
return data_dir
def find_dataset_location(self,dataset):
if isinstance(self.dir,list):
data_dir = self.find_dataset_location_list(dataset)
else:
data_dir = os.path.join(self.dir,dataset)
if not os.path.exists(data_dir):
raise ValueError(f"Dataset {dataset} not found in {self.dir}")
return data_dir
def find_scenes(self):
all_scenes = {}
print("Loading scenes...")
for dataset in self.datasets:
dataset_dir = self.find_dataset_location(dataset)
scenes = json.load(open(os.path.join(dataset_dir, f"{self.split}_scenes.json")))
scene_ids = [dataset + "_" + f.split("/")[-2] + "_" + f.split("/")[-1] for f in scenes]
all_scenes.update(dict(zip(scene_ids, scenes)))
self.scenes = all_scenes
self.scene_ids = list(self.scenes.keys())
# shuffle the scene ids
self.rng.shuffle(self.scene_ids)
if self.size > 0:
self.scene_ids = self.scene_ids[:self.size]
self.size = len(self.scene_ids)
return scenes
def __len__(self):
return self.size
def decide_context_view(self,cam_dir):
# we pick the view furthest away from the origin as the view for conditioning
cam_dirs = [d for d in os.listdir(cam_dir) if os.path.isdir(os.path.join(cam_dir,d)) and not d.startswith("gen")] # input cam needs rgb
extrinsics = {c:torch.load(os.path.join(cam_dir,c,'cam2world.pt'),map_location='cpu',weights_only=True) for c in cam_dirs}
dist_origin = {c:torch.linalg.norm(extrinsics[c][:3,3]) for c in extrinsics}
if self.view_select_mode == 'new_zoom':
# find the view with the maximum distance to the origin
input_cam = max(dist_origin,key=dist_origin.get)
# pick another random view to predict, excluding the context view
elif self.view_select_mode == 'random':
# pick a random view
input_cam = str(self.rng.choice(list(dist_origin.keys())))
# pick another random view to predict, excluding the context view
else:
raise ValueError(f"Invalid mode: {self.view_select_mode}")
if self.rendered_views_mode == "None":
pass
elif self.rendered_views_mode == "random":
cam_dirs = [d for d in os.listdir(cam_dir) if os.path.isdir(os.path.join(cam_dir,d))]
elif self.rendered_views_mode == "always":
cam_dirs_gen = [d for d in os.listdir(cam_dir) if os.path.isdir(os.path.join(cam_dir,d)) and d.startswith("gen")]
if len(cam_dirs_gen) > 0:
cam_dirs = cam_dirs_gen
else:
raise ValueError(f"Invalid mode: {self.rendered_views_mode}")
possible_views = [v for v in cam_dirs if v != input_cam]
new_cam = str(self.rng.choice(possible_views))
return input_cam,new_cam
def transform_pointmap(self,pointmap_cam,c2w):
# pointmap: shape H x W x 3
# cw2: shape 4 x 4
# we want to transform the pointmap to the world frame
pointmap_cam_h = torch.cat([pointmap_cam,torch.ones(pointmap_cam.shape[:-1]+(1,)).to(pointmap_cam.device)],dim=-1)
pointmap_world_h = pointmap_cam_h @ c2w.T
pointmap_world = pointmap_world_h[...,:3]/pointmap_world_h[...,3:4]
return pointmap_world
def load_scene_slow(self,input_cam,new_cam,cam_dir):
data = dict(new_cams={},input_cams={})
data['new_cams']['c2ws'] = [torch.load(os.path.join(cam_dir,new_cam,'cam2world.pt'),map_location='cpu',weights_only=True).to(self.dtype)]
data['new_cams']['depths'] = [torch.load(os.path.join(cam_dir,new_cam,'depth.pt'),map_location='cpu',weights_only=True).to(self.dtype)]
data['new_cams']['pointmaps'] = [self.transform_pointmap(torch.load(os.path.join(cam_dir,new_cam,'pointmap.pt'),map_location='cpu',weights_only=True).to(self.dtype),data['new_cams']['c2ws'][0])]
data['new_cams']['Ks'] = [torch.load(os.path.join(cam_dir,new_cam,'intrinsics.pt'),map_location='cpu',weights_only=True).to(self.dtype)]
data['new_cams']['valid_masks'] = [torch.load(os.path.join(cam_dir,new_cam,'mask.pt'),map_location='cpu',weights_only=True).to(torch.bool)]
# add the context views
data['input_cams']['c2ws'] = [torch.load(os.path.join(cam_dir,input_cam,'cam2world.pt'),map_location='cpu',weights_only=True).to(self.dtype)]
data['input_cams']['depths'] = [torch.load(os.path.join(cam_dir,input_cam,'depth.pt'),map_location='cpu',weights_only=True).to(self.dtype)]
data['input_cams']['pointmaps'] = [self.transform_pointmap(torch.load(os.path.join(cam_dir,input_cam,'pointmap.pt'),map_location='cpu',weights_only=True).to(self.dtype),data['input_cams']['c2ws'][0])]
data['input_cams']['Ks'] = [torch.load(os.path.join(cam_dir,input_cam,'intrinsics.pt'),map_location='cpu',weights_only=True).to(self.dtype)]
data['input_cams']['valid_masks'] = [torch.load(os.path.join(cam_dir,input_cam,'mask.pt'),map_location='cpu',weights_only=True).to(torch.bool)]
data['input_cams']['imgs'] = [torch.load(os.path.join(cam_dir,input_cam,'rgb.pt'),map_location='cpu',weights_only=True).to(self.dtype)]
data['input_cams']['dino_features'] = [torch.load(os.path.join(cam_dir,input_cam,f'dino_features_layer_{l}.pt'),map_location='cpu',weights_only=True).to(self.dtype) for l in self.dino_features]
return data
def depth_to_metric(self,depth):
# depth: shape H x W
# we want to convert the depth to a metric depth
depth_max = 10.0
depth_scaled = depth_max * (depth / 65535.0)
return depth_scaled
def load_scene_fast(self,input_cam,new_cam,cam_dir):
data = dict(new_cams={},input_cams={})
data['new_cams']['c2ws'] = [torch.load(os.path.join(cam_dir,new_cam,'cam2world.pt'),map_location='cpu',weights_only=True).to(self.dtype)]
data['new_cams']['Ks'] = [torch.load(os.path.join(cam_dir,new_cam,'intrinsics.pt'),map_location='cpu',weights_only=True).to(self.dtype)]
data['new_cams']['depths'] = [torch.from_numpy(np.array(Image.open(os.path.join(cam_dir,new_cam,'depth.png'))).astype(np.float32))]
data['new_cams']['valid_masks'] = [torch.from_numpy(np.array(Image.open(os.path.join(cam_dir,new_cam,'mask.png'))))]
data['input_cams']['c2ws'] = [torch.load(os.path.join(cam_dir,input_cam,'cam2world.pt'),map_location='cpu',weights_only=True).to(self.dtype)]
data['input_cams']['Ks'] = [torch.load(os.path.join(cam_dir,input_cam,'intrinsics.pt'),map_location='cpu',weights_only=True).to(self.dtype)]
data['input_cams']['depths'] = [torch.from_numpy(np.array(Image.open(os.path.join(cam_dir,input_cam,'depth.png'))).astype(np.float32))]
data['input_cams']['valid_masks'] = [torch.from_numpy(np.array(Image.open(os.path.join(cam_dir,input_cam,'mask.png'))))]
data['input_cams']['imgs'] = [torch.from_numpy(np.array(Image.open(os.path.join(cam_dir,input_cam,'rgb.png'))))]
if self.prefetch_dino:
data['input_cams']['dino_features'] = [torch.cat([torch.load(os.path.join(cam_dir,input_cam,f'dino_features_layer_{l}.pt'),map_location='cpu',weights_only=True).to(self.dtype) for l in self.dino_features],dim=-1)]
return data
def __getitem__(self,idx):
cam_dir = os.path.join(self.scenes[self.scene_ids[idx]],'cameras')
#data['input_cams'] = {k:[v[0].unsqueeze(0)] for k,v in data['input_cams'].items()}
input_cam,new_cam = self.decide_context_view(cam_dir)
if self.mode == 'slow':
data = self.load_scene_slow(input_cam,new_cam,cam_dir)
else:
data = self.load_scene_fast(input_cam,new_cam,cam_dir)
return data
|