|
bb = breakpoint |
|
import torch |
|
import trimesh |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import os |
|
from pathlib import Path |
|
import pickle |
|
import tqdm |
|
import json |
|
from PIL import Image |
|
|
|
class GenericLoader(torch.utils.data.Dataset): |
|
def __init__(self,dir="octmae_data/tiny_train/train_processed",seed=747,size=10,datasets=["fp_objaverse"],split="train",dtype=torch.float32,mode="slow", |
|
prefetch_dino=False,dino_features=[23],view_select_mode="new_zoom",noise_std=0.0,rendered_views_mode="None",**kwargs): |
|
super().__init__(**kwargs) |
|
self.dir = dir |
|
self.rng = np.random.default_rng(seed) |
|
self.size = size |
|
self.datasets = datasets |
|
self.split = split |
|
self.dtype = dtype |
|
self.mode = mode |
|
self.prefetch_dino = prefetch_dino |
|
self.view_select_mode = view_select_mode |
|
self.noise_std = noise_std * torch.iinfo(torch.uint16).max / 10.0 |
|
if self.mode == 'slow': |
|
self.prefetch_dino = True |
|
self.find_scenes() |
|
self.dino_features = dino_features |
|
self.rendered_views_mode = rendered_views_mode |
|
|
|
def find_dataset_location_list(self,dataset): |
|
data_dir = None |
|
for d in self.dir: |
|
datasets = os.listdir(d) |
|
if dataset in datasets: |
|
if data_dir is not None: |
|
raise ValueError(f"Dataset {dataset} found in multiple locations: {self.dir}") |
|
else: |
|
data_dir = os.path.join(d,dataset) |
|
if data_dir is None: |
|
raise ValueError(f"Dataset {dataset} not found in {self.dir}") |
|
return data_dir |
|
|
|
def find_dataset_location(self,dataset): |
|
if isinstance(self.dir,list): |
|
data_dir = self.find_dataset_location_list(dataset) |
|
else: |
|
data_dir = os.path.join(self.dir,dataset) |
|
if not os.path.exists(data_dir): |
|
raise ValueError(f"Dataset {dataset} not found in {self.dir}") |
|
return data_dir |
|
|
|
def find_scenes(self): |
|
all_scenes = {} |
|
print("Loading scenes...") |
|
for dataset in self.datasets: |
|
dataset_dir = self.find_dataset_location(dataset) |
|
scenes = json.load(open(os.path.join(dataset_dir, f"{self.split}_scenes.json"))) |
|
scene_ids = [dataset + "_" + f.split("/")[-2] + "_" + f.split("/")[-1] for f in scenes] |
|
all_scenes.update(dict(zip(scene_ids, scenes))) |
|
self.scenes = all_scenes |
|
self.scene_ids = list(self.scenes.keys()) |
|
|
|
self.rng.shuffle(self.scene_ids) |
|
if self.size > 0: |
|
self.scene_ids = self.scene_ids[:self.size] |
|
self.size = len(self.scene_ids) |
|
return scenes |
|
|
|
def __len__(self): |
|
return self.size |
|
|
|
def decide_context_view(self,cam_dir): |
|
|
|
cam_dirs = [d for d in os.listdir(cam_dir) if os.path.isdir(os.path.join(cam_dir,d)) and not d.startswith("gen")] |
|
|
|
extrinsics = {c:torch.load(os.path.join(cam_dir,c,'cam2world.pt'),map_location='cpu',weights_only=True) for c in cam_dirs} |
|
dist_origin = {c:torch.linalg.norm(extrinsics[c][:3,3]) for c in extrinsics} |
|
|
|
if self.view_select_mode == 'new_zoom': |
|
|
|
input_cam = max(dist_origin,key=dist_origin.get) |
|
|
|
elif self.view_select_mode == 'random': |
|
|
|
input_cam = str(self.rng.choice(list(dist_origin.keys()))) |
|
|
|
else: |
|
raise ValueError(f"Invalid mode: {self.view_select_mode}") |
|
|
|
if self.rendered_views_mode == "None": |
|
pass |
|
elif self.rendered_views_mode == "random": |
|
cam_dirs = [d for d in os.listdir(cam_dir) if os.path.isdir(os.path.join(cam_dir,d))] |
|
elif self.rendered_views_mode == "always": |
|
cam_dirs_gen = [d for d in os.listdir(cam_dir) if os.path.isdir(os.path.join(cam_dir,d)) and d.startswith("gen")] |
|
if len(cam_dirs_gen) > 0: |
|
cam_dirs = cam_dirs_gen |
|
else: |
|
raise ValueError(f"Invalid mode: {self.rendered_views_mode}") |
|
|
|
possible_views = [v for v in cam_dirs if v != input_cam] |
|
new_cam = str(self.rng.choice(possible_views)) |
|
return input_cam,new_cam |
|
|
|
def transform_pointmap(self,pointmap_cam,c2w): |
|
|
|
|
|
|
|
pointmap_cam_h = torch.cat([pointmap_cam,torch.ones(pointmap_cam.shape[:-1]+(1,)).to(pointmap_cam.device)],dim=-1) |
|
pointmap_world_h = pointmap_cam_h @ c2w.T |
|
pointmap_world = pointmap_world_h[...,:3]/pointmap_world_h[...,3:4] |
|
return pointmap_world |
|
|
|
def load_scene_slow(self,input_cam,new_cam,cam_dir): |
|
|
|
data = dict(new_cams={},input_cams={}) |
|
|
|
data['new_cams']['c2ws'] = [torch.load(os.path.join(cam_dir,new_cam,'cam2world.pt'),map_location='cpu',weights_only=True).to(self.dtype)] |
|
data['new_cams']['depths'] = [torch.load(os.path.join(cam_dir,new_cam,'depth.pt'),map_location='cpu',weights_only=True).to(self.dtype)] |
|
data['new_cams']['pointmaps'] = [self.transform_pointmap(torch.load(os.path.join(cam_dir,new_cam,'pointmap.pt'),map_location='cpu',weights_only=True).to(self.dtype),data['new_cams']['c2ws'][0])] |
|
data['new_cams']['Ks'] = [torch.load(os.path.join(cam_dir,new_cam,'intrinsics.pt'),map_location='cpu',weights_only=True).to(self.dtype)] |
|
data['new_cams']['valid_masks'] = [torch.load(os.path.join(cam_dir,new_cam,'mask.pt'),map_location='cpu',weights_only=True).to(torch.bool)] |
|
|
|
|
|
data['input_cams']['c2ws'] = [torch.load(os.path.join(cam_dir,input_cam,'cam2world.pt'),map_location='cpu',weights_only=True).to(self.dtype)] |
|
data['input_cams']['depths'] = [torch.load(os.path.join(cam_dir,input_cam,'depth.pt'),map_location='cpu',weights_only=True).to(self.dtype)] |
|
data['input_cams']['pointmaps'] = [self.transform_pointmap(torch.load(os.path.join(cam_dir,input_cam,'pointmap.pt'),map_location='cpu',weights_only=True).to(self.dtype),data['input_cams']['c2ws'][0])] |
|
data['input_cams']['Ks'] = [torch.load(os.path.join(cam_dir,input_cam,'intrinsics.pt'),map_location='cpu',weights_only=True).to(self.dtype)] |
|
data['input_cams']['valid_masks'] = [torch.load(os.path.join(cam_dir,input_cam,'mask.pt'),map_location='cpu',weights_only=True).to(torch.bool)] |
|
data['input_cams']['imgs'] = [torch.load(os.path.join(cam_dir,input_cam,'rgb.pt'),map_location='cpu',weights_only=True).to(self.dtype)] |
|
data['input_cams']['dino_features'] = [torch.load(os.path.join(cam_dir,input_cam,f'dino_features_layer_{l}.pt'),map_location='cpu',weights_only=True).to(self.dtype) for l in self.dino_features] |
|
return data |
|
|
|
def depth_to_metric(self,depth): |
|
|
|
|
|
depth_max = 10.0 |
|
depth_scaled = depth_max * (depth / 65535.0) |
|
return depth_scaled |
|
|
|
def load_scene_fast(self,input_cam,new_cam,cam_dir): |
|
data = dict(new_cams={},input_cams={}) |
|
data['new_cams']['c2ws'] = [torch.load(os.path.join(cam_dir,new_cam,'cam2world.pt'),map_location='cpu',weights_only=True).to(self.dtype)] |
|
data['new_cams']['Ks'] = [torch.load(os.path.join(cam_dir,new_cam,'intrinsics.pt'),map_location='cpu',weights_only=True).to(self.dtype)] |
|
data['new_cams']['depths'] = [torch.from_numpy(np.array(Image.open(os.path.join(cam_dir,new_cam,'depth.png'))).astype(np.float32))] |
|
data['new_cams']['valid_masks'] = [torch.from_numpy(np.array(Image.open(os.path.join(cam_dir,new_cam,'mask.png'))))] |
|
|
|
data['input_cams']['c2ws'] = [torch.load(os.path.join(cam_dir,input_cam,'cam2world.pt'),map_location='cpu',weights_only=True).to(self.dtype)] |
|
data['input_cams']['Ks'] = [torch.load(os.path.join(cam_dir,input_cam,'intrinsics.pt'),map_location='cpu',weights_only=True).to(self.dtype)] |
|
data['input_cams']['depths'] = [torch.from_numpy(np.array(Image.open(os.path.join(cam_dir,input_cam,'depth.png'))).astype(np.float32))] |
|
data['input_cams']['valid_masks'] = [torch.from_numpy(np.array(Image.open(os.path.join(cam_dir,input_cam,'mask.png'))))] |
|
data['input_cams']['imgs'] = [torch.from_numpy(np.array(Image.open(os.path.join(cam_dir,input_cam,'rgb.png'))))] |
|
if self.prefetch_dino: |
|
data['input_cams']['dino_features'] = [torch.cat([torch.load(os.path.join(cam_dir,input_cam,f'dino_features_layer_{l}.pt'),map_location='cpu',weights_only=True).to(self.dtype) for l in self.dino_features],dim=-1)] |
|
return data |
|
|
|
def __getitem__(self,idx): |
|
cam_dir = os.path.join(self.scenes[self.scene_ids[idx]],'cameras') |
|
|
|
input_cam,new_cam = self.decide_context_view(cam_dir) |
|
if self.mode == 'slow': |
|
data = self.load_scene_slow(input_cam,new_cam,cam_dir) |
|
else: |
|
data = self.load_scene_fast(input_cam,new_cam,cam_dir) |
|
return data |
|
|